思ったより実装に時間かかった。
題意
- $n \leq 10^5$個の各要素 $- 10^9 \leq a_i \leq 10^9 $からなる配列$a$が与えられます。
- この連続している部分配列の和が、$- 10^{12} \leq t \leq 10^{12} $未満のものは全部でいくつあるか?(配列は重複あり、ソートなし)
こう考えた
問題は言い換えると、区間$[l, r)$の部分和が$t$未満の部分和となる、$l,r$の数を求めたい。と言えます。
累積和を取りセグメントツリーに入れます。つまり、$a_0$からある$a_i$までの和をindexとして、出現回数を値にするのセグメントツリー$st$を持ちます。ただし、値は負を取りうることと、値の範囲がとても大きいため、座圧を行います。例えば、
例題を例にとります。$[5,-1,3,4,-1]$であれば、累積和は$[5,4,7,11,10]$となるので、座圧のためのテーブル$[4,5,7,10,11]$を持ち、$1,0,2,4,3$に変換します。
この時、stは$[1,1,1,1,1]$というテーブルになります。
ここで、$l=0$と固定したとします。この時、$t=4$であると、累積和のテーブルの中でt未満となる数を列挙すればよいので、上記で計算した(累積和のカウントを座圧した数値をセグメントツリーにしている)stにクエリすればよいです。クエリの対象は、$t=4$未満ですが、座圧後の値にする必要があります。$t=0$とは、座圧するとindexが$0$です。$st.query[0,0) = 0$が答えになります。
つぎに、$l=1$と固定したとしましょう。まず、この処理に入る前に、$i=0$に対応する累積和の要素を消して、セグメントツリーの該当要素を$-1$します。
この時、再度、累積和を計算したくなりますが、これを繰り返すと$O(N^2)$の時間がかかります。このため、逆に、$t$を累積和に合わせて変更します。さて、最初に計算した累積和$[5,4,7,11,10]$ですが、$l=1$から累積和を行うと$[-1, 2, 6, 5]$です。さて、これを見ると、最初の累積和の2要素目以降から1要素目の$5$を引いたものがわかります。累積和の性質を考えると1要素目が抜けたため、明らかです。
ということは、$l=1$のとき最初に求めた各要素からa_0の要素の値(この場合は5)を引いて、t=4未満の要素であればいいので、言い換えれば、最初に求めた各要素からa_0の要素の値がt=9未満であればよいです。
ここで、座圧テーブル$[4,5,7,10,11]$で$9$未満のindexを考えると$3$です。この際、lower bound(Pythonならbisect left)で考えます。このまま、[l, index)のクエリを行うと、まるで、$l=0$の時のクエリと同じように、$l=1$の時の組み合わせが求められます。
実装
def do():
st = segmentTreeSum()
n, t = map(int, input().split())
dat = list(map(int, input().split()))
dattotal = []
total = 0
segtreeList = [0] * (200000 + 10)
zatsu = set()
for x in dat:
total += x
dattotal.append(total)
zatsu.add(total)
zatsu = list(zatsu)
zatsu.sort()
zatsuTable = dict()
zatsuTableRev = dict()
for ind, val in enumerate(zatsu):
zatsuTable[val] = ind
zatsuTableRev[ind] = val
buf = []
for x in dattotal:
buf.append(zatsuTable[x])
segtreeList[zatsuTable[x]] += 1
st.load(segtreeList)
from bisect import bisect_left, bisect_right
offset = 0
res = 0
for i in range(n): # x = total from 0 to curren
x = dattotal[i]
curvalind = zatsuTable[x]
targetval = t + offset# target val
targetind = bisect_left(zatsu, targetval)
cnt = st.query(0, targetind )
res += cnt
st.addValue(curvalind, -1)
offset += dat[i] # for next offset
print(res)
do()