LoginSignup
27
8

More than 3 years have passed since last update.

行列行列積の計算オーダーは$O(N^3)$より小さくできる

Posted at

はじめに

D論やばい
こんな記事を書いてる場合じゃない

通常の行列行列積

通常$N \times N$行列$A, B$の行列行列積$C=AB$は

for(int i=0; i<N; ++i){
  for(int j=0; j<N; ++j){
    for(int k=0; k<N; ++k){
      C[i][k] += A[i][j]*B[j][k];
    }
  }
}

と計算し,計算量は$O(N^3)$として扱うことが多いです.
しかし,行列行列積の計算量はもっと少なくすることができることが知られています.

Strassenのアルゴリズム

通常の行列行列積の方法を,少し形を変えて見てみます.


A = \begin{pmatrix}
A_{11}, A_{12} \\
A_{21}, A_{22}
\end{pmatrix}, 
B = \begin{pmatrix}
B_{11}, B_{12} \\
B_{21}, B_{22}
\end{pmatrix},  
C = \begin{pmatrix}
C_{11}, C_{12} \\
C_{21}, C_{22}
\end{pmatrix},

と置いて計算をすると,


\begin{align}
C_{11} = A_{11} B_{11} + A_{12} B_{21} \\
C_{12} = A_{11} B_{12} + A_{12} B_{22} \\
C_{21} = A_{21} B_{11} + A_{22} B_{21} \\
C_{22} = A_{21} B_{12} + A_{22} B_{22}
\end{align}

となり,N/2 x N/2 のサイズの行列行列積を8回計算することになります.

Strassenはこれを


\begin{align}
M_{1} &= \left(A_{11} + A_{22}\right) \left(B_{11} + B_{22}\right) \\
M_{2} &= \left(A_{21} + A_{22}\right) B_{11} \\
M_{3} &= A_{11} \left(B_{12} - B_{22}\right) \\
M_{4} &= A_{22} \left(B_{21} - B_{11}\right) \\
M_{5} &= \left(A_{11} + A_{12}\right) B_{22} \\
M_{6} &= \left(A_{21} - A_{11}\right) \left(B_{11} + B_{12}\right) \\
M_{7} &= \left(A_{12} - A_{22}\right) \left(B_{21} + B_{22}\right) \\
\space
\\
\space
\\
C_{11} &= M_1 + M_4 - M_5 + M_7 \\
C_{12} &= M_3 + M_5 \\
C_{21} &= M_2 + M_4 \\
C_{22} &= M_1 - M_2 + M_3 + M_6
\end{align}

と計算することで,行列行列積を7回に抑える方法を提案しました12.

確認

\begin{align}
C_{11} &= M_1 + M_4 - M_5 + M_7 \\
&= (A_{11} + A_{22}) (B_{11} + B_{22}) + A_{22} (B_{21} - B_{11}) \\
&- (A_{11} + A_{12}) B_{22} + (A_{12} - A_{22})(B_{21} + B_{22}) \\
&= A_{11}B_{11} + A_{11} B_{22} + A_{22} B_{11} + A_{22} B_{22} \\
&+ A_{22} B_{21} - A_{22} B_{11} - A_{11} B_{22} - A_{12} B_{22} \\
&+ + A_{12} B_{21} + A_{12} B_{22} - A_{22} B_{21} - A_{22} B_{22} \\
&= A_{11} B_{11} + A_{11} B_{22} - A_{11} B_{22} - A_{12} B{22} \\
&+ A_{12} B_{21} + A_{12} B_{22} + A_{22} B_{11} + A_{22} B_{22} \\
&+ A_{22} B_{21} - A_{22} B_{11} - A_{22} B_{21} - A_{22} B_{22} \\
&= A_{11} B_{11} A_{12} B_{21} \\
\space\\
C_{12} &= M_3 + M_5 \\
&= A_{11} (B_{12} - B_{22}) + (A_{11} + A_{12}) B_{22} \\
&= A_{11} B_{12} - A_{11} B_{22} + A_{11} B_{22} + A_{12} B_{22} \\
&= A_{11} B_{12} + A_{12} B_{22} \\
\space\\
C_{21} &= M_2 + M_4 \\
&= (A_{21} + A_{22})B_{11} + A_{22}(B_{21} - B_{11}) \\
&= A_{21} B_{11} + A_{22} B_{11} + A_{22} B_{21} - A_{22} B_{11} \\
&= A_{21} B_{11} + A_{22} B_{21} \\
&\space\\
C_{22} &= M_1 - M_2 + M_3 + M_6\\
&= (A_{11} + A_{22}) (B_{11} + B_{22}) - (A_{21} + A_{22}) B_{11} \\
&+ A_{11} (B_{12} - B_{22}) + (A_{21} - A_{11})(B_{11} + B_{12}) \\
&= A_{11} B_{11} + A_{11} B_{22} + A_{22} B_{11} + A_{22} B_{22} \\
&- A_{21} B_{11} - A_{22} B_{11} + A_{11} B_{12} - A_{11} B_{22} \\
&- + A_{21} B_{11} + A_{21} B_{12} - A_{11} B_{11} - A_{11} B_{12} \\
&= A_{11} B_{11} + A_{11} B_{22} + A_{11} B_{12} - A_{11} B_{22} \\
&- A_{11} B_{11} - A_{11} B_{12} - A_{21} B_{11} + A_{21} B_{11} \\
&- + A_{21} B_{12} + A_{22} B_{11} + A_{22} B_{22} - A_{22} B_{11} \\
&= A_{21} B_{12} + A_{22} B_{22}
\end{align}

計算量の確認

NxNの行列行列積$1$回 $\to$ N/2 x N/2 の行列行列積$7$回 $\to$ N/4 x N/4 の行列行列積$7^2$回 ...と,どんどん小さく行列に分解して計算することで,計算量は

$O(7^{\log_2 N}) = O((2^{\log_2 7})^{log_2 N}) = O((2^{\log_2 N})^{\log_2 7}) = O(N^{\log_2 7}) = O(N^{2.807})$

となります.元の計算は$O(N^3) = O(N^{\log_2 8})$なので,明らかに計算量が減っています.

これが最も低い計算量というわけではなく,$O(N^{2.375477})$のCoppersmith–Winogradのアルゴリズム34や$O(N^{2.3728639})$
のGallのアルゴリズム5などがあるみたいです.

実装

じゃあなんでこの方法を使わないのかというと,実際に使うには遅いという問題があります.
単純にStrassenのアルゴリズムを組んで,行列行列積の時間を計測してみます.

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef vector<ll> vll;
typedef vector<vector<ll> > vvll;

vvll ordinary(vvll& A, vvll& B, ll N){
  vvll Ans(N, vll(N, 0.0));
  for(ll i=0; i<N; ++i){
    for(ll k=0; k<N; ++k){
      for(ll j=0; j<N; ++j){
        Ans[i][j] += A[i][k]*B[k][j];
      }
    }
  }
  return Ans;
}

vvll Strassen(vvll& A, vvll& B, ll N){
  if(N==1){
    vvll Ans(1, vll(1, A[0][0]*B[0][0]));
    return Ans;
  }

  vvll A1_1(N/2, vll(N/2, 0.0));
  vvll A1_2(N/2, vll(N/2, 0.0));
  vvll A2_1(N/2, vll(N/2, 0.0));
  vvll A2_2(N/2, vll(N/2, 0.0));

  vvll B1_1(N/2, vll(N/2, 0.0));
  vvll B1_2(N/2, vll(N/2, 0.0));
  vvll B2_1(N/2, vll(N/2, 0.0));
  vvll B2_2(N/2, vll(N/2, 0.0));

  for(ll i=0; i<N/2; ++i){
    for(ll j=0; j<N/2; ++j){
      A1_1[i][j] = A[i][j];
      A1_2[i][j] = A[i][N/2+j];
      A2_1[i][j] = A[N/2+i][j];
      A2_2[i][j] = A[N/2+i][N/2+j];

      B1_1[i][j] = B[i][j];
      B1_2[i][j] = B[i][N/2+j];
      B2_1[i][j] = B[N/2+i][j];
      B2_2[i][j] = B[N/2+i][N/2+j];
    }
  }

  vvll M1_A(N/2, vll(N/2, 0.0));
  vvll M2_A(N/2, vll(N/2, 0.0));
  vvll M5_A(N/2, vll(N/2, 0.0));
  vvll M6_A(N/2, vll(N/2, 0.0));
  vvll M7_A(N/2, vll(N/2, 0.0));

  vvll M1_B(N/2, vll(N/2, 0.0));
  vvll M3_B(N/2, vll(N/2, 0.0));
  vvll M4_B(N/2, vll(N/2, 0.0));
  vvll M6_B(N/2, vll(N/2, 0.0));
  vvll M7_B(N/2, vll(N/2, 0.0));

  for(ll i=0; i<N/2; ++i){
    for(ll j=0; j<N/2; ++j){
      M1_A[i][j] = A1_1[i][j] + A2_2[i][j];
      M2_A[i][j] = A2_1[i][j] + A2_2[i][j];
      M5_A[i][j] = A1_1[i][j] + A1_2[i][j];
      M6_A[i][j] = A2_1[i][j] - A1_1[i][j];
      M7_A[i][j] = A1_2[i][j] - A2_2[i][j];

      M1_B[i][j] = B1_1[i][j] + B2_2[i][j];
      M3_B[i][j] = B1_2[i][j] - B2_2[i][j];
      M4_B[i][j] = B2_1[i][j] - B1_1[i][j];
      M6_B[i][j] = B1_1[i][j] + B1_2[i][j];
      M7_B[i][j] = B2_1[i][j] + B2_2[i][j];
    }
  }

  vvll M1 = Strassen(M1_A, M1_B, N/2);
  vvll M2 = Strassen(M2_A, B1_1, N/2);
  vvll M3 = Strassen(A1_1, M3_B, N/2);
  vvll M4 = Strassen(A2_2, M4_B, N/2);
  vvll M5 = Strassen(M5_A, B2_2, N/2);
  vvll M6 = Strassen(M6_A, M6_B, N/2);
  vvll M7 = Strassen(M7_A, M7_B, N/2);

  vvll Ans(N, vll(N, 0.0));

  for(ll i=0; i<N/2; ++i){
    for(ll j=0; j<N/2; ++j){
      Ans[i][j] = M1[i][j] + M4[i][j] - M5[i][j] + M7[i][j];
      Ans[i][N/2+j] = M5[i][j] + M3[i][j];
      Ans[N/2+i][j] = M4[i][j] + M2[i][j];
      Ans[N/2+i][N/2+j] = M1[i][j] - M2[i][j] + M3[i][j] + M6[i][j];
    }
  }

  return Ans;
}


int main(void){
  ll K = 16;
  srand(time(NULL));

  FILE *fp;
  char filename[32];
  sprintf(filename, "test1.dat");
  fp = fopen(filename, "w");
  fclose(fp);

  for(ll n=5; n<8; ++n){
    printf("n=%lld start\n", n);
    ll N = (ll)(1<<n);

    double time1 = 0.0;
    double time2 = 0.0;

    for(ll k=0; k<K; ++k){

      vvll A(N, vll(N, 0.0));
      vvll B(N, vll(N, 0.0));
      for(ll i=0; i<N; ++i){
        for(ll j=0; j<N; ++j){
          A[i][j] = ((double)rand())/RAND_MAX;
          B[i][j] = ((double)rand())/RAND_MAX;
        }
      }

      printf("start ordinary\n");
      clock_t start1 = clock();
      vvll Ans1 = ordinary(A, B, N);
      clock_t end1 = clock();
      printf("end ordinary\n");
      printf("start Strassen\n");
      clock_t start2 = clock();
      vvll Ans2 = Strassen(A, B, N);
      clock_t end2 = clock();
      printf("end Strassen\n");

      time1 += (double)(end1-start1)/CLOCKS_PER_SEC;
      time2 += (double)(end2-start2)/CLOCKS_PER_SEC;
    }

    fp = fopen(filename, "a");
    fprintf(fp, "%lld %.12lf %.12lf\n", N, time1/K, time2/K);
    fclose(fp);
  }

  return 0;
}

g++ -O2 でコンパイルして実行した結果が次の表です.

N ordinary [sec] Strassen [sec] Strassen / ordinary
32 0.0000249375 0.0037718750 151.253132832
64 0.0001463750 0.0260155000 177.731853117
128 0.0012548125 0.1829435625 145.793544852

普通の方法より100倍以上遅い...

再帰呼び出しをしてその中で配列を作ってるのが遅い原因なので,再帰を手でバラして配列を外に置くことにします


#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef vector<ll> vll;
typedef vector<vll> vvll;
typedef vector<vvll> vvvll;
typedef vector<vvvll> vvvvll;

vvvvll MAs(20, vvvll(7));
vvvvll MBs(20, vvvll(7));
vvvvll Ms(20, vvvll(7));

void mat_diff(vvll& Ans1, vvll& Ans2, ll N){
  double diff = 0.0;
  for(ll i=0; i<N; ++i){
    for(ll j=0; j<N; ++j){
      diff += (double)abs(Ans1[i][j] - Ans2[i][j]);
    }
  }
  printf("diff %lf\n", diff);
}

void ordinary(vvll& A, vvll& B, vvll& Ans, ll N){
  for(ll i=0; i<N; ++i){
    for(ll k=0; k<N; ++k){
      for(ll j=0; j<N; ++j){
        Ans[i][j] += A[i][k]*B[k][j];
      }
    }
  }
  return;
}

void Strassen_before(ll N, ll layer, vvll& A, vvll& B){
  for(ll i=0; i<N/2; ++i){
    for(ll j=0; j<N/2; ++j){
      MAs[layer][0][i][j] = A[i][j]     + A[N/2+i][N/2+j];
      MAs[layer][1][i][j] = A[N/2+i][j] + A[N/2+i][N/2+j];
      MAs[layer][2][i][j] = A[i][j]     + A[i][N/2+j];
      MAs[layer][3][i][j] = A[N/2+i][j] - A[i][j];
      MAs[layer][4][i][j] = A[i][N/2+j] - A[N/2+i][N/2+j];
      MAs[layer][5][i][j] = A[i][j];
      MAs[layer][6][i][j] = A[N/2+i][N/2+j];
      /*
    }
  }
  for(ll i=0; i<N/2; ++i){
    for(ll j=0; j<N/2; ++j){
    */
      MBs[layer][0][i][j] = B[i][j]     + B[N/2+i][N/2+j];
      MBs[layer][1][i][j] = B[i][N/2+j] - B[N/2+i][N/2+j];
      MBs[layer][2][i][j] = B[N/2+i][j] - B[i][j];
      MBs[layer][3][i][j] = B[i][j]     + B[i][N/2+j];
      MBs[layer][4][i][j] = B[N/2+i][j] + B[N/2+i][N/2+j];
      MBs[layer][5][i][j] = B[i][j];
      MBs[layer][6][i][j] = B[N/2+i][N/2+j];
    }
  }
}

void Strassen_after(vvll& Ans, ll N, ll layer){
  for(ll i=0; i<N/2; ++i){
    for(ll j=0; j<N/2; ++j){
      Ans[i][j]         = Ms[layer][0][i][j] + Ms[layer][3][i][j] - Ms[layer][4][i][j] + Ms[layer][6][i][j];
      Ans[i][N/2+j]     = Ms[layer][4][i][j] + Ms[layer][2][i][j];
      Ans[N/2+i][j]     = Ms[layer][3][i][j] + Ms[layer][1][i][j];
      Ans[N/2+i][N/2+j] = Ms[layer][0][i][j] - Ms[layer][1][i][j] + Ms[layer][2][i][j] + Ms[layer][5][i][j];
    }
  }
}

void Strassen0(vvll& A, vvll& B, vvll& Ans, ll N){
  Ans[0][0] = A[0][0]*B[0][0];
  return;
}

void Strassen1(vvll& A, vvll& B, vvll& Ans, ll N){
  Strassen_before(N, 1, A, B);

  Strassen0(MAs[1][0], MBs[1][0], Ms[1][0], N/2);
  Strassen0(MAs[1][1], MBs[1][5], Ms[1][1], N/2);
  Strassen0(MAs[1][5], MBs[1][1], Ms[1][2], N/2);
  Strassen0(MAs[1][6], MBs[1][2], Ms[1][3], N/2);
  Strassen0(MAs[1][2], MBs[1][6], Ms[1][4], N/2);
  Strassen0(MAs[1][3], MBs[1][3], Ms[1][5], N/2);
  Strassen0(MAs[1][4], MBs[1][4], Ms[1][6], N/2);

  Strassen_after(Ans, N, 1);

  return;
}

void Strassen2(vvll& A, vvll& B, vvll& Ans, ll N){
  Strassen_before(N, 2, A, B);

  Strassen1(MAs[2][0], MBs[2][0], Ms[2][0], N/2);
  Strassen1(MAs[2][1], MBs[2][5], Ms[2][1], N/2);
  Strassen1(MAs[2][5], MBs[2][1], Ms[2][2], N/2);
  Strassen1(MAs[2][6], MBs[2][2], Ms[2][3], N/2);
  Strassen1(MAs[2][2], MBs[2][6], Ms[2][4], N/2);
  Strassen1(MAs[2][3], MBs[2][3], Ms[2][5], N/2);
  Strassen1(MAs[2][4], MBs[2][4], Ms[2][6], N/2);

  Strassen_after(Ans, N, 2);

  return;
}

void Strassen3(vvll& A, vvll& B, vvll& Ans, ll N){
  Strassen_before(N, 3, A, B);

  Strassen2(MAs[3][0], MBs[3][0], Ms[3][0], N/2);
  Strassen2(MAs[3][1], MBs[3][5], Ms[3][1], N/2);
  Strassen2(MAs[3][5], MBs[3][1], Ms[3][2], N/2);
  Strassen2(MAs[3][6], MBs[3][2], Ms[3][3], N/2);
  Strassen2(MAs[3][2], MBs[3][6], Ms[3][4], N/2);
  Strassen2(MAs[3][3], MBs[3][3], Ms[3][5], N/2);
  Strassen2(MAs[3][4], MBs[3][4], Ms[3][6], N/2);

  Strassen_after(Ans, N, 3);

  return;
}

void Strassen4(vvll& A, vvll& B, vvll& Ans, ll N){
  Strassen_before(N, 4, A, B);

  Strassen3(MAs[4][0], MBs[4][0], Ms[4][0], N/2);
  Strassen3(MAs[4][1], MBs[4][5], Ms[4][1], N/2);
  Strassen3(MAs[4][5], MBs[4][1], Ms[4][2], N/2);
  Strassen3(MAs[4][6], MBs[4][2], Ms[4][3], N/2);
  Strassen3(MAs[4][2], MBs[4][6], Ms[4][4], N/2);
  Strassen3(MAs[4][3], MBs[4][3], Ms[4][5], N/2);
  Strassen3(MAs[4][4], MBs[4][4], Ms[4][6], N/2);

  Strassen_after(Ans, N, 4);

  return;
}

void Strassen5(vvll& A, vvll& B, vvll& Ans, ll N){
  Strassen_before(N, 5, A, B);

  Strassen4(MAs[5][0], MBs[5][0], Ms[5][0], N/2);
  Strassen4(MAs[5][1], MBs[5][5], Ms[5][1], N/2);
  Strassen4(MAs[5][5], MBs[5][1], Ms[5][2], N/2);
  Strassen4(MAs[5][6], MBs[5][2], Ms[5][3], N/2);
  Strassen4(MAs[5][2], MBs[5][6], Ms[5][4], N/2);
  Strassen4(MAs[5][3], MBs[5][3], Ms[5][5], N/2);
  Strassen4(MAs[5][4], MBs[5][4], Ms[5][6], N/2);

  Strassen_after(Ans, N, 5);

  return;
}

void Strassen6(vvll& A, vvll& B, vvll& Ans, ll N){
  Strassen_before(N, 6, A, B);

  Strassen5(MAs[6][0], MBs[6][0], Ms[6][0], N/2);
  Strassen5(MAs[6][1], MBs[6][5], Ms[6][1], N/2);
  Strassen5(MAs[6][5], MBs[6][1], Ms[6][2], N/2);
  Strassen5(MAs[6][6], MBs[6][2], Ms[6][3], N/2);
  Strassen5(MAs[6][2], MBs[6][6], Ms[6][4], N/2);
  Strassen5(MAs[6][3], MBs[6][3], Ms[6][5], N/2);
  Strassen5(MAs[6][4], MBs[6][4], Ms[6][6], N/2);

  Strassen_after(Ans, N, 6);

  return;
}

void Strassen7(vvll& A, vvll& B, vvll& Ans, ll N){
  Strassen_before(N, 7, A, B);

  Strassen6(MAs[7][0], MBs[7][0], Ms[7][0], N/2);
  Strassen6(MAs[7][1], MBs[7][5], Ms[7][1], N/2);
  Strassen6(MAs[7][5], MBs[7][1], Ms[7][2], N/2);
  Strassen6(MAs[7][6], MBs[7][2], Ms[7][3], N/2);
  Strassen6(MAs[7][2], MBs[7][6], Ms[7][4], N/2);
  Strassen6(MAs[7][3], MBs[7][3], Ms[7][5], N/2);
  Strassen6(MAs[7][4], MBs[7][4], Ms[7][6], N/2);

  Strassen_after(Ans, N, 7);

  return;
}

void Strassen8(vvll& A, vvll& B, vvll& Ans, ll N){
  Strassen_before(N, 8, A, B);

  Strassen7(MAs[8][0], MBs[8][0], Ms[8][0], N/2);
  Strassen7(MAs[8][1], MBs[8][5], Ms[8][1], N/2);
  Strassen7(MAs[8][5], MBs[8][1], Ms[8][2], N/2);
  Strassen7(MAs[8][6], MBs[8][2], Ms[8][3], N/2);
  Strassen7(MAs[8][2], MBs[8][6], Ms[8][4], N/2);
  Strassen7(MAs[8][3], MBs[8][3], Ms[8][5], N/2);
  Strassen7(MAs[8][4], MBs[8][4], Ms[8][6], N/2);

  Strassen_after(Ans, N, 8);

  return;
}

void Strassen9(vvll& A, vvll& B, vvll& Ans, ll N){
  Strassen_before(N, 9, A, B);

  Strassen8(MAs[9][0], MBs[9][0], Ms[9][0], N/2);
  Strassen8(MAs[9][1], MBs[9][5], Ms[9][1], N/2);
  Strassen8(MAs[9][5], MBs[9][1], Ms[9][2], N/2);
  Strassen8(MAs[9][6], MBs[9][2], Ms[9][3], N/2);
  Strassen8(MAs[9][2], MBs[9][6], Ms[9][4], N/2);
  Strassen8(MAs[9][3], MBs[9][3], Ms[9][5], N/2);
  Strassen8(MAs[9][4], MBs[9][4], Ms[9][6], N/2);

  Strassen_after(Ans, N, 9);

  return;
}

void Strassen10(vvll& A, vvll& B, vvll& Ans, ll N){
  Strassen_before(N, 10, A, B);

  Strassen9(MAs[10][0], MBs[10][0], Ms[10][0], N/2);
  Strassen9(MAs[10][1], MBs[10][5], Ms[10][1], N/2);
  Strassen9(MAs[10][5], MBs[10][1], Ms[10][2], N/2);
  Strassen9(MAs[10][6], MBs[10][2], Ms[10][3], N/2);
  Strassen9(MAs[10][2], MBs[10][6], Ms[10][4], N/2);
  Strassen9(MAs[10][3], MBs[10][3], Ms[10][5], N/2);
  Strassen9(MAs[10][4], MBs[10][4], Ms[10][6], N/2);

  Strassen_after(Ans, N, 10);

  return;
}

void Strassen_base(vvll& A, vvll& B, vvll& Ans, ll N){
  ll layer=10;
  Strassen_before(N, layer, A, B);

  Strassen9(MAs[layer][0], MBs[layer][0], Ms[layer][0], N/2);
  Strassen9(MAs[layer][1], MBs[layer][5], Ms[layer][1], N/2);
  Strassen9(MAs[layer][5], MBs[layer][1], Ms[layer][2], N/2);
  Strassen9(MAs[layer][6], MBs[layer][2], Ms[layer][3], N/2);
  Strassen9(MAs[layer][2], MBs[layer][6], Ms[layer][4], N/2);
  Strassen9(MAs[layer][3], MBs[layer][3], Ms[layer][5], N/2);
  Strassen9(MAs[layer][4], MBs[layer][4], Ms[layer][6], N/2);

  Strassen_after(Ans, N, layer);

  return;
}


int main(void){
  ll K = 16;
  srand(time(NULL));

  std::vector<std::function<void(vvll&, vvll&, vvll&, ll)>> functions(11);
  functions[0] = Strassen0;
  functions[1] = Strassen1;
  functions[2] = Strassen2;
  functions[3] = Strassen3;
  functions[4] = Strassen4;
  functions[5] = Strassen5;
  functions[6] = Strassen6;
  functions[7] = Strassen7;
  functions[8] = Strassen8;
  functions[9] = Strassen9;
  functions[10] = Strassen10;

  FILE *fp;
  char filename[32];
  sprintf(filename, "test3.dat");
  fp = fopen(filename, "w");
  fclose(fp);

  for(ll n=5; n<11; ++n){
    printf("n=%lld start\n", n);
    ll N = (ll)(1<<n);

    double time1 = 0.0;
    double time2 = 0.0;

    for(ll k=0; k<K; ++k){

      vvll A(N, vll(N, 0.0));
      vvll B(N, vll(N, 0.0));
      vvll Ans1(N, vll(N, 0.0));
      vvll Ans2(N, vll(N, 0.0));
      for(ll i=0; i<N; ++i){
        for(ll j=0; j<N; ++j){
          A[i][j] = ((double)rand()*100)/RAND_MAX;
          B[i][j] = ((double)rand()*100)/RAND_MAX;
        }
      }

      fprintf(stderr, "start ordinary\n");
      clock_t start1 = clock();
      ordinary(A, B, Ans1, N);
      clock_t end1 = clock();
      fprintf(stderr, "end ordinary\n");
      fprintf(stderr, "start Strassen\n");
      clock_t start2 = clock();
      int layer=0;
      int tmp=1;
      for(; tmp<N; ++layer){
        tmp*=2;
      }
      for(ll i=0; i<layer+1; ++i){
        for(ll j=0; j<7; ++j){
          MAs[i][j].resize(N/2, vll(N/2));
          MBs[i][j].resize(N/2, vll(N/2));
          Ms[i][j].resize(N/2, vll(N/2));
        }
      }
      functions[layer](A, B, Ans2, N);
      clock_t end2 = clock();
      fprintf(stderr, "end Strassen\n");

      time1 += (double)(end1-start1)/CLOCKS_PER_SEC;
      time2 += (double)(end2-start2)/CLOCKS_PER_SEC;

      //mat_diff(Ans1, Ans2, N);
    }

    fp = fopen(filename, "a");
    fprintf(fp, "%lld %.12lf %.12lf\n", N, time1/K, time2/K);
    fclose(fp);
  }

  return 0;
}

再びg++ -O2でコンパイルして実行すると,

N ordinary [sec] Strassen [sec] Strassen / ordinary
32 0.0000274375 0.0002223750 8.10478359909
64 0.0001561250 0.0010284375 6.58726981585
128 0.0012133750 0.0075082500 6.18790563511
256 0.0095398125 0.0531371250 5.57003871931
512 0.0746498125 0.4024858125 5.3916520219
1024 0.6649019375 3.0327997500 4.56127374422

まだ普通の方法より4倍以上遅い...

ま,まぁ$N$が大きいところで勝てば...ということで,このデータを$t(N) = aN^b$で最小二乗法でfittingしてます(対数plotするとt(N)が直線になるので最小二乗法が使えます)

Source a b
ordinary 8.98e-10 2.93
Strassen 1.15e-08 2.78

$\log_2 7 = 2.81$なので,計算量の推定がそれなりに上手くいってそうです.
ここからStrassenのアルゴリズムが勝つNが分かります.ずばり

$ N \approx 30000000 $

実際にここまで大きな行列を扱う時は,双方もっと最適化した方法を使うので,また違う結果になると思いますが,単純にこの程度大きくない場合はStrassenのアルゴリズムを使うメリットがない,という悲しい結果になります.

終わりに

Strassenのアルゴリズム以外の方法のアルゴリズムの実装方法が論文を読んでも分からなかったので,分かる方教えてください.

27
8
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
27
8