LoginSignup
0
0

AVX2で32bit整数/float要素のトップ3を調べる

Last updated at Posted at 2023-11-09

要素のtop-3をSIMDで高速に調べたかったのでちょっと書いてみました。本当はfloat版が必要なのですが、とりあえず32bit整数版から。

top3.c
// SPDX-License-Identifier: Apache-2.0
#include <limits.h>
#include <immintrin.h>
#include <stdio.h>
#include <stdint.h>

void print_m256i(const char *prefix, __m256i vector) {
    __attribute__ ((aligned (32))) uint16_t elements[16];
    _mm256_storeu_si256((__m256i*)elements, vector);

    printf("%s", prefix);
    for (int i = 0; i < 16; i++) {
        printf("%.4x ", elements[i]);
    }
    printf("\n");
}

// https://stackoverflow.com/questions/21622212/how-to-perform-the-inverse-of-mm256-movemask-epi8-vpmovmskb
// Copyright: Satya Arjunan
// License: CC-BY-SA 3.0
__m256i get_mask3(const uint32_t mask) {
  __m256i vmask = _mm256_set1_epi32(mask);
  const __m256i shuffle = _mm256_setr_epi64x(0x0000000000000000,
      0x0101010101010101, 0x0202020202020202, 0x0303030303030303);
  vmask = _mm256_shuffle_epi8(vmask, shuffle);
  const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
  vmask = _mm256_or_si256(vmask, bit_mask);
  return _mm256_cmpeq_epi8(vmask, _mm256_set1_epi64x(-1));
}

// https://stackoverflow.com/questions/48726032/using-a-variable-to-index-a-simd-vector-with-mm256-extract-epi32-intrinsic
// Copyright: wim
// License: CC-BY-SA 3.0
uint32_t mm256_extract_epi32_var_indx(const __m256i vec, const unsigned int i) {
    __m128i indx = _mm_cvtsi32_si128(i);
    __m256i val  = _mm256_permutevar8x32_epi32(vec, _mm256_castsi128_si256(indx));
    return         _mm256_cvtsi256_si32(val);
}


void top3_values_and_indices_simd(__m256i simdData[], int len, int top3_values[3], int top3_indices[3]) {

    __m256i max_1st = simdData[0], max_2nd = _mm256_set1_epi32(INT_MIN), max_3rd = _mm256_set1_epi32(INT_MIN);
//    print_m256i("1st: ", max_1st);
//    print_m256i("2nd: ", max_2nd);
//    print_m256i("3rd: ", max_3rd);
    __m256i max_1st_col = _mm256_setzero_si256(), max_2nd_col = _mm256_setzero_si256(), max_3rd_col = _mm256_setzero_si256();
    for(register int i = 1; i < len; i++){
        __m128i col_t = _mm_cvtsi32_si128(i);
        __m256i col = _mm256_broadcastd_epi32(col_t); // or col = _mm256_sub_epi32(col, _mm256_cmpeq_epi32(col, col))
//        print_m256i("col:", col);

        __m256i gt = _mm256_cmpgt_epi32(simdData[i], max_1st);
//        print_m256i("gt : ", gt);
        __m256i t = _mm256_blendv_epi8(simdData[i], max_1st, gt); // min
        __m256i t_col = _mm256_blendv_epi8(col, max_1st_col, gt);
//        print_m256i("t  : ", t);
        max_1st = _mm256_blendv_epi8(max_1st, simdData[i], gt); // max
        max_1st_col = _mm256_blendv_epi8(max_1st_col, col, gt);
//        print_m256i("1st: ", max_1st);
//        printf("%d\n", i);
        __m256i gt2 = _mm256_cmpgt_epi32(t, max_2nd);
//        print_m256i("gt2: ", gt2);
        __m256i t2 = _mm256_blendv_epi8(t, max_2nd, gt2);
        __m256i t2_col = _mm256_blendv_epi8(t_col, max_2nd_col, gt2);
//        print_m256i("t2 : ", t2);
        max_2nd = _mm256_blendv_epi8(max_2nd, t, gt2);
        max_2nd_col = _mm256_blendv_epi8(max_2nd_col, t_col, gt2);
//        print_m256i("2nd: ", max_2nd);
        __m256i gt3 = _mm256_cmpgt_epi32(t, max_3rd);
//        print_m256i("gt3: ", gt3);
        max_3rd = _mm256_blendv_epi8(max_3rd, t2, gt3);
        max_3rd_col = _mm256_blendv_epi8(max_3rd_col, t2_col, gt3);
//        print_m256i("3rd: ", max_3rd);
    }

    for(int i=0; i < 3; i++){
        __m256i vmax = max_1st;
        // borrowed from Public Domain codes
        // https://stackoverflow.com/questions/23590610/find-index-of-maximum-element-in-x86-simd-vector
        vmax = _mm256_max_epu32(vmax, _mm256_alignr_epi8(vmax, vmax, 4));
        vmax = _mm256_max_epu32(vmax, _mm256_alignr_epi8(vmax, vmax, 8));
        vmax = _mm256_max_epu32(vmax, _mm256_permute2x128_si256(vmax, vmax, 0x01));

        __m256i vmax_mask = _mm256_cmpeq_epi32(max_1st, vmax); // 2個以上の同じ値も含まれる
        // ffff ffff 0000 0000 0000 0000 0000 0000 ffff ffff 0000 0000 0000 0000 0000 0000

//        print_m256i("test: ", vmax_mask);
        uint32_t mask = _mm256_movemask_epi8(vmax_mask);
//        printf("%x\n", mask);
        int32_t ctz = __builtin_ctz(mask);
//        printf("%i\n", ctz);

        __m256i lowest_vmax_mask = get_mask3(255 << ctz);
//        print_m256i("mask: ", lowest_vmax_mask);

        top3_values[i] = mm256_extract_epi32_var_indx(max_1st, ctz >> 2);
        top3_indices[i] = (ctz >> 2) + 8*mm256_extract_epi32_var_indx(max_1st_col, ctz >> 2);
//        printf("%uth: %u\n", i, ctz >> 2);

        // maxを取り除いて繰り上げ
//        print_m256i("b1st: ", max_1st);
        max_1st = _mm256_blendv_epi8(max_1st, max_2nd, lowest_vmax_mask);
        max_1st_col = _mm256_blendv_epi8(max_1st_col, max_2nd_col, lowest_vmax_mask);
//        print_m256i("a1st: ", max_1st);
//        print_m256i("b2nd: ", max_2nd);
        max_2nd = _mm256_blendv_epi8(max_2nd, max_3rd, lowest_vmax_mask);
        max_2nd_col = _mm256_blendv_epi8(max_2nd_col, max_3rd_col, lowest_vmax_mask);
//        print_m256i("a2nd: ", max_2nd);
        
    }

}

int main() {
    __m256i simdData[10] = {
        _mm256_set_epi32(0x5C, 0x22, 0x4C, 0x37, 0x1C, 0x3D, 0x31, 0x55),
        _mm256_set_epi32(0x13, 0x25, 0x42, 0x2F, 0x0C, 0x58, 0x34, 0x1F),
        _mm256_set_epi32(0x49, 0x1D, 0x2A, 0x45, 0x36, 0x5D, 0x17, 0x2D),
        _mm256_set_epi32(0x3F, 0x54, 0x1E, 0x3B, 0x26, 0x4F, 0x0A, 0x16),
        _mm256_set_epi32(0x50, 0x46, 0x39, 0x2C, 0x0F, 0x47, 0x5F, 0x27),
        _mm256_set_epi32(0x29, 0x44, 0x4A, 0x14, 0x51, 0x11, 0x23, 0x62),
        _mm256_set_epi32(0x41, 0x4D, 0x56, 0x1B, 0x33, 0x20, 0x0D, 0x48),
        _mm256_set_epi32(0x1A, 0x3A, 0x0E, 0x24, 0x43, 0x35, 0x59, 0x4B),
        _mm256_set_epi32(0x2B, 0x5A, 0x40, 0x15, 0x4E, 0x2E, 0x57, 0x0B),
        _mm256_set_epi32(0x10, 0x21, 0x32, 0x5E, 0x3E, 0x19, 0x30, 0x28),
    };
    int top3_values[3];
    int top3_indices[3];

    top3_values_and_indices_simd(simdData, sizeof(simdData)/sizeof(__m256i), top3_values, top3_indices);

    // Print the top-3 values and their indices
    printf("Top-3 values: 0x%x, 0x%x, 0x%x\n", top3_values[0], top3_values[1], top3_values[2]);
    printf("Top-3 indices: 0x%x, 0x%x, 0x%x\n", top3_indices[0], top3_indices[1], top3_indices[2]);

    return 0;
}

$ gcc -mavx -mavx2 top3_4.c 
$ ./a.out
Top-3 values: 0x62, 0x5e, 0x5d
Top-3 indices: 0x28, 0x4c, 0x12

ChatGPTに頼ったらサクッと書けるかなと思ったらトンチンカンな答えばかりで結局stackoverflowに頼りまくってます。
もっと命令発行ポートとか最適化できそうな気もしますが本命はfloat版なのでとりあえずここまでで。

追記: float版も作りました。col側も__m256にしてますが、__m256iとどっちが良いかは不明です。

top3_float.c

#include <immintrin.h>
#include <stdio.h>
#include <stdint.h>
#include <float.h>

void print_m256i(const char *prefix, __m256i vector) {
    __attribute__ ((aligned (32))) uint16_t elements[16];
    _mm256_storeu_si256((__m256i*)elements, vector);

    printf("%s", prefix);
    for (int i = 0; i < 16; i++) {
        printf("%.4x ", elements[i]);
    }
    printf("\n");
}

void print_m256(__m256 vector) {
    __attribute__ ((aligned (32))) float elements[8];
    _mm256_storeu_ps(elements, vector);
    
    for (int i = 0; i < 8; i++) {
        printf("%.4f ", elements[i]);
    }
    printf("\n");
}

// https://stackoverflow.com/questions/21622212/how-to-perform-the-inverse-of-mm256-movemask-epi8-vpmovmskb
// Copyright: Satya Arjunan
// License: CC-BY-SA 3.0
__m256i get_mask3(const uint32_t mask) {
  __m256i vmask = _mm256_set1_epi32(mask);
  const __m256i shuffle = _mm256_setr_epi64x(0x0000000000000000,
      0x0101010101010101, 0x0202020202020202, 0x0303030303030303);
  vmask = _mm256_shuffle_epi8(vmask, shuffle);
  const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
  vmask = _mm256_or_si256(vmask, bit_mask);
  return _mm256_cmpeq_epi8(vmask, _mm256_set1_epi64x(-1));
}

// https://stackoverflow.com/questions/48726032/using-a-variable-to-index-a-simd-vector-with-mm256-extract-epi32-intrinsic
// Copyright: wim
// License: CC-BY-SA 3.0
uint32_t mm256_extract_epi32_var_indx(const __m256i vec, const unsigned int i) {
    __m128i indx = _mm_cvtsi32_si128(i);
    __m256i val  = _mm256_permutevar8x32_epi32(vec, _mm256_castsi128_si256(indx));
    return         _mm256_cvtsi256_si32(val);
}

float mm256_extract_ps_var_indx(const __m256 vec, const unsigned int i) {
    __m128i indx = _mm_cvtsi32_si128(i);
    __m256 val  = _mm256_permutevar8x32_ps(vec, _mm256_castsi128_si256(indx));
    return         _mm256_cvtss_f32(val);
}

void top3_values_and_indices_simd(__m256 simdData[], int len, float top3_values[3], int top3_indices[3]) {

    __m256 max_1st = simdData[0], max_2nd = _mm256_set1_ps(FLT_MIN), max_3rd = _mm256_set1_ps(FLT_MIN);
//    print_m256("1st: ", max_1st);
//    print_m256("2nd: ", max_2nd);
//    print_m256("3rd: ", max_3rd);
    __m256 max_1st_col = (__m256)_mm256_setzero_si256(), max_2nd_col = (__m256)_mm256_setzero_si256(), max_3rd_col = (__m256)_mm256_setzero_si256();
    for(register int i = 1; i < len; i++){
        __m128i col_t = _mm_cvtsi32_si128(i);
        __m256 col = (__m256)_mm256_broadcastd_epi32(col_t); // or col = _mm256_sub_epi32(col, _mm256_cmpeq_epi32(col, col))
//        print_m256("col:", col);

        __m256 gt = _mm256_cmp_ps(simdData[i], max_1st, 14/*_CMP_GT_OS*/);
//        print_m256("gt : ", gt);
        __m256 t = _mm256_blendv_ps(simdData[i], max_1st, gt); // min
        __m256 t_col = _mm256_blendv_ps(col, max_1st_col, gt);
//        print_m256("t  : ", t);
        max_1st = _mm256_blendv_ps(max_1st, simdData[i], gt); // max
        max_1st_col = _mm256_blendv_ps(max_1st_col, col, gt);
//        print_m256("1st: ", max_1st);
//        printf("%d\n", i);
        __m256 gt2 = _mm256_cmp_ps(t, max_2nd, 14/*_CMP_GT_OS*/);
//        print_m256("gt2: ", gt2);
        __m256 t2 = _mm256_blendv_ps(t, max_2nd, gt2);
        __m256 t2_col = _mm256_blendv_ps(t_col, max_2nd_col, gt2);
//        print_m256("t2 : ", t2);
        max_2nd = _mm256_blendv_ps(max_2nd, t, gt2);
        max_2nd_col = _mm256_blendv_ps(max_2nd_col, t_col, gt2);
//        print_m256("2nd: ", max_2nd);
        __m256 gt3 = _mm256_cmp_ps(t, max_3rd, 14/*_CMP_GT_OS*/);
//        print_m256("gt3: ", gt3);
        max_3rd = _mm256_blendv_ps(max_3rd, t2, gt3);
        max_3rd_col = _mm256_blendv_ps(max_3rd_col, t2_col, gt3);
//        print_m256("3rd: ", max_3rd);
    }

    for(int i=0; i < 3; i++){
        __m256 vmax = max_1st;
        // borrowed from Public Domain codes
        // https://stackoverflow.com/questions/23590610/find-index-of-maximum-element-in-x86-simd-vector
        vmax = _mm256_max_ps(vmax, _mm256_permute_ps(vmax, _MM_SHUFFLE(0, 3, 2, 1)));
        vmax = _mm256_max_ps(vmax, _mm256_permute_ps(vmax, _MM_SHUFFLE(1, 0, 3, 2)));
        vmax = _mm256_max_ps(vmax, _mm256_permute2f128_ps(vmax, vmax, 0x01));

        __m256 vmax_mask = _mm256_cmp_ps(max_1st, vmax, 24/*_CMP_EQ_US*/); // 2個以上の同じ値も含まれる
        // ffff ffff 0000 0000 0000 0000 0000 0000 ffff ffff 0000 0000 0000 0000 0000 0000

//        print_m256("test: ", vmax_mask);
        uint32_t mask = _mm256_movemask_epi8((__m256i)vmax_mask);
//        printf("%x\n", mask);
        int32_t ctz = __builtin_ctz(mask);
//        printf("%i\n", ctz);

        __m256 lowest_vmax_mask = (__m256)get_mask3(255 << ctz);
//        print_m256("mask: ", lowest_vmax_mask);

        top3_values[i] = mm256_extract_ps_var_indx(max_1st, ctz >> 2);
        top3_indices[i] = (ctz >> 2) + 8*mm256_extract_epi32_var_indx((__m256i)max_1st_col, ctz >> 2);
//        printf("%uth: %u\n", i, ctz >> 2);

        // maxを取り除いて繰り上げ
//        print_m256("b1st: ", max_1st);
        max_1st = _mm256_blendv_ps(max_1st, max_2nd, lowest_vmax_mask);
        max_1st_col = _mm256_blendv_ps(max_1st_col, max_2nd_col, lowest_vmax_mask);
//        print_m256("a1st: ", max_1st);
//        print_m256("b2nd: ", max_2nd);
        max_2nd = _mm256_blendv_ps(max_2nd, max_3rd, lowest_vmax_mask);
        max_2nd_col = _mm256_blendv_ps(max_2nd_col, max_3rd_col, lowest_vmax_mask);
//        print_m256("a2nd: ", max_2nd);
        
    }

}

int main() {
    __m256 simdData[10] = {
        _mm256_set_ps(0.732, 0.981, 0.564, 0.417, 0.237, 0.819, 0.645, 0.123),
        _mm256_set_ps(0.897, 0.654, 0.378, 0.209, 0.527, 0.486, 0.763, 0.321),
        _mm256_set_ps(0.412, 0.896, 0.123, 0.789, 0.456, 0.642, 0.318, 0.547),
        _mm256_set_ps(0.645, 0.237, 0.819, 0.732, 0.931, 0.564, 0.417, 0.123),
        _mm256_set_ps(0.564, 0.819, 0.237, 0.732, 0.417, 0.961, 0.123, 0.645),
        _mm256_set_ps(0.123, 0.564, 0.897, 0.645, 0.378, 0.732, 0.321, 0.642),
        _mm256_set_ps(0.486, 0.378, 0.732, 0.897, 0.209, 0.527, 0.547, 0.763),
        _mm256_set_ps(0.642, 0.456, 0.564, 0.321, 0.819, 0.732, 0.237, 0.123),
        _mm256_set_ps(0.209, 0.123, 0.789, 0.897, 0.547, 0.642, 0.456, 0.237),
        _mm256_set_ps(0.819, 0.732, 0.564, 0.645, 0.237, 0.123, 0.911, 0.417),
    };
    float top3_values[3];
    int top3_indices[3];

    top3_values_and_indices_simd(simdData, sizeof(simdData)/sizeof(__m256), top3_values, top3_indices);

    // Print the top-3 values and their indices
    printf("Top-3 values: %f, %f, %f\n", top3_values[0], top3_values[1], top3_values[2]);
    printf("Top-3 indices: 0x%x, 0x%x, 0x%x\n", top3_indices[0], top3_indices[1], top3_indices[2]);

    return 0;
}
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0