競プロでたまに符号つき64bit整数の乗算のオーバーフロー判定が必要になることがある。(途中の乗算でオーバーフローすると誤判定するため)
GCCにはビルトイン関数があってそれでできるが、Visual C++にはないので、組み込み関数の乗算を利用して書いてみた。結果は2つの64bitワードとして取得できる。
以下の場合はオーバーフローしていない。
- 上位ワードがゼロで、下位ワードが非負のとき
- 上位ワードが-1で、下位ワードが負のとき
上記以外の場合はオーバーフローしている。
- 上位ワードが1以上なら64bitに収まっていない
- 上位ワードがゼロで下位ワードが負ならLLONG_MAXを超えている
- 最上位ワードが-1より小さい場合は64bitに収まっていない
- 最上位ワードが-1で、下位ワードが正のときは、LLONG_MINより小さい
#include <cassert>
#include <climits>
#include <iostream>
#ifdef _MSC_VER
#include <intrin.h>
#pragma intrinsic(_mul128)
#endif
inline bool is_mul_overflow(long long a, long long b) {
#ifdef _MSC_VER
long long high, low = _mul128(a, b, &high);
return !((high == 0 && low >= 0) || (high == -1 && low < 0));
#else
long long low;
return __builtin_smulll_overflow(a, b, &low);
#endif
}
int main(int argc, char* argv[]) {
assert(is_mul_overflow(LLONG_MAX, 1) == false);
assert(is_mul_overflow(LLONG_MAX / 2 + 1, 2) == true); // LLONG_MAX + 1
assert(is_mul_overflow(LLONG_MIN, 1) == false);
assert(is_mul_overflow(LLONG_MIN, 2) == true); // 0xFFFFFFFF00000000
return 0;
}