avx512
More than 3 years have passed since last update.

マスク付きvpandd(http://qiita.com/tanakmura/items/0838f0244bd93cc7a2b6 )で書いたようにAVX-512では、多くの命令にマスクが付けられます。

C言語上では、__mmask16 は、unsigned shortのtypedefになっているようですが、機械語上は、マスクの値は汎用レジスタではなく、k0-k7 という別のレジスタに入れられます。k0-k7 はAVX512Fでは、16bit、AVX512BW拡張が入っていれば64bitのレジスタです。

AVX2 までは、比較結果は、ベクタレジスタに入れられるようになっていましたが、AVX-512では、比較結果などは、このマスクレジスタに入るようになります。

#include <immintrin.h>
#include <stdio.h>

double in0[8] = {1,2,3,4,5,6,7,8};
double in1[8] = {1,2,1,2,5,6,5,6};
double in2[8] = {2,2,4,4,6,6,8,8};

int
main()
{
    __m512d v0 = _mm512_loadu_pd(in0);
    __m512d v1 = _mm512_loadu_pd(in1);
    __m512d v2 = _mm512_loadu_pd(in2);

    __mmask8 cmp0 = _mm512_cmp_pd_mask(v0, v1, _CMP_EQ_OQ);
    __mmask8 cmp1 = _mm512_cmp_pd_mask(v0, v2, _CMP_EQ_OQ);

    printf("%x\n", cmp0);
    printf("%x\n", cmp1);

}
 $ ./sde  -- ./a.out 
33
aa

__mmask16 値はC言語上は、スカラ値と同じように扱えるようですが、普通に使うと汎用レジスタに入れるためのkmovという命令が入ります。

#include <immintrin.h>
#include <stdio.h>

double in0[8] = {1,2,3,4,5,6,7,8};
double in1[8] = {1,2,1,2,5,6,5,6};
double in2[8] = {2,2,4,4,6,6,8,8};

int a;

int
main()
{
    __m512d v0 = _mm512_loadu_pd(in0);
    __m512d v1 = _mm512_loadu_pd(in1);

    __mmask8 cmp0 = _mm512_cmp_pd_mask(v0, v1, _CMP_EQ_OQ);

    a = cmp0;
}
main:
.LFB2245:
    .cfi_startproc
    vmovupd in0(%rip), %zmm0
    vmovupd in1(%rip), %zmm1
    vcmppd  $0, %zmm1, %zmm0, %k1
    kmovw   %k1, %eax                # k1 → eax へ転送
    movzbl  %al, %eax
    movl    %eax, a(%rip)
    ret

マスク値を加算したりすることはあまり無いかもしれませんが、条件が複数あるときなどは、論理和、論理積が欲しい場合はそれなりにあると思います。

このとき、いちいち汎用レジスタに入れたり出したりするのは無駄があるような気がしますね。なので、マスクレジスタだけで、いくらかの論理演算ができるようになっています。

korwは、ふたつのマスクレジスタの16bit論理和をとります。

#include <immintrin.h>
#include <stdio.h>

double in0[8] = {1,2,3,4,5,6,7,8};
double in1[8] = {1,2,1,2,5,6,5,6};
double in2[8] = {2,2,4,4,6,6,8,8};

double out[8];

void
dump_bits8(__mmask8 a)
{
    for (int i=0; i<8; i++) {
        if (a & (1<<(7-i))) {
            putchar('1');
        } else {
            putchar('0');
        }
    }
    puts("");
}

int
main()
{
    __m512d v0 = _mm512_loadu_pd(in0);
    __m512d v1 = _mm512_loadu_pd(in1);
    __m512d v2 = _mm512_loadu_pd(in2);

    __mmask8 cmp0 = _mm512_cmp_pd_mask(v0, v1, _CMP_EQ_OQ);
    __mmask8 cmp1 = _mm512_cmp_pd_mask(v0, v2, _CMP_EQ_OQ);


    dump_bits8(cmp0);
    dump_bits8(cmp1);
    dump_bits8(_mm512_kor(cmp0,cmp1));
}
 $ ./sde  -- ./a.out 
00110011
10101010
10111011

以下のようなプログラムを書けば、無駄な転送が減っていることが確認できますね。

#include <immintrin.h>
#include <stdio.h>

double in0[8] = {1,2,3,4,5,6,7,8};
double in1[8] = {1,2,1,2,5,6,5,6};
double out[8];
int a;

int
main()
{
    __m512d v0 = _mm512_loadu_pd(in0);
    __m512d v1 = _mm512_loadu_pd(in1);
    __mmask8 cmp0 = _mm512_cmp_pd_mask(v0, v1, _CMP_EQ_OQ);
    __mmask8 cmp1 = _mm512_cmp_pd_mask(v0, v1, _CMP_LT_OQ);
    __mmask8 cmpor = _mm512_kor(cmp0, cmp1);

    _mm512_mask_storeu_pd(out, cmpor, v0);
}
main:
.LFB2229:
    .cfi_startproc
    vmovupd in0(%rip), %zmm1
    xorl    %eax, %eax
    vmovupd in1(%rip), %zmm0
    vcmppd  $0, %zmm0, %zmm1, %k2
    vcmppd  $17, %zmm0, %zmm1, %k1
    korw    %k2, %k1, %k1
    vmovupd %zmm1, out(%rip){%k1}
    ret

明日は、@tanakmura が kortestw について書きます。