マスク付きvpandd(http://qiita.com/tanakmura/items/0838f0244bd93cc7a2b6 )で書いたようにAVX-512では、多くの命令にマスクが付けられます。
C言語上では、__mmask16 は、unsigned shortのtypedefになっているようですが、機械語上は、マスクの値は汎用レジスタではなく、k0-k7 という別のレジスタに入れられます。k0-k7 はAVX512Fでは、16bit、AVX512BW拡張が入っていれば64bitのレジスタです。
AVX2 までは、比較結果は、ベクタレジスタに入れられるようになっていましたが、AVX-512では、比較結果などは、このマスクレジスタに入るようになります。
#include <immintrin.h>
#include <stdio.h>
double in0[8] = {1,2,3,4,5,6,7,8};
double in1[8] = {1,2,1,2,5,6,5,6};
double in2[8] = {2,2,4,4,6,6,8,8};
int
main()
{
__m512d v0 = _mm512_loadu_pd(in0);
__m512d v1 = _mm512_loadu_pd(in1);
__m512d v2 = _mm512_loadu_pd(in2);
__mmask8 cmp0 = _mm512_cmp_pd_mask(v0, v1, _CMP_EQ_OQ);
__mmask8 cmp1 = _mm512_cmp_pd_mask(v0, v2, _CMP_EQ_OQ);
printf("%x\n", cmp0);
printf("%x\n", cmp1);
}
$ ./sde -- ./a.out
33
aa
__mmask16 値はC言語上は、スカラ値と同じように扱えるようですが、普通に使うと汎用レジスタに入れるためのkmovという命令が入ります。
#include <immintrin.h>
#include <stdio.h>
double in0[8] = {1,2,3,4,5,6,7,8};
double in1[8] = {1,2,1,2,5,6,5,6};
double in2[8] = {2,2,4,4,6,6,8,8};
int a;
int
main()
{
__m512d v0 = _mm512_loadu_pd(in0);
__m512d v1 = _mm512_loadu_pd(in1);
__mmask8 cmp0 = _mm512_cmp_pd_mask(v0, v1, _CMP_EQ_OQ);
a = cmp0;
}
main:
.LFB2245:
.cfi_startproc
vmovupd in0(%rip), %zmm0
vmovupd in1(%rip), %zmm1
vcmppd $0, %zmm1, %zmm0, %k1
kmovw %k1, %eax # k1 → eax へ転送
movzbl %al, %eax
movl %eax, a(%rip)
ret
マスク値を加算したりすることはあまり無いかもしれませんが、条件が複数あるときなどは、論理和、論理積が欲しい場合はそれなりにあると思います。
このとき、いちいち汎用レジスタに入れたり出したりするのは無駄があるような気がしますね。なので、マスクレジスタだけで、いくらかの論理演算ができるようになっています。
korwは、ふたつのマスクレジスタの16bit論理和をとります。
#include <immintrin.h>
#include <stdio.h>
double in0[8] = {1,2,3,4,5,6,7,8};
double in1[8] = {1,2,1,2,5,6,5,6};
double in2[8] = {2,2,4,4,6,6,8,8};
double out[8];
void
dump_bits8(__mmask8 a)
{
for (int i=0; i<8; i++) {
if (a & (1<<(7-i))) {
putchar('1');
} else {
putchar('0');
}
}
puts("");
}
int
main()
{
__m512d v0 = _mm512_loadu_pd(in0);
__m512d v1 = _mm512_loadu_pd(in1);
__m512d v2 = _mm512_loadu_pd(in2);
__mmask8 cmp0 = _mm512_cmp_pd_mask(v0, v1, _CMP_EQ_OQ);
__mmask8 cmp1 = _mm512_cmp_pd_mask(v0, v2, _CMP_EQ_OQ);
dump_bits8(cmp0);
dump_bits8(cmp1);
dump_bits8(_mm512_kor(cmp0,cmp1));
}
$ ./sde -- ./a.out
00110011
10101010
10111011
以下のようなプログラムを書けば、無駄な転送が減っていることが確認できますね。
#include <immintrin.h>
#include <stdio.h>
double in0[8] = {1,2,3,4,5,6,7,8};
double in1[8] = {1,2,1,2,5,6,5,6};
double out[8];
int a;
int
main()
{
__m512d v0 = _mm512_loadu_pd(in0);
__m512d v1 = _mm512_loadu_pd(in1);
__mmask8 cmp0 = _mm512_cmp_pd_mask(v0, v1, _CMP_EQ_OQ);
__mmask8 cmp1 = _mm512_cmp_pd_mask(v0, v1, _CMP_LT_OQ);
__mmask8 cmpor = _mm512_kor(cmp0, cmp1);
_mm512_mask_storeu_pd(out, cmpor, v0);
}
main:
.LFB2229:
.cfi_startproc
vmovupd in0(%rip), %zmm1
xorl %eax, %eax
vmovupd in1(%rip), %zmm0
vcmppd $0, %zmm0, %zmm1, %k2
vcmppd $17, %zmm0, %zmm1, %k1
korw %k2, %k1, %k1
vmovupd %zmm1, out(%rip){%k1}
ret
明日は、@tanakmura が kortestw について書きます。