はじめに
プログラミングにほぼ触れたことのない友人から「Python教えて!」と言われて勉強会を開いていたところ、ABC163C - managementを「良問だ!」と叫び倒しておられました。
実際、1つの問題の中で「コードを書けるようになる」というフェーズから「計算量を改善することができる」というフェーズへのシフトが興味深かったので、やりとりをかい摘んで記しそうと思います。
この記事の対象者
for・if・リスト等の文法はある程度覚えて、ロジックには自信があるのになぜかTLEになってしまう人。
問題概要
社員番号$1, 2, \cdots, N$のそれぞれの直属の部下の数を出力せよ。$i$番目の社員の上司を$A_i$として、$A_2, \cdots, A_N$が与えられる。ただし、ある社員の社員番号よりも上司の社員番号は若い番号である。
計算量改善の軌跡
2回TLEになって、3回目でACだったという流れです。
1. まずは愚直に
方針
社員番号1からNまで順番に、上司リストAの中身を検索しヒット数をprintします。リストに対して指定した要素の個数を出力させるにはcount()
メソッドが使えるので、それを使います。
N = int(input())
A = list(map(int, input().split()))
for i in range(1, N + 1):
count = A.count(i)
print(count)
提出!
TLE!→なんでダメなの?
制約が$2 \le N \le 2 \times 10^5$なのに、計算量が$O(N^2) \approx 10^{10}$となってしまうためです。
count()
というメソッドはリストの要素を端から端まで比較しながらカウントします。
リストの要素数Nとすると、count()
が呼び出される度にN回計算する。それがN回ループするfor文の中に書いてあるので、$O(N^2)$です。
なお、プログラミングコンテストチャレンジブック(通称:蟻本)には、コンパイル言語であるC++で制限時間が1秒の場合
$10^6$ 余裕を持って間に合う
$10^7$ おそらく間に合う
$10^8$ 非常にシンプルな処理でない限り厳しい
とあります。Pythonで計算量$10^{10}$など、到底間に合いません。
2. 途中で打ち切ってみる
方針
count()
で得た部下の数の総和が$n - 1$になった時点で、それ以上このメソッドに頼る必要はありません。そこで、総和を保持する変数total
を初期値0で定義して、count()
を呼び出すたびにその結果を足し合わせます。
ただし、それだけだと余った社員分0
を何回出力すれば良いのか分からないので、打ち切り時の社員番号をt
という変数に保存します。
N = int(input())
A = list(map(int, input().split()))
total = 0
t = 0
for i in range(1, N):
count = A.count(i)
print(count)
total += count
if total == N - 1:
t = i
break
for _ in range(t, N):
print(0)
提出!
2度目のTLE!→これでもダメなの?
最悪計算量 が$O(N^2)$となってしまうためです。例えば、
7
1 2 3 4 5 6
というケースは、社員番号1から始まり、6まで見たところでようやくcount()
を使うforループを打ち切ることができます。このように、社員$1, \cdots, N - 1$まで見ないと計算が終わらない最悪のケースを想定すると、結局は$O(N^2)$なのです。
ということで、社員ごとに部下の数を毎度計算していたのではとてもダメそうです。
3. 別の配列に結果を保存する
方針
これまでは各社員に対して毎回直属の部下の数を計算していましたが、各社員の部下の数が収められたリストをどうにかして$O(N)$で作ることができないか考えてみます。最終的にそのリストの内容を一つずつ出力すれば良いということです。
そのためには、直属の上司のリスト$A$を端から端まで見ていくときにそれぞれの社員番号を何回見たか記録する必要があります。記録先として$N$人分のデータを保持できるものが必要です1が、それはインデックスがそれぞれの社員に対応しており、要素が部下の数であるような長さ$N$のリストとして保持するのが良さそうです。
図にすると、こんな感じです。上のカードが、リストA、下の容れ物が「記録先」です。
N = int(input())
A = list(map(int, input().split()))
B = [0] * N
for i in A:
B[i - 1] += 1
for i in B:
print(i)
めでたくAC!2
さいごに
この計算量でダメだなと概算して、異なる方針を考えるというのは本当に大事なことだと思います。
補足/list.countの実装について
list.count
の実装がどうなっているのか、cpython/Objectsから該当部分を抜粋します。
/*[clinic input]
list.count
value: object
/
Return number of occurrences of value.
[clinic start generated code]*/
static PyObject *
list_count(PyListObject *self, PyObject *value)
/*[clinic end generated code: output=b1f5d284205ae714 input=3bdc3a5e6f749565]*/
{
Py_ssize_t count = 0;
Py_ssize_t i;
for (i = 0; i < Py_SIZE(self); i++) {
PyObject *obj = self->ob_item[i];
if (obj == value) {
count++;
continue;
}
Py_INCREF(obj);
int cmp = PyObject_RichCompareBool(obj, value, Py_EQ);
Py_DECREF(obj);
if (cmp > 0)
count++;
else if (cmp < 0)
return NULL;
}
return PyLong_FromSsize_t(count);
}
C言語を普段書いているわけではないので厳密な議論はしませんが、
- for文でlistを端から端まで見て:
for (i = 0; i < Py_SIZE(self); i++)
- if文で一致性を確認し:
if (obj == value)
- 一致していればカウンタを1進めて:
count++;
いますね。ということでlist.count
はO(N)と言えそうです。