D - Digits Parade
「10^9+7で割ったあまりを答えて」なので答えが大変大きくなることがわかる
が、とりあえず素直に実装してみる
?をすべて9に変更した数字が最大になるので、
それまでのすべての13の倍数を列挙して、
Sの数字部分と一致できるかどうかでカウントアップ
S = input()
c = {}
for i, j in enumerate(S):
if j != "?":
c[i] = j
T = [s if s != "?" else "9" for s in S]
T = int(''.join(T))
ans = 0
for i in range(T // 13 + 1):
f = False
s = '{0:0>{1}}'.format(i * 13 + 5, len(S))
for k, v in c.items():
if s[k] != v:
f = True
break
if f is not True:
ans += 1
print(ans % (10**9 + 7))
結果:9/30
ほとんどがTLE(知ってた…)
定番のDPで解いてみる
mod(a+b) = mod(mod(a)+mod(b))なので
すべての桁に分解して数え上げる
SS:見えている数字部分で13を割った余り
c:j*10^iの13を割った余り
遷移では
dp[k+1][i] = sum(dp[k][j] for j in range(13) if (j+r)%13 == i)
S = input()
T = [s if s != "?" else "0" for s in S]
T = int(''.join(T))
SS = T % 13
c = {}
for i in range(len(S)):
if S[len(S) - i - 1] != "?":
continue
tc = {i: 0 for i in range(13)}
for j in range(0, 10):
t = j * 10 ** i
tc[t % 13] += 1
c[i] = tc
dp = []
tdp = {i: 0 for i in range(13)}
tdp[SS] += 1
dp.append(tdp)
for k, v in c.items():
tdp = {i: 0 for i in range(13)}
for r, c in v.items():
for _r, _v in dp[-1].items():
tdp[(_r + r) % 13] += _v * c
dp.append(tdp)
ans = dp[-1][5]
print(ans % (10**9 + 7))
結果:19/30
10^5個の?を入れてみると手元ですら9秒くらいかかる…
見てみるとbuiltins.execが異常に時間がかかっている
いろいろ考えてみると、すごく大きな数字をずっとそのまま扱っているせいでは、ということに
$ python -m cProfile main2.py < in9
822158247
179999 function calls in 8.878 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
2 0.000 0.000 0.000 0.000 codecs.py:319(decode)
2 0.000 0.000 0.000 0.000 codecs.py:331(getstate)
1 8.831 8.831 8.878 8.878 main2.py:1(<module>)
9999 0.012 0.000 0.012 0.000 main2.py:11(<dictcomp>)
1 0.000 0.000 0.000 0.000 main2.py:18(<dictcomp>)
9999 0.015 0.000 0.015 0.000 main2.py:22(<dictcomp>)
1 0.001 0.001 0.001 0.001 main2.py:3(<listcomp>)
2 0.000 0.000 0.000 0.000 {built-in method _codecs.utf_8_decode}
1 0.000 0.000 8.878 8.878 {built-in method builtins.exec}
1 0.001 0.001 0.001 0.001 {built-in method builtins.input}
10000 0.001 0.000 0.001 0.000 {built-in method builtins.len}
1 0.003 0.003 0.003 0.003 {built-in method builtins.print}
10000 0.001 0.000 0.001 0.000 {method 'append' of 'list' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
139987 0.012 0.000 0.012 0.000 {method 'items' of 'dict' objects}
1 0.000 0.000 0.000 0.000 {method 'join' of 'str' objects}
定番のDPで解いてみる(数字をサボらずに小さくしながら解く)
S = input()
con = 10**9 + 7
T = [s if s != "?" else "0" for s in S]
T = int(''.join(T))
SS = T % 13
def get_mod(S):
c = {}
iii = 1
for i in range(len(S)):
if S[len(S) - i - 1] != "?":
iii *= 10
iii %= 13
continue
c[i] = {(j * iii) % 13: 1 for j in range(0, 10)}
iii *= 10
iii %= 13
return c
c = get_mod(S)
def get_tdp(pdp):
tdp = {i: 0 for i in range(13)}
for r in v.keys():
for _r, _v in pdp.items():
i = (_r + r) % 13
tdp[i] += _v
for i in range(13):
tdp[i] %= con
return tdp
pdp = {i: 0 for i in range(13)}
pdp[SS] += 1
for k, v in c.items():
pdp = get_tdp(pdp)
ans = pdp[5]
print(ans)
少なくとも手元ではだいぶ速くなった
結果:26/30
$ python -m cProfile main3.py < in9
822158247
150002 function calls in 0.274 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
2 0.000 0.000 0.000 0.000 codecs.py:319(decode)
2 0.000 0.000 0.000 0.000 codecs.py:331(getstate)
1 0.004 0.004 0.274 0.274 main3.py:1(<module>)
9999 0.018 0.000 0.018 0.000 main3.py:16(<dictcomp>)
9999 0.223 0.000 0.238 0.000 main3.py:22(get_tdp)
9999 0.008 0.000 0.008 0.000 main3.py:23(<dictcomp>)
1 0.000 0.000 0.000 0.000 main3.py:32(<dictcomp>)
1 0.001 0.001 0.001 0.001 main3.py:4(<listcomp>)
1 0.010 0.010 0.029 0.029 main3.py:8(get_mod)
2 0.000 0.000 0.000 0.000 {built-in method _codecs.utf_8_decode}
1 0.000 0.000 0.274 0.274 {built-in method builtins.exec}
1 0.001 0.001 0.002 0.002 {built-in method builtins.input}
10000 0.001 0.000 0.001 0.000 {built-in method builtins.len}
1 0.000 0.000 0.000 0.000 {built-in method builtins.print}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
99991 0.007 0.000 0.007 0.000 {method 'items' of 'dict' objects}
1 0.000 0.000 0.000 0.000 {method 'join' of 'str' objects}
9999 0.001 0.000 0.001 0.000 {method 'keys' of 'dict' objects}
リスト内包表記にしたり、
関数の中に閉じ込めると速くなるので、
いろいろ試してみたが、結局パスせず
もう無理…
Python3からPyPy3に変更して提出してみる
結果:30/30 (930 ms)
外部ライブラリを利用していないのならPyPyを使うとサクッと解けるかもしれないという教訓を得た
※気力があったらPythonのままがんばるかも
※本番でPython使って解けてた方の回答は大変頭良かったので後で研究してみる…