この解法は?
https://atcoder.jp/contests/abc247/editorial/3736
の最後にある$O(NlogN)$の解法です。
前処理
解説の前半(Bを求めること)と同じです。
例えば、A=$[10, 3,2,4,3,4,3,2,1, 3, 6, 2,4, 4]$という入力があり、$X=4, Y=2$だったとすると、2未満、5を超過する数でこの配列を分割し、$[3,2,4,3,4,3,2,1]$と$[3]$と$[2,4,4]$を得ます。それぞれの配列に対して、問題と同じ組み合わせの数を求めてその和が最終的な答えとなります。(ある区間が分割する対象の数というのはその数を含んだ瞬間にそれ以上の条件を満たす区間が作れなくなるためです)
補足:この発想は次の処理を簡単にするために非常に有効なので、「次の処理を思い浮かんだからこの前処理をしたのか?」と思われるかもしれませんが私はこの問題を読んだときに、分割できるな。と思い分割しました。
明らかに一般性を失わずに入力を分割できる場合、実装を簡単にできる(例えばindexの調整や境界の配慮をしなくていい)ため、明確であればこのように分割を行うことは適切です。
各区間内の処理
今、$[3,2,4,3,4,3,2,1]$に対する組み合わせの数を考えます。この入力に対してすべての条件を満たすindexの対$(L,R)$をすべて求めたいですが、2重ループではこの長さNに対して時間計算量$O(N^2)$となってしまいます。
$L$を固定することを考えます。
- Lを含めて右にある最初のX及びYの位置を
ix
,iy
とします。(含めて、なので、iX
=Lやiy
=Lになることもあります) - ixとiyが片方でも見つからない場合、Lを固定した範囲は存在しません。なぜなら、(XかYがLより後に含まれないkからです)
- 両者のmaxよりも後のRはすべて条件を満たします。なぜなら、すでにXもYも含まれており、前処理によってX未満、Yより大きい数は排除されているため、この区間内の最後まで条件を満たせることは明らかです。
- L = 0と固定します(緑字)。最も近いXであるindex ix=2, Yであるindex iy=1であるため、max(ix,iy) = 2です。このため、(L,R) = (0,2)は条件を満たし、それ以降のRも条件を満たします。つまり、(0,2)...(0,6)の5通りが考えられます。
- 先の繰り返しになりますが、(0,2)の区間ですでにX,Yの最大・最小の条件は満たせており、この区間には最大・最小を変化させる値は(前処理をしたので)ありません。
- L = 3の時(青字)を考えます。この時、ix, iy = 4, 6です。(L,R) = (4,6)は条件を満たし、Rはこれ以上の値をとれないので、L=3の時は(3,6)のみのただ1通りとなります。
- これよりも前のRを考えると例えば(3,5)では、最小値が3となり、条件を満たせません。
- L = 5(紺色)を考えます。この時、ixが見つかりません。(L=5以降にX=4は存在しません)。このため、条件を満たすような(L,R)=(5,k)は存在しません。
実装方針と計算量の検討
前処理について、入力された配列を見ていきYからXの範囲を超える値でsplitすればよいです。
前処理の後の各区間の処理は、最初にすべての要素を$O(N)$で操作し、X, Yと一致するindexをリストの形で持ちます。各Lに対してそれぞれのリストを二分探索してやれば$O(logN)$でそれぞれixとiyが求められます。各LはN個であるので、$O(NlogN)$です。
実装
前処理はよくある配列の分割を行います。
前処理は上記の実装方針の通りに行います。lower_bound(bisect_left)で結果が見つからなかった時の処理に気を付けます。
実装(Python)
from bisect import bisect_left
n, x, y = map(int, input().split())
dat = list(map(int, input().split()))
buf = []
q = []
# 入力の配列を、min未満, max超過の数字でsplitする
for val in dat:
if y <= val <= x:
q.append(val)
continue
if len(q) > 0: buf.append(q)
q = []
if len(q) > 0: buf.append(q)
# 各分割について数を計算する
ans = 0
for dat in buf:
milist = [] # minのindex list
malist = [] # maxのindex list
for i in range(len(dat)): # を作る
if dat[i] == y: milist.append(i)
if dat[i] == x: malist.append(i)
for l in range(len(dat)): # lを固定して満たせるrを探索
r = bisect_left(malist, l) # index=l以上のxのindex見つける
if len(malist) <= r: continue # xが見つからない時は無理
a = malist[r]
r = bisect_left(milist, l) # index=l以上のyのindex見つける
if len(milist) <= r: continue # yが見つからないときは無理
b = milist[r]
r = max(a, b) # xかyの見つかったindexの大きい方
ans += len(dat) - r #それ以上は全部良いRなので組み合わせの数としてadd
print(ans)