avx512
More than 3 years have passed since last update.

AVX拡張によって、オペランドが3個に増えたx86命令でしたが、3入力1出力を必要とするfmaでも、4オペランドに拡張されることはなく、入力の組み合わせによって、積和のオペランドの順序が違うみっつの命令が用意されることになりました。

  • vfmadd132 src1 = src1*src3 + src2
  • vfmadd213 src1 = src2*src1 + src3
  • vfmadd231 src1 = src2*src3 + src1

http://news.mynavi.jp/photo/articles/2012/10/03/idf_haswell_hpc_01/images/002l.jpg

結構「わぁい!やったぁー!」という感じの名前の命令ですが、実際にプログラムを書くときは、intrinsicsで書く場合が多く、コンパイラによって隠蔽されてしまうため、"213"という文字列をソースコード上で見ることはほぼありません。悲しみ…!

ところが、安心してください!AVX-512では、命令にマスクが付けられるようになったため、機械語中のオペランドの並び順に別の意味が出てきてしまい、vfma のオペランド数が足りない問題が intrinsics では隠蔽しきれなくなりました。

その結果、

_mm512_mask3_fmadd_ps

というintrinsicが誕生しました。mask3!やったね!

_mm512_mask_fmadd_ps(A,B,C) は、 マスクのビットが立っていた場合、A*B+Cの値を、マスクのビットが立っていなかった場合、A の値を出力します。

Intel Intrinsics Guide(https://software.intel.com/sites/landingpage/IntrinsicsGuide/ ) の疑似コードによると、_mm512_mask_fmadd_ps は

FOR j := 0 to 15
    i := j*32
    IF k[j]
        dst[i+31:i] := (a[i+31:i] * b[i+31:i]) + c[i+31:i]
    ELSE
        dst[i+31:i] := a[i+31:i]
    FI
ENDFOR

dst[MAX:512] := 0

と、なっています。

_mm512_mask3_fmadd_ps(A,B,C) は、 マスクのビットが立っていなかった場合、C の値を出力します。

Intrinsics Guide では、_mm512_mask3_fmadd_ps は、

FOR j := 0 to 15
    i := j*32
    IF k[j]
        dst[i+31:i] := (a[i+31:i] * b[i+31:i]) + c[i+31:i]
    ELSE
        dst[i+31:i] := c[i+31:i]
    FI
ENDFOR

dst[MAX:512] := 0

と、なっていますね。

まあそもそも、_mm512_mask_fmadd_ps の 「A * B + C のうち、 マスクのビットが立っていなかったらAを選ぶ」という機能自体、どういう状況で使うんだよ…という感じもありますが、まあSIMD書くプログラマなんて、こういう機能を無理矢理使うのを楽しいと感じるプログラマか、こういう機能に対してボロクソに文句言いながら、無理矢理使うのを楽しいと感じるプログラマしかいないので基本的には問題無いです。

あとこれ調べてて気付きましたが、fmadd に限らず、全体的に maskz というのもあって、マスクビットが立っていなければゼロクリアする機能もあります。まあこれは使える可能性はありますね。

まとめると、↓こんな感じです。

#include <immintrin.h>
#include <stdio.h>

float in0[16] = {1.1f,1.1f,1.1f,1.1f,
                 1.1f,1.1f,1.1f,1.1f,
                 1.1f,1.1f,1.1f,1.1f,
                 1.1f,1.1f,1.1f,1.1f};

float in1[16] = {2.2f,2.2f,2.2f,2.2f,
                 2.2f,2.2f,2.2f,2.2f,
                 2.2f,2.2f,2.2f,2.2f,
                 2.2f,2.2f,2.2f,2.2f};

float in2[16] = {3.3f,3.3f,3.3f,3.3f,
                 3.3f,3.3f,3.3f,3.3f,
                 3.3f,3.3f,3.3f,3.3f,
                 3.3f,3.3f,3.3f,3.3f};



float out_nomask[16];
float out_mask[16];
float out_mask3[16];
float out_maskz[16];

int
main(void)
{
    __m512 v0;
    __m512 v1;
    __m512 v2;

    __m512 dst;

    int i;

    __mmask16 mask_16 = 0xffaaU;

    v0 = _mm512_loadu_ps(in0);
    v1 = _mm512_loadu_ps(in1);
    v2 = _mm512_loadu_ps(in2);

    dst = _mm512_fmadd_ps(v0, v1, v2);
    _mm512_storeu_ps(out_nomask, dst);

    for (i=0; i<16; i++) {
        printf("nomask %2d:%6.2f\n", i, out_nomask[i]);
    }


    puts("");

    dst = _mm512_mask_fmadd_ps(v0, mask_16, v1, v2);
    _mm512_storeu_ps(out_mask, dst);

    for (i=0; i<16; i++) {
        printf("mask   %2d:%6.2f\n", i, out_mask[i]);
    }


    puts("");

    dst = _mm512_mask3_fmadd_ps(v0, v1, v2, mask_16);
    _mm512_storeu_ps(out_mask3, dst);

    for (i=0; i<16; i++) {
        printf("mask3 %2d:%6.2f\n", i, out_mask3[i]);
    }


    puts("");

    dst = _mm512_maskz_fmadd_ps(mask_16, v0, v1, v2);
    _mm512_storeu_ps(out_maskz, dst);

    for (i=0; i<16; i++) {
        printf("maskz %2d:%6.2f\n", i, out_maskz[i]);
    }
}
$ ./sde -- ./a.out 
nomask  0:  5.72
nomask  1:  5.72
nomask  2:  5.72
nomask  3:  5.72
nomask  4:  5.72
nomask  5:  5.72
nomask  6:  5.72
nomask  7:  5.72
nomask  8:  5.72
nomask  9:  5.72
nomask 10:  5.72
nomask 11:  5.72
nomask 12:  5.72
nomask 13:  5.72
nomask 14:  5.72
nomask 15:  5.72

mask    0:  1.10
mask    1:  5.72
mask    2:  1.10
mask    3:  5.72
mask    4:  1.10
mask    5:  5.72
mask    6:  1.10
mask    7:  5.72
mask    8:  5.72
mask    9:  5.72
mask   10:  5.72
mask   11:  5.72
mask   12:  5.72
mask   13:  5.72
mask   14:  5.72
mask   15:  5.72

mask3  0:  3.30
mask3  1:  5.72
mask3  2:  3.30
mask3  3:  5.72
mask3  4:  3.30
mask3  5:  5.72
mask3  6:  3.30
mask3  7:  5.72
mask3  8:  5.72
mask3  9:  5.72
mask3 10:  5.72
mask3 11:  5.72
mask3 12:  5.72
mask3 13:  5.72
mask3 14:  5.72
mask3 15:  5.72

maskz  0:  0.00
maskz  1:  5.72
maskz  2:  0.00
maskz  3:  5.72
maskz  4:  0.00
maskz  5:  5.72
maskz  6:  0.00
maskz  7:  5.72
maskz  8:  5.72
maskz  9:  5.72
maskz 10:  5.72
maskz 11:  5.72
maskz 12:  5.72
maskz 13:  5.72
maskz 14:  5.72
maskz 15:  5.72

明日は、@tanakmura が、今Intrinsics Guide 見ていてぱっと見何やってるかよくわからなかった、_mm_fixupimm_sd について書く可能性があります。