最適化
皆さんご存じの通りC++という言語は速度しか取り柄がありません。複雑怪奇な構文、異常量のコンパイルエラー、何が起きるかわからない未定義動作たち、、、それでもなお私たちがC++を使い続けるのは、そう、極めて速いからです。ちょっと前に流行ったD言語や最近流行りのRustが速いという噂ですが、結局ほとんどの場面で速度においてC++に勝ててはいません。個人的な偏見ですがC++のように速度のみを重視し、安全性を捨てない限りRustに未来はないでしょう。まあ、人から聞いた噂によるとRustの標準ライブラリの中身はunsafeなコードが敷き詰められているらしく、なかなか速いようですが。いずれRustが廃れてUnsafe Rustの方が人気が出たりするかもしれませんね。とりあえず、最も速度面で優れている言語はC++といっても過言ではないでしょう。
そして、その圧倒的な速度を支えるのが何といってもC++コンパイラによる強力な最適化でしょう。近年では最適化があまりに強力すぎて、適当なコードを書いても高速化を意識したコードと大して速度が変わらない、ということも頻発しています。というわけで、コンパイラがどのようにコードを最適化するのか見ていきましょう。
実例
条件分岐(if文)とループ(for文)におけるコンパイラの最適化を見ていきましょう。今回は未定義動作のためにアセンブリコード量を大きく減らせるような関数の最適化は扱いません。扱うコンパイラはGCC10.1、Clang10.0.0、MSVC19.24です。すべてC++2aでコンパイルしています。
条件分岐
条件分岐のあるコードを適当に書いてみましょう。
int f1(int a, int b) { // A
if(a == 0 || b == 0) { // B
return 0; // C
}
if(a == 0 && b == 0) { // D
return -1; // E
}
return 1; // F
} // G
最適化なし
さあコンパイルしてみましょう。一旦最適化なし(-O0)で、GCCでコンパイルしてみます。アセンブリは以下のようになります。
f1(int, int):
pushq %rbp # ベースポインタの値をスタックへ保存 (A)
movq %rsp, %rbp # スタックポインタの値をベースポインタへ格納 (A)
movl %edi, -4(%rbp) # 第一引数の値をスタックフレーム上のaへ保存 (A)
movl %esi, -8(%rbp) # 第二引数の値をスタックフレーム上のbへ保存 (A)
cmpl $0, -4(%rbp) # a == 0 か? (B)
je .L2 # もしそうなら.L2タグへジャンプ (B)
cmpl $0, -8(%rbp) # b == 0 か? (B)
jne .L3 # もしそうでない.L3タグへジャンプ (B)
.L2:
movl $0, %eax # 戻り値に0を格納 return 0 (C)
jmp .L4 # .L4タグへジャンプ (C)
.L3:
cmpl $0, -4(%rbp) # a == 0 か? (D)
jne .L5 # もしそうでないなら.L5へジャンプ (D)
cmpl $0, -8(%rbp) # b == 0 か? (D)
jne .L5 # .L5タグへジャンプ (D)
movl $-1, %eax # 戻り値を-1に (E)
jmp .L4 # .L4タグへジャンプ (E)
.L5:
movl $1, %eax # 戻り値を1に (F)
.L4:
popq %rbp # スタックへ保存したベースポインタの値を復帰 (G)
ret # 呼び出し元へ復帰 (G)
ソースコードが忠実にアセンブリに翻訳されていますね。一瞬、a == 0 が評価された段階で(b == 0 を評価する前に)、ジャンプしてしまうのはおかしいと感じるかもしれませんが、規格では論理演算子operator||とoperator&&は左の項がそれぞれtrue/falseと評価されたなら右の項は評価されないので、正しい挙動です。つまりこういうことです。
int i = 0;
bool f() { i++; return true; }
int main() {
true || f(); // f()は評価されない(i == 0 のまま)
false && f(); // f()は評価されない(i == 0 のまま)
assert(i == 0);
return 0;
}
最適化あり
話が逸れましたが、元のコード(f1.cpp)に戻って、改めて最適化(-O2)をかけた状態でコンパイルしてみましょう。
f1(int, int):
testl %edi, %edi # aとaで論理積を取る。a == 0 なら0、それ以外なら1になる (B)
sete %al # 上の結果を%eaxに代入(%alは%eaxの下位8bit分) (B?)
testl %esi, %esi # bとbで論理積を取る。b == 0 なら0、それ以外なら1になる (B)
sete %dl # 上の結果を%edxに代入(%dlは%edxの下位8bit分) (B?)
orl %edx, %eax # %eax || %edx を計算し%eaxへ入れる (C)
xorl $1, %eax # 1と%eaxの排他的論理和をとって%eaxへ入れる (C, F)
movzbl %al, %eax # %eax下位8bitを%eaxへコピー(上位24bitを0詰めするのが目的か?) (C, F)
ret # 呼び出し元へ返る (G)
皆さん既にお気付きかもしれませんが、この関数では-1を返すことはありません。というのも a == 0 かつ b == 0 のときは、当然 a == 0 または b == 0 も満たしますから、関数は0を返します。つまり、コード内のD, Eの部分の条件分岐は余分であるということです。一般に条件分岐はジャンプ命令を生み、これは比較的計算時間コストが大きいため、できれば余計な条件分岐は避けるべきです。しかしながら、仮に私たちがうっかり余計な分岐を入れてしまっても、コンパイラは気づいてその部分のコードを削除してくれます。優秀ですね。
え? このくらいだったらコンパイラに頼らずとも自分で気付いて消せる? まあ普通のプログラマーならそうでしょう。しかし、その先の最適化までできた人はいるでしょうか? 改めてアセンブリコード(GCC_O2_f1.asm)を見直してみてください。このコードにはジャンプが一つもありませんね? これはつまり全ての条件分岐を消し去ったことを意味しています。コードのD, Eの部分を消せると気付く人は多いかもしれませんが、Aの部分の条件分岐まで消せると分かった人は少ないのではないでしょうか。
ソースコード再考
要するに真に高速化するためには以下のf2.cppのようなコードではなく、f3.cppのようなコードを書くべきということになります。
int f2(int a, int b) {
if (a == 0 || b == 0) {
return 0;
}
return 1;
}
int f3(int a, int b) {
return ((a == 0) | (b == 0)) ^ 1;
}
さらにこれらを最適化なしでコンパイルすると、
f2(int, int):
pushq %rbp
movq %rsp, %rbp
movl %edi, -4(%rbp)
movl %esi, -8(%rbp)
cmpl $0, -4(%rbp)
je .L1
cmpl $0, -8(%rbp)
jne .L2
.L1:
movl $0, %eax
jmp .L3
.L2:
movl $1, %eax
.L3:
popq %rbp
ret
f3(int, int):
pushq %rbp
movq %rsp, %rbp
movl %edi, -4(%rbp)
movl %esi, -8(%rbp)
cmpl $0, -4(%rbp)
sete %dl
cmpl $0, -8(%rbp)
sete %al
orl %edx, %eax
xorl $1, %eax
movzbl %al, %eax
popq %rbp
ret
f3.cppまで条件分岐を削って初めて、最適化なしでもジャンプが起こらなくなります。ちなみに、GCCで-O2の最適化オプションをつけてコンパイルすると、f2.cppもf3.cppもf1.cppと完全に同じアセンブリコード(GCC_O2_f1.asm)が生成されます。GCCは初めから最適なプログラムを見抜いているということですね。
あまり記事が長くなりすぎても良くないので、条件分岐最適化の話はここまでにしておきますが、興味のある人はf1.cppまたはf2.cppの最後の return 1; を return 2; に変えてみてください。こうするとGCCは諦めてジャンプするコードを生成しますが、Clangでコンパイルすると諦めずにすべてのジャンプを取り除きます。あと、MSVCはf1.cppとf2.cppでは return 1; の状態で既にジャンプします。
f1.cpp, f2.cpp, f3.cpp : https://godbolt.org/z/ctqCoZ (すべて一つのファイルにまとめてあります)
ループ
ある日、あなたは突然 $\sum_{i=0}^{n-1} (i^3-4i^2+2i-1)$ を求めたくて仕方がなくなりました。このくらいのプログラムならば、ループを使えば考えるまでもなく簡単に書けるでしょう。しかし、あなたは一瞬躊躇います。もしかすると今後各項にかかる係数を変えたくなるかもしれない。そこであなたはとりあえず、$\sum_{i=0}^{n-1} 1$、$\sum_{i=0}^{n-1} i$、$\sum_{i=0}^{n-1} i^2$、$\sum_{i=0}^{n-1} i^3$をそれぞれ求めることにしました。
定数和
まずは $\sum_{i=0}^{n-1} 1$ からいきましょう。
int sum0(int n) {
int sum = 0;
for (int i = 0;i < n; i++) {
sum++;
}
return sum;
}
皆さん当然お気づきかと思いますが、このループは無駄です。この関数は計算量O(n)ですが変数sumをn回インクリメントするだけなので、結果はnと等しくなるに決まってます。あなた同様コンパイラも当然この事実に気づいています。Clangで-O2でコンパイルすると、
sum_i0(int):
xorl %eax, %eax
testl %edi, %edi
cmovnsl %edi, %eax
retq
というように当然の如くループを削除します。GCCやMSVCもループを削除します。当然あなたもsum0(int)を改良して
int sum0_noloop(int n) {
if (n <= 0) {
return 0;
}
return n;
}
こうすることでしょう。
線形和
続けて、$\sum_{i=0}^{n-1} i$ にいきましょう。
int sum1(int n) {
int sum = 0;
for (int i = 0; i < n; i++) {
sum += i;
}
return sum;
}
ここまで書いて、優秀なあなたはこのコードが非効率であることに気付きます。というのも、昔習ったような公式
\sum_{i=0}^{n-1} i = \frac{n(n-1)}{2}
を使えば定数時間で計算可能だからです。GCCとMSVCは残念ながらこの公式を覚えていなかったようですが、優秀なClangはあなた同様この計算が定数時間で終わることに気付きます。Clangでコンパイルすると、
sum_i1(int):
testl %edi, %edi # nとnの論理積を取る
jle .LBB2_1 # 論理積の符号が非正なら.LBB2_1タグへ飛ぶ
leal -1(%rdi), %eax # n-1を格納
leal -2(%rdi), %ecx # n-2を格納
imulq %rax, %rcx # 64bitとしてn-1とn-2を掛け算(オーバーフローさせない目的か?)
shrq %rcx # (n-1)*(n-2)を右シフト(2で割ることに相当)
leal (%rcx,%rdi), %eax # (n-1)*(n-2)とnを足す
addl $-1, %eax # -1を足す
retq # 呼び出し元へ返る
.LBB2_1:
xorl %eax, %eax # 同じ値のXORを取ることで戻り値を0に設定する
retq # 呼び出し元へ返る
となり、これは
\sum_{i=0}^{n-1} i = \frac{n(n-1)}{2} = \frac{(n-1)(n-2)}{2} + n - 1
と変形して計算されていることがわかります。 そして、あなたも定数時間で計算できるようにコードを改良します。
int sum1_noloop(int n) {
if (n <= 0) {
return 0;
}
return static_cast<int>((static_cast<long long>(n) * static_cast<long long>(n - 1)) >> 1);
}
地味に注意が必要ですが、sum1(int)と全く同じ挙動を取らせるには掛け算でオーバーフローしないようにしましょう。int型のみで計算すると、引数によっては足し算の時には起こらなかったオーバーフローが掛け算で起こってしまいます。
二乗和、三乗和
ここまで順調ですね。この調子で $ \sum_{i=0}^{n-1} i^2 $ と $\sum_{i=0}^{n-1} i^3 $も求めましょうか。
int sum2(int n) {
int sum = 0;
for (int i = 0; i < n; i++) {
sum += i * i;
}
return sum;
}
int sum3(int n) {
int sum = 0;
for (int i = 0;i < n; i++) {
sum += i * i * i;
}
return sum;
}
やっと最後まで書ききれましたね。しかし、流れを見ればわかってしまうかもしれませんが、やはりこれら二つも定数時間で計算可能なのです。皆さんはいつか習ったはずのこの公式がパッと頭に浮かびますか? あるいは素早く導出できますか? 何乗の和でもアルゴリズミックに導出可能ですが、高次になるほど計算は複雑になりがちです。計算するとこうなります。
\sum_{i=0}^{n-1} i^2 = \frac{n(n-1)(2n-1)}{6},\space\space\space \sum_{i=0}^{n-1} i^3 = \frac{n^2(n-1)^2}{4}
しかし、優秀なClangは公式を覚えているのか瞬時にループレスなコードに最適化します。まず二乗和から見てみましょう。
sum2(int):
testl %edi, %edi # nとnの論理積を取る
jle .LBB4_1 # 論理積の符号が非正なら.LBB4_1タグへ飛ぶ
leal -1(%rdi), %eax # n-1を格納
leal -2(%rdi), %ecx # n-2を格納
imulq %rax, %rcx # 64bitとしてn-1とn-2を掛ける
leal -3(%rdi), %eax # n-3を格納
imulq %rcx, %rax # 64bitとして(n-1)*(n-2)とn-3を掛ける
shrq %rax # 右シフト(3で割る前の下準備?)
imull $1431655766, %eax, %eax # 0x55555556 == 0x100000000 / 3 (切り上げ) を掛ける(3で割ることに相当?)
addl %edi, %eax # nを足す
shrq %rcx # (n-1)*(n-2)を右シフト(2で割り算)
leal (%rcx,%rcx,2), %ecx # (n-1)*(n-2)/2と(n-1)*(n-2)/2*2を足し合わせる
addl %ecx, %eax # (n-1)*(n-2)*(n-3)/6+nと3*(n-1)*(n-2)/2を足し合わせる
addl $-1, %eax # -1を足す
retq # 呼び出し元へ返る
.LBB4_1:
xorl %eax, %eax # 戻り値を0にする
retq # 呼び出し元へ返る
複雑ですが、Clangがやっていることは
\sum_{i=0}^{n-1}i^2 = \frac{n(n-1)(2n-1)}{6} = \frac{(n-1)(n-2)(n-3)}{3} + n + \frac{3(n-1)(n-2)}{2} - 1
と思われます。多分。正直3で割り算しているあたりはよくわかりません。0x55555556をかけてから右に31bitシフトするのであればわかるのだけど、、、
続けて三乗和。
sum3(int):
testl %edi, %edi
jle .LBB3_1 # n <= 0 ならジャンプ
leal -1(%rdi), %eax # n-1
leal -2(%rdi), %ecx # n-2
imulq %rax, %rcx # (n-1)*(n-2)
movq %rcx, %rax
shrq %rax # (n-1)*(n-2)/2
leal (,%rax,8), %edx # (n-1)*(n-2)/2*8
subl %eax, %edx # (n-1)*(n-2)/2*8 - (n-1)*(n-2)/2
addl %edi, %edx # 7*(n-1)*(n-2)/2 + n
leal -3(%rdi), %eax # n-3
addl $-4, %edi # n-4
imulq %rax, %rdi # (n-3)*(n-4)
imulq %rcx, %rdi # (n-1)*(n-2)*(n-3)*(n-4)
imull %eax, %ecx # (n-1)*(n-2)*(n-3)
andl $-2, %ecx # (n-1)*(n-2)*(n-3)の最下位bitを0に(-2 == 0xFFFFFFFE、無意味な操作?)
leal (%rdx,%rcx,2), %eax # 7*(n-1)*(n-2)/2+n + (n-1)(n-2)(n-3)*2
shrq $2, %rdi # (n-1)*(n-2)*(n-3)*(n-4)/4
andl $-2, %edi # (n-1)*(n-2)*(n-3)*(n-4)/4の最下位bitを0に(無意味な操作?)
addl %edi, %eax # (n-1)*(n-2)*(n-3)*(n-4)/4 + 7*(n-1)*(n-2)/2+n+(n-1)(n-2)(n-3)*2
addl $-1, %eax # (n-1)*(n-2)*(n-3)*(n-4)/4+7*(n-1)*(n-2)/2+n+(n-1)(n-2)(n-3)*2 - 1
retq
.LBB3_1:
xorl %eax, %eax
retq
もう何が何やらよくわかりませんね。途中で謎の論理積を取って最下位bitを0にしていますが、3つの連続する整数の積はもともと偶数ですし、4つの連続する整数の積は8の倍数になることが保証されますから、特に意味のない操作に思えます。もしかしたら何か深い意味があるのかもしれませんが(オーバーフロー関連か?)。おそらくですが、Clangは下のような形で和を求めています。
\begin{eqnarray}
\sum_{i=0}^{n-1} i^3 &=& \frac{n^2(n-1)^2}{4}\\
&=& \frac{(n-1)(n-2)(n-3)(n-4)}{4}+\frac{7(n-1)(n-2)}{2}+n+2(n-1)(n-2)(n-3)-1
\end{eqnarray}
ここまでみて、何となくClangもアルゴリズミックに級数を求めていそうな気がしますね。$d-1$乗の和に対して、定数項と$n$と$ \space_{n-1} P_k=(n-1)(n-2)\cdots(n-k) \space\space (k=2,3,...,d)$ の項をそれぞれ計算して、各項の係数を調整しているように見えてきますね。え? 見えない?
複雑な級数
ここまで来て、やっと本来求めたかった $\sum_{i=0}^{n-1} (i^3-4i^2+2i-1)$ が求められます。ところで、今回はのちのちのことを考えて各項分けて級数を求めましたが、もしも各累乗和に分割せずにそのまま求めたらどうなっていたのでしょうか。Clangはちゃんと定数時間のコードを書いてくれたでしょうか。試してみましょう。
int sum_all(int n) {
int sum = 0;
for (int i = 0; i < n; i++) {
sum += i * (i * (i - 4) + 2) - 1; // i^3 - 4*i^2 + 2*i - 1
}
return sum;
}
これをコンパイルすると、
sum_all(int):
testl %edi, %edi
jle .LBB4_1
leal -1(%rdi), %eax
leal -2(%rdi), %ecx
imulq %rax, %rcx
leal -3(%rdi), %eax
imulq %rcx, %rax
leal -4(%rdi), %edx
imulq %rax, %rdx
shrq $2, %rdx
addl %edi, %edi
shrq %rcx
shrq %rax
orl $1, %edx
subl %edi, %edx
leal (%rcx,%rcx,2), %ecx
subl %ecx, %edx
imull $-1431655764, %eax, %eax
addl %edx, %eax
retq
.LBB4_1:
xorl %eax, %eax
retq
うん、定数時間。何やってるかはよくわからないけど、とりあえずClangは賢い。
sum0.cpp、sum1.cpp、sum2.cpp、sum3.cpp、summ_all.cpp : https://godbolt.org/z/YU7W3T (すべて一つのファイルにまとめてあります)
おわりに
適当に書いてもC++は速い。それでも我々はより速いコードを書かねばならない。お、ちょっとかっこいいこと言った気がする。
おしまい