5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

LJの力計算を組み込み関数で書いてみる

Posted at

#はじめに

LJの力計算をフルアセンブラで書いてみるの続き。

思想

フルアセンブラでループ内SIMD化してみたのはいいけど、それでループアンロールとかやりたくない。intrinsicで書けばコンパイラがやってくれないかと期待。

計算環境

  • Intel(R) Xeon(R) CPU E5-2680 v3 @ 2.50GHz
  • icpc (ICC) 16.0.3
  • g++ (GCC) 4.8.5

コード

長いが全部載せる。それぞれ以下の通り。

  • calc_intrin intrinsicで書き下したSIMDコード
  • calc_asm アセンブラで書いたコード
  • check ナイーブに書いたコード (icpc向け)
  • check_opt 内側のループで余計なロード/ストアをしなくて良いように明示的に指定したコード (g++向け)
test.cpp
//------------------------------------------------------------------------
#include <immintrin.h>
#include <stdio.h>
//------------------------------------------------------------------------
enum {X, Y, Z};
const int N = 20000;
const double dt = 0.01;
double __attribute__((aligned(32))) q[N][4] = {};
double __attribute__((aligned(32))) p[N][4] = {};
double __attribute__((aligned(32))) c24[4] = {24 * dt, 24 * dt, 24 * dt, 0};
double __attribute__((aligned(32))) c48[4] = { -48 * dt, -48 * dt, -48 * dt, 0};
//------------------------------------------------------------------------
#ifdef MYINTRIN
typedef double v4df __attribute__((vector_size(32)));
void
calc_intrin(void){
  const double c24[4] = {24*dt,24*dt,24*dt,0};
  const double c48[4] = {48*dt,48*dt,48*dt,0};
  const v4df vc24 = _mm256_load_pd((double*)(c24));
  const v4df vc48 = _mm256_load_pd((double*)(c48));
  for (int i = 0; i < N; i++) {
    const v4df vqi = _mm256_load_pd((double*)(q+i));
    v4df vpi = _mm256_load_pd((double*)(p+i));
    for (int j = i + 1; j < N; j++) {
      v4df vqj = _mm256_load_pd((double*)(q+j));
      v4df vdq = (vqj - vqi);
      v4df vd1 = vdq*vdq;
      v4df vd2 = _mm256_permute4x64_pd(vd1,201);
      v4df vd3 = _mm256_permute4x64_pd(vd1,210);
      v4df vr2 = vd1 + vd2 + vd3;
      v4df vr6 = vr2 * vr2 * vr2;
      v4df vdf = (vc24 * vr6 - vc48)/(vr6*vr6*vr2);
      v4df vpj = _mm256_load_pd((double*)(p+j));
      vpi += vdq * vdf;
      vpj -= vdq * vdf;
      _mm256_store_pd((double*)(p+j),vpj);
    }
    _mm256_store_pd((double*)(p+i),vpi);
  }
}
#endif
//------------------------------------------------------------------------
#ifdef MYASM
static void
calc_asm(void) {
  __asm__ (
    "xor %r10, %r10\n\t" // r10 = i = 0
    "mov $q, %ecx\n\t"
    "mov $p, %edx\n\t"
    "vmovapd c24(%rip), %ymm8\n\t" // ymm4 = (24,24,24,0)
    "vmovapd c48(%rip), %ymm9\n\t" // ymm5 = (48,48,48,0)
    "I.LOOP:\n\t"
    "vmovapd (%ecx), %ymm7\n\t" // qi
    "vmovapd (%edx), %ymm6\n\t" // pi
    "mov %ecx, %ebx\n\t"
    "mov %edx, %eax\n\t"
    "mov %r10, %r11\n\t" // r11 = j = i
    "inc %r11\n\t"
    "J.LOOP:\n\t"
    "add $32,%ebx\n\t"
    "add $32,%eax\n\t"
    "vmovapd (%ebx), %ymm1\n\t" // qj
    "vsubpd %ymm1, %ymm7, %ymm0\n\t" // ymm0 = (dx, dy, dz,0 )
    "vmulpd %ymm0, %ymm0, %ymm1\n\t" // ymm1 = (dx**2, dy**2, dz**2,0)
    "vpermpd $201, %ymm1, %ymm2\n\t" // ymm2 = (dy**2, dz**2, dx**2,0)
    "vpermpd $210, %ymm1, %ymm3\n\t" // ymm3 = (dz**2, dx**2, dy**2,0)
    "vaddpd %ymm1, %ymm2, %ymm1\n\t" // ymm1 = (xy,yz,zx,0)
    "vaddpd %ymm1, %ymm3, %ymm1\n\t" // ymm1 = (r2,r2,r2,0)
    "vmulpd %ymm1, %ymm1, %ymm2\n\t" // ymm2 = (r4,r4,r4,0)
    "vmulpd %ymm1, %ymm2, %ymm2\n\t" // ymm2 = (r6,r6,r6,0)
    "vmulpd %ymm2, %ymm2, %ymm3\n\t" // ymm3 = (r12,r12,r12,0)
    "vmulpd %ymm1, %ymm3, %ymm3\n\t" // ymm3 = (r14,r14,r14,0)
    "vfmadd132pd %ymm8, %ymm9, %ymm2\n\t" // ymm2 = (24.0*r6-48)*dt
    "vdivpd %ymm3, %ymm2, %ymm1\n\t" //ymm1 = (24.0*r6-48)/r14*dt =df
    "vmulpd %ymm1, %ymm0, %ymm0\n\t" // ymm0 = (df*dx, df*dy, df*dz, 0)
    "vaddpd (%eax), %ymm0, %ymm5\n\t"// pj -= df*dt;
    "vsubpd %ymm0, %ymm6, %ymm6\n\t"// pi += df*dt;
    "vmovapd %ymm5, (%eax)\n\t" // ymm5 -> pj
    "inc %r11\n\t" //j++
    "cmp $20000,%r11\n\t"
    "jl J.LOOP\n\t"
    "vmovapd %ymm6, (%edx)\n\t" // ymm6 -> pi
    "add $32,%ecx\n\t"
    "add $32,%edx\n\t"
    "inc %r10\n\t" //i++
    "cmp $19999,%r10\n\t"
    "jl I.LOOP\n\t"
  );
}
#endif
//------------------------------------------------------------------------
void
init(void) {
  for (int i = 0; i < N; i++) {
    q[i][X] = 1.0 + 0.4 * i;
    q[i][Y] = 2.0 + 0.5 * i;
    q[i][Z] = 3.0 + 0.6 * i;
    p[i][X] = 0.0;
    p[i][Y] = 0.0;
    p[i][Z] = 0.0;
  }
}
//------------------------------------------------------------------------
void
check(void) {
  for (int i = 0; i < N; i++) {
    for (int j = i + 1; j < N; j++) {
      const double dx = q[j][X] - q[i][X];
      const double dy = q[j][Y] - q[i][Y];
      const double dz = q[j][Z] - q[i][Z];
      const double r2 = (dx * dx + dy * dy + dz * dz);
      double r6 = r2 * r2 * r2;
      double df = (24.0 * r6 - 48.0) / (r6 * r6 * r2) * dt;
      p[i][X] += df * dx;
      p[i][Y] += df * dy;
      p[i][Z] += df * dz;
      p[j][X] -= df * dx;
      p[j][Y] -= df * dy;
      p[j][Z] -= df * dz;
    }
  }
}
//------------------------------------------------------------------------
void
check_opt(void) {
  for (int i = 0; i < N; i++) {
    const double qix = q[i][X];
    const double qiy = q[i][Y];
    const double qiz = q[i][Z];
    double pix = p[i][X];
    double piy = p[i][Y];
    double piz = p[i][Z];
    for (int j = i + 1; j < N; j++) {
      const double dx = q[j][X] - qix;
      const double dy = q[j][Y] - qiy;
      const double dz = q[j][Z] - qiz;
      const double r2 = (dx * dx + dy * dy + dz * dz);
      double r6 = r2 * r2 * r2;
      double df = (24.0 * r6 - 48.0) / (r6 * r6 * r2) * dt;
      pix += df * dx;
      piy += df * dy;
      piz += df * dz;
      p[j][X] -= df * dx;
      p[j][Y] -= df * dy;
      p[j][Z] -= df * dz;
    }
    p[i][X] = pix;
    p[i][Y] = piy;
    p[i][Z] = piz;
  }
}
//------------------------------------------------------------------------
int
main(void) {
  init();
#ifdef MYINTRIN
  calc_intrin();
#elif MYASM
  calc_asm();
#elif GCC
  check_opt();
#else
  check();
#endif

  for (int i = 0; i < 10; i++) {
    printf("%f %f %f\n", p[i][X], p[i][Y], p[i][Z]);
  }
}
//------------------------------------------------------------------------

結果

コンパイルオプション 速度 [s] 備考
icpc -O3 -xHOST -DMYINTRIN 1.996 intrinsic使った奴
g++ -O3 -DMYASM test.cpp 2.000 アセンブラで書いた奴
icpc -O3 -xHOST -O3 0.916 インテルコンパイラに任せた奴
g++ -O3 -mavx -DGCC 1.732 g++に任せた奴

GCCはこのコードでは最内側ループで不変な式を外出しできないので、それを助けてやると速くなる。で、結局インテルコンパイラにもGCCにも負けた。インテルコンパイラは、こちらが余計なことをしないでナイーブなコードを食わせた方が速いコードを吐く。

intrinsic使ったけど、アセンブラと速度が変わらない。っていうかループアンロールもしてくれてないし、吐いてるアセンブリが手で書いた奴と大差ない。そもそもmovapdを出してくれてない(movupdになってる)。なんでだ?

やっぱりこの思想のSIMD化はダメかしらん?

まとめ

人外への道は遠かった。

5
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?