LoginSignup
0
0

More than 1 year has passed since last update.

ABC077 C - Snuke Festival を 3通りで解く 二分探索/イベントソート

Posted at

  • 解法1: 公式解説通りです
  • 解法2: イベントソートのように解きます
  • 解法3: LeetCodeっぽい解き方をします。入力がソートされていればO(N)です。

解法1: すべてソートしておき、B[i]をもとに二分探索

公式解答通りです。
時間計算量はソートに$O(NlogN)$、各探索に$O(NlogN)$です。
空間計算量は$O(N)$です。

実装(Python)
from bisect import bisect_left, bisect_right
n = int(input())
d1 = list(map(int, input().split()))
d2 = list(map(int, input().split()))
d3 = list(map(int, input().split()))
d1.sort()
d3.sort()
ans = 0
for center in d2:
    ind1 = bisect_left(d1, center)
    ind3 = bisect_right(d3, center)
    ans += (ind1) * (n - ind3)
print(ans)

解法2: イベントソート

上中下の各値を時間とみなしてイベントソートします。ここで、$cnt1$をこれまでに出現した上パーツの個数、$cnt3$をこれまでに出現した下パーツの個数とします。

  • 中パーツのイベントを処理するとき、$それまでの1の数 \times (n - それまでの3の数)$のパーツを使うことができます
  • 同一の時間に上中下のパーツを処理するときの処理に気を付けます。上の式を睨み、下パーツのイベント、中パーツのイベント、上パーツのイベントの順で処理できるようにします。今回の実装ではそれぞれのイベントを$111, 22, 3$としてソートすることで実装しました。

時間計算量はイベントの作成に$O(NlogN)、$ソートに$O(NlogN)$、各探索に$O(1)$です。
空間計算量は$O(N)$です。

実装
n = int(input())
d1 = list(map(int, input().split()))
d2 = list(map(int, input().split()))
d3 = list(map(int, input().split()))
event = []
for x in d1: event.append( (x, 111) )
for x in d2: event.append( (x, 22) )
for x in d3: event.append( (x, 3) )
event.sort()
cnt1, cnt3 = 0, 0
ans = 0
for _, e in event:
    if e == 111: cnt1 += 1
    elif e == 3: cnt3 += 1
    elif e == 22: ans += cnt1 * (n - cnt3)
print(ans)

解法3: ソートして順に見る(ソートされているならばO(N))

上中下パーツがソートされているときに高速に動作します。

  • まず、すべてのインデックスを$i1=i2=i3=0$とし、それぞれ上中下パーツのindexとします。
  • i2をインクリメントしながら次の処理を行います。
  • i1が今のパーツよりも小さい(<)間、i1++します。これにより、i1の位置はi2のパーツよりも大きな最初のindexです。
  • i3が今のパーツ以下(<=)の間、i1++します。これにより、i3の位置はi2のパーツよりも大きな最初のindexです。
  • つまり、$i1 \times (n - i3)$個のパーツを使うことができます

これは先のイベントソートと同じように思えますが、イベントのソートが不要のため、最初からソートされているときは高速に動作します。

時間計算量は$ソートに$O(NlogN)$、各探索に$O(1)$です。
空間計算量は$O(N)$です。

実装(Python)
n = int(input())
d1 = list(map(int, input().split()))
d2 = list(map(int, input().split()))
d3 = list(map(int, input().split()))
d1.sort()
d2.sort()
d3.sort()
i1, i3 = 0, 0
ans = 0
for i2 in range(0, n):
    while i1 < n and d1[i1] <  d2[i2]: i1 += 1
    while i3 < n and d3[i3] <= d2[i2]: i3 += 1
    ans += i1 * (n-i3)
print(ans)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0