AtCoder Beginner Contest 169 B問題「Multiplication 2」の解説を行います。
#問題概要
$N$個の整数で構成された数列 ($A_1,...,A_N$)が与えられる。
$A_1×...×A_N$, つまり
\prod_{i=1}^{N} A_i
を求めよ。ただし、この値が$10^{18}$を超える場合は-1
と出力せよ。
##制約
・$2 \leq N \leq 10^5$
・$0 \leq A_i \leq 10^{18}$
・入力は全て整数
#解説
この問題ですが多倍長整数の知識が必要な問題でした。これらの数列の値を、int
型で入力しかけ算をするとオーバーフローが発生してしまいます。また、制約が$10^{18}$までであることから、long int
型で入力しかけ算をしてもオーバーフローが発生してしまいます。
従って、何かしらの対策をしないといけません。
##対策1 (数列を降順にソートする)
降順にソートすると、数列を、数字の大きい順に並び替えることが出来ます。数列をソートしても総積は同じなので、これは有効な手です。
その後、数列の値を掛け合わせて$10^{18}$を超えた時に-1
と出力し処理を終了すればよいです。
計算量は ソートに$O(logN)$, かけ算に$O(N)$かかるので$O(NlogN)$です。
また、処理を終了しないと、一部言語ではTLEになります。理由は単純で、多倍長整数のかけ算に計算時間がかかってしまうからです。
例えば、Python3,C++,Javaで、$100×100$, $10^5×10^5$, $10^9×10^9$のかけ算、また$10^{18}$の複数回のかけ算をそれぞれ実行して、計算時間を比較してみます
(AtCoderのコードテストを利用しました)
(Python3ではint
,C++ではlong long int
, Javaではdouble
型として値を入力して実行しました。)
(オーバーフローが発生する前提で計算をしています)
計算 | Python3 | C++ | Java |
---|---|---|---|
$100×100$ | 18ms | 6ms | 122ms |
$10^5×10^5$ | 21ms | 10ms | 106ms |
$10^9×10^9$ | 18ms | 7ms | 122ms |
$10^{18}×10^{18}$ | 19ms | 8ms | 122ms |
${10^{18}}^{100}$ | 18ms | 6ms | 118ms |
${10^{18}}^{10000}$ | 585ms | 8ms | 113ms |
${10^{18}}^{100000}$ | 10500ms | 9ms | 118ms |
ご覧の通り、C++やJavaでは計算時間にさほど差はみられませんが、Python3だと、計算回数が多くなることで実行時間制限である2sec(2000ms)より多くの計算時間がかかっていることが分かります。
##対策2 (0に注意する)
数列の中に$0$が1つでも存在していると、総積は当然$0$になりますが、対策1と同様のことをすると、実は$0$を含んだ数列でも-1
が出力されてしまいます。
理由は簡単です。ソートすることで$0$でない、値が大きい項により、総積が$10^{18}$を超えたと判定され、-1
と出力されてしまうからです。
これが、コーナーケースであるzero_01.txt
などでWA判定が出てしまう原因になります。
なので、数列を読み込むとき、$0$の項が存在していたら1にするフラグ関数を作ってしまいましょう。そうすると、かけ算をする前に総積が$0$であるかを判定することが出来ます!
Python3では A.count(0)
で$0$の項が存在しているかを確かめることができます。
##まとめ・解答例
つまり、この問題に対する各言語のポイントは以下のようになります。
・Python3
多倍長整数同士のかけ算は計算時間が多くかかるので避けた方がよい
・C++,Java
オーバーフローによる計算の誤りを避けた方が良い
以下、Python3,C++,Javaでの解答例を示します。
Python3での解答例
N = int(input())
A = list(map(int,input().split()))
A.sort()
A.reverse()
if A.count(0) > 0:
print(0)
else:
f = 0
ans = 1
for i in range(N):
ans *= A[i]
if ans > 10**18:
f += 1
print(-1)
break
if f == 0:
print(ans)
C++での解答例
#include<bits/stdc++.h>
using namespace std;
int main(){
int n;
cin >> n;
vector<long int> vec(n);
for (int i = 0; i < n; i++){
long int a;
cin >> a;
vec.at(i) = a;
if(vec.at(i) == 0){
cout << 0 << endl;
return 0;
}
}
sort(vec.begin(),vec.end());
reverse(vec.begin(),vec.end());
long int ans = 1;
for (int i = 0; i < n; i++){
if (vec.at(i) > 1000000000000000000/ans){
cout << -1 << endl;
return 0;
}else{
ans *= vec.at(i);
}
}
cout << ans << endl;
}
Javaでの解答例
import java.util.Scanner;
import java.util.Arrays;
public class Main{
public static void main(String[] args){
Scanner scan = new Scanner(System.in);
int n = scan.nextInt();
long [] a = new long[n];
for (int i = 0; i < n; i++){
a[i] = scan.nextLong();
if (a[i] == 0){
System.out.println(0);
return;
}
}
Arrays.sort(a);
long ans = 1;
for (int i = n-1; i >= 0; i--){
if (a[i] > 1000000000000000000L/ans){
System.out.println(-1);
return;
}else{
ans *= a[i];
}
}
System.out.println(ans);
}
}