

Last updated at Posted at 2023-11-09


// SPDX-License-Identifier: Apache-2.0
#include <limits.h>
#include <immintrin.h>
#include <stdio.h>
#include <stdint.h>

void print_m256i(const char *prefix, __m256i vector) {
    __attribute__ ((aligned (32))) uint16_t elements[16];
    _mm256_storeu_si256((__m256i*)elements, vector);

    printf("%s", prefix);
    for (int i = 0; i < 16; i++) {
        printf("%.4x ", elements[i]);

// https://stackoverflow.com/questions/21622212/how-to-perform-the-inverse-of-mm256-movemask-epi8-vpmovmskb
// Copyright: Satya Arjunan
// License: CC-BY-SA 3.0
__m256i get_mask3(const uint32_t mask) {
  __m256i vmask = _mm256_set1_epi32(mask);
  const __m256i shuffle = _mm256_setr_epi64x(0x0000000000000000,
      0x0101010101010101, 0x0202020202020202, 0x0303030303030303);
  vmask = _mm256_shuffle_epi8(vmask, shuffle);
  const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
  vmask = _mm256_or_si256(vmask, bit_mask);
  return _mm256_cmpeq_epi8(vmask, _mm256_set1_epi64x(-1));

// https://stackoverflow.com/questions/48726032/using-a-variable-to-index-a-simd-vector-with-mm256-extract-epi32-intrinsic
// Copyright: wim
// License: CC-BY-SA 3.0
uint32_t mm256_extract_epi32_var_indx(const __m256i vec, const unsigned int i) {
    __m128i indx = _mm_cvtsi32_si128(i);
    __m256i val  = _mm256_permutevar8x32_epi32(vec, _mm256_castsi128_si256(indx));
    return         _mm256_cvtsi256_si32(val);

void top3_values_and_indices_simd(__m256i simdData[], int len, int top3_values[3], int top3_indices[3]) {

    __m256i max_1st = simdData[0], max_2nd = _mm256_set1_epi32(INT_MIN), max_3rd = _mm256_set1_epi32(INT_MIN);
//    print_m256i("1st: ", max_1st);
//    print_m256i("2nd: ", max_2nd);
//    print_m256i("3rd: ", max_3rd);
    __m256i max_1st_col = _mm256_setzero_si256(), max_2nd_col = _mm256_setzero_si256(), max_3rd_col = _mm256_setzero_si256();
    for(register int i = 1; i < len; i++){
        __m128i col_t = _mm_cvtsi32_si128(i);
        __m256i col = _mm256_broadcastd_epi32(col_t); // or col = _mm256_sub_epi32(col, _mm256_cmpeq_epi32(col, col))
//        print_m256i("col:", col);

        __m256i gt = _mm256_cmpgt_epi32(simdData[i], max_1st);
//        print_m256i("gt : ", gt);
        __m256i t = _mm256_blendv_epi8(simdData[i], max_1st, gt); // min
        __m256i t_col = _mm256_blendv_epi8(col, max_1st_col, gt);
//        print_m256i("t  : ", t);
        max_1st = _mm256_blendv_epi8(max_1st, simdData[i], gt); // max
        max_1st_col = _mm256_blendv_epi8(max_1st_col, col, gt);
//        print_m256i("1st: ", max_1st);
//        printf("%d\n", i);
        __m256i gt2 = _mm256_cmpgt_epi32(t, max_2nd);
//        print_m256i("gt2: ", gt2);
        __m256i t2 = _mm256_blendv_epi8(t, max_2nd, gt2);
        __m256i t2_col = _mm256_blendv_epi8(t_col, max_2nd_col, gt2);
//        print_m256i("t2 : ", t2);
        max_2nd = _mm256_blendv_epi8(max_2nd, t, gt2);
        max_2nd_col = _mm256_blendv_epi8(max_2nd_col, t_col, gt2);
//        print_m256i("2nd: ", max_2nd);
        __m256i gt3 = _mm256_cmpgt_epi32(t, max_3rd);
//        print_m256i("gt3: ", gt3);
        max_3rd = _mm256_blendv_epi8(max_3rd, t2, gt3);
        max_3rd_col = _mm256_blendv_epi8(max_3rd_col, t2_col, gt3);
//        print_m256i("3rd: ", max_3rd);

    for(int i=0; i < 3; i++){
        __m256i vmax = max_1st;
        // borrowed from Public Domain codes
        // https://stackoverflow.com/questions/23590610/find-index-of-maximum-element-in-x86-simd-vector
        vmax = _mm256_max_epu32(vmax, _mm256_alignr_epi8(vmax, vmax, 4));
        vmax = _mm256_max_epu32(vmax, _mm256_alignr_epi8(vmax, vmax, 8));
        vmax = _mm256_max_epu32(vmax, _mm256_permute2x128_si256(vmax, vmax, 0x01));

        __m256i vmax_mask = _mm256_cmpeq_epi32(max_1st, vmax); // 2個以上の同じ値も含まれる
        // ffff ffff 0000 0000 0000 0000 0000 0000 ffff ffff 0000 0000 0000 0000 0000 0000

//        print_m256i("test: ", vmax_mask);
        uint32_t mask = _mm256_movemask_epi8(vmax_mask);
//        printf("%x\n", mask);
        int32_t ctz = __builtin_ctz(mask);
//        printf("%i\n", ctz);

        __m256i lowest_vmax_mask = get_mask3(255 << ctz);
//        print_m256i("mask: ", lowest_vmax_mask);

        top3_values[i] = mm256_extract_epi32_var_indx(max_1st, ctz >> 2);
        top3_indices[i] = (ctz >> 2) + 8*mm256_extract_epi32_var_indx(max_1st_col, ctz >> 2);
//        printf("%uth: %u\n", i, ctz >> 2);

        // maxを取り除いて繰り上げ
//        print_m256i("b1st: ", max_1st);
        max_1st = _mm256_blendv_epi8(max_1st, max_2nd, lowest_vmax_mask);
        max_1st_col = _mm256_blendv_epi8(max_1st_col, max_2nd_col, lowest_vmax_mask);
//        print_m256i("a1st: ", max_1st);
//        print_m256i("b2nd: ", max_2nd);
        max_2nd = _mm256_blendv_epi8(max_2nd, max_3rd, lowest_vmax_mask);
        max_2nd_col = _mm256_blendv_epi8(max_2nd_col, max_3rd_col, lowest_vmax_mask);
//        print_m256i("a2nd: ", max_2nd);


int main() {
    __m256i simdData[10] = {
        _mm256_set_epi32(0x5C, 0x22, 0x4C, 0x37, 0x1C, 0x3D, 0x31, 0x55),
        _mm256_set_epi32(0x13, 0x25, 0x42, 0x2F, 0x0C, 0x58, 0x34, 0x1F),
        _mm256_set_epi32(0x49, 0x1D, 0x2A, 0x45, 0x36, 0x5D, 0x17, 0x2D),
        _mm256_set_epi32(0x3F, 0x54, 0x1E, 0x3B, 0x26, 0x4F, 0x0A, 0x16),
        _mm256_set_epi32(0x50, 0x46, 0x39, 0x2C, 0x0F, 0x47, 0x5F, 0x27),
        _mm256_set_epi32(0x29, 0x44, 0x4A, 0x14, 0x51, 0x11, 0x23, 0x62),
        _mm256_set_epi32(0x41, 0x4D, 0x56, 0x1B, 0x33, 0x20, 0x0D, 0x48),
        _mm256_set_epi32(0x1A, 0x3A, 0x0E, 0x24, 0x43, 0x35, 0x59, 0x4B),
        _mm256_set_epi32(0x2B, 0x5A, 0x40, 0x15, 0x4E, 0x2E, 0x57, 0x0B),
        _mm256_set_epi32(0x10, 0x21, 0x32, 0x5E, 0x3E, 0x19, 0x30, 0x28),
    int top3_values[3];
    int top3_indices[3];

    top3_values_and_indices_simd(simdData, sizeof(simdData)/sizeof(__m256i), top3_values, top3_indices);

    // Print the top-3 values and their indices
    printf("Top-3 values: 0x%x, 0x%x, 0x%x\n", top3_values[0], top3_values[1], top3_values[2]);
    printf("Top-3 indices: 0x%x, 0x%x, 0x%x\n", top3_indices[0], top3_indices[1], top3_indices[2]);

    return 0;

$ gcc -mavx -mavx2 top3_4.c 
$ ./a.out
Top-3 values: 0x62, 0x5e, 0x5d
Top-3 indices: 0x28, 0x4c, 0x12


追記: float版も作りました。col側も__m256にしてますが、__m256iとどっちが良いかは不明です。


#include <immintrin.h>
#include <stdio.h>
#include <stdint.h>
#include <float.h>

void print_m256i(const char *prefix, __m256i vector) {
    __attribute__ ((aligned (32))) uint16_t elements[16];
    _mm256_storeu_si256((__m256i*)elements, vector);

    printf("%s", prefix);
    for (int i = 0; i < 16; i++) {
        printf("%.4x ", elements[i]);

void print_m256(__m256 vector) {
    __attribute__ ((aligned (32))) float elements[8];
    _mm256_storeu_ps(elements, vector);
    for (int i = 0; i < 8; i++) {
        printf("%.4f ", elements[i]);

// https://stackoverflow.com/questions/21622212/how-to-perform-the-inverse-of-mm256-movemask-epi8-vpmovmskb
// Copyright: Satya Arjunan
// License: CC-BY-SA 3.0
__m256i get_mask3(const uint32_t mask) {
  __m256i vmask = _mm256_set1_epi32(mask);
  const __m256i shuffle = _mm256_setr_epi64x(0x0000000000000000,
      0x0101010101010101, 0x0202020202020202, 0x0303030303030303);
  vmask = _mm256_shuffle_epi8(vmask, shuffle);
  const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
  vmask = _mm256_or_si256(vmask, bit_mask);
  return _mm256_cmpeq_epi8(vmask, _mm256_set1_epi64x(-1));

// https://stackoverflow.com/questions/48726032/using-a-variable-to-index-a-simd-vector-with-mm256-extract-epi32-intrinsic
// Copyright: wim
// License: CC-BY-SA 3.0
uint32_t mm256_extract_epi32_var_indx(const __m256i vec, const unsigned int i) {
    __m128i indx = _mm_cvtsi32_si128(i);
    __m256i val  = _mm256_permutevar8x32_epi32(vec, _mm256_castsi128_si256(indx));
    return         _mm256_cvtsi256_si32(val);

float mm256_extract_ps_var_indx(const __m256 vec, const unsigned int i) {
    __m128i indx = _mm_cvtsi32_si128(i);
    __m256 val  = _mm256_permutevar8x32_ps(vec, _mm256_castsi128_si256(indx));
    return         _mm256_cvtss_f32(val);

void top3_values_and_indices_simd(__m256 simdData[], int len, float top3_values[3], int top3_indices[3]) {

    __m256 max_1st = simdData[0], max_2nd = _mm256_set1_ps(FLT_MIN), max_3rd = _mm256_set1_ps(FLT_MIN);
//    print_m256("1st: ", max_1st);
//    print_m256("2nd: ", max_2nd);
//    print_m256("3rd: ", max_3rd);
    __m256 max_1st_col = (__m256)_mm256_setzero_si256(), max_2nd_col = (__m256)_mm256_setzero_si256(), max_3rd_col = (__m256)_mm256_setzero_si256();
    for(register int i = 1; i < len; i++){
        __m128i col_t = _mm_cvtsi32_si128(i);
        __m256 col = (__m256)_mm256_broadcastd_epi32(col_t); // or col = _mm256_sub_epi32(col, _mm256_cmpeq_epi32(col, col))
//        print_m256("col:", col);

        __m256 gt = _mm256_cmp_ps(simdData[i], max_1st, 14/*_CMP_GT_OS*/);
//        print_m256("gt : ", gt);
        __m256 t = _mm256_blendv_ps(simdData[i], max_1st, gt); // min
        __m256 t_col = _mm256_blendv_ps(col, max_1st_col, gt);
//        print_m256("t  : ", t);
        max_1st = _mm256_blendv_ps(max_1st, simdData[i], gt); // max
        max_1st_col = _mm256_blendv_ps(max_1st_col, col, gt);
//        print_m256("1st: ", max_1st);
//        printf("%d\n", i);
        __m256 gt2 = _mm256_cmp_ps(t, max_2nd, 14/*_CMP_GT_OS*/);
//        print_m256("gt2: ", gt2);
        __m256 t2 = _mm256_blendv_ps(t, max_2nd, gt2);
        __m256 t2_col = _mm256_blendv_ps(t_col, max_2nd_col, gt2);
//        print_m256("t2 : ", t2);
        max_2nd = _mm256_blendv_ps(max_2nd, t, gt2);
        max_2nd_col = _mm256_blendv_ps(max_2nd_col, t_col, gt2);
//        print_m256("2nd: ", max_2nd);
        __m256 gt3 = _mm256_cmp_ps(t, max_3rd, 14/*_CMP_GT_OS*/);
//        print_m256("gt3: ", gt3);
        max_3rd = _mm256_blendv_ps(max_3rd, t2, gt3);
        max_3rd_col = _mm256_blendv_ps(max_3rd_col, t2_col, gt3);
//        print_m256("3rd: ", max_3rd);

    for(int i=0; i < 3; i++){
        __m256 vmax = max_1st;
        // borrowed from Public Domain codes
        // https://stackoverflow.com/questions/23590610/find-index-of-maximum-element-in-x86-simd-vector
        vmax = _mm256_max_ps(vmax, _mm256_permute_ps(vmax, _MM_SHUFFLE(0, 3, 2, 1)));
        vmax = _mm256_max_ps(vmax, _mm256_permute_ps(vmax, _MM_SHUFFLE(1, 0, 3, 2)));
        vmax = _mm256_max_ps(vmax, _mm256_permute2f128_ps(vmax, vmax, 0x01));

        __m256 vmax_mask = _mm256_cmp_ps(max_1st, vmax, 24/*_CMP_EQ_US*/); // 2個以上の同じ値も含まれる
        // ffff ffff 0000 0000 0000 0000 0000 0000 ffff ffff 0000 0000 0000 0000 0000 0000

//        print_m256("test: ", vmax_mask);
        uint32_t mask = _mm256_movemask_epi8((__m256i)vmax_mask);
//        printf("%x\n", mask);
        int32_t ctz = __builtin_ctz(mask);
//        printf("%i\n", ctz);

        __m256 lowest_vmax_mask = (__m256)get_mask3(255 << ctz);
//        print_m256("mask: ", lowest_vmax_mask);

        top3_values[i] = mm256_extract_ps_var_indx(max_1st, ctz >> 2);
        top3_indices[i] = (ctz >> 2) + 8*mm256_extract_epi32_var_indx((__m256i)max_1st_col, ctz >> 2);
//        printf("%uth: %u\n", i, ctz >> 2);

        // maxを取り除いて繰り上げ
//        print_m256("b1st: ", max_1st);
        max_1st = _mm256_blendv_ps(max_1st, max_2nd, lowest_vmax_mask);
        max_1st_col = _mm256_blendv_ps(max_1st_col, max_2nd_col, lowest_vmax_mask);
//        print_m256("a1st: ", max_1st);
//        print_m256("b2nd: ", max_2nd);
        max_2nd = _mm256_blendv_ps(max_2nd, max_3rd, lowest_vmax_mask);
        max_2nd_col = _mm256_blendv_ps(max_2nd_col, max_3rd_col, lowest_vmax_mask);
//        print_m256("a2nd: ", max_2nd);


int main() {
    __m256 simdData[10] = {
        _mm256_set_ps(0.732, 0.981, 0.564, 0.417, 0.237, 0.819, 0.645, 0.123),
        _mm256_set_ps(0.897, 0.654, 0.378, 0.209, 0.527, 0.486, 0.763, 0.321),
        _mm256_set_ps(0.412, 0.896, 0.123, 0.789, 0.456, 0.642, 0.318, 0.547),
        _mm256_set_ps(0.645, 0.237, 0.819, 0.732, 0.931, 0.564, 0.417, 0.123),
        _mm256_set_ps(0.564, 0.819, 0.237, 0.732, 0.417, 0.961, 0.123, 0.645),
        _mm256_set_ps(0.123, 0.564, 0.897, 0.645, 0.378, 0.732, 0.321, 0.642),
        _mm256_set_ps(0.486, 0.378, 0.732, 0.897, 0.209, 0.527, 0.547, 0.763),
        _mm256_set_ps(0.642, 0.456, 0.564, 0.321, 0.819, 0.732, 0.237, 0.123),
        _mm256_set_ps(0.209, 0.123, 0.789, 0.897, 0.547, 0.642, 0.456, 0.237),
        _mm256_set_ps(0.819, 0.732, 0.564, 0.645, 0.237, 0.123, 0.911, 0.417),
    float top3_values[3];
    int top3_indices[3];

    top3_values_and_indices_simd(simdData, sizeof(simdData)/sizeof(__m256), top3_values, top3_indices);

    // Print the top-3 values and their indices
    printf("Top-3 values: %f, %f, %f\n", top3_values[0], top3_values[1], top3_values[2]);
    printf("Top-3 indices: 0x%x, 0x%x, 0x%x\n", top3_indices[0], top3_indices[1], top3_indices[2]);

    return 0;

