0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Atcoder ABC377 F問題解説「わかりやすく解説したいシリーズ」

Posted at

方針

  • マスi、jにクイーンが置かれるとi、jマスの横、縦、斜め1(右上がり)、斜め2(右下がり)の四つのラインがコマを置けないゾーンになる。この四つのラインをsetに追加していく。
    • 横のラインはi、縦のラインはj、斜め1はi+j、斜め2はi-jで管理できます。
    • setを使う理由はラインを追加していくときに重複するラインを重ねて追加せずに済むから。
F.py
N, M = map(int,input().split())
queen = [tuple(map(int,input().split())) for _ in range(M)]
hor = set()
ver = set()
dia1 = set() # i+j
dia2 = set() # i-j

for a, b in queen:
    if not a in hor:
        hor.add(a)
    if not b in ver:
        ver.add(b)
    if not a+b in dia1:
        dia1.add(a+b)
    if not a-b in dia2:
        dia2.add(a-b)

縦のラインはhor、横のラインは、斜め1はdia1、斜め2はdia2に追加しています。

  • コマが置けなくなるラインをsetで受け取りました。次のステップはansの初期値をN^2とし、ansからsetで管理している各ラインが塞いでしまうマスの数だけ引き算していきます。
    • 初期値がN^2なのは最初何もない状態で、N^2個全てのますに駒を置くことができるからです。
F.py
ans = N**2

for x in hor:
    ans -= N
for y in ver:
    ans -= N
for xpy in dia1:
    ans -= N - abs(N+1-xpy)
for xmy in dia2:
    ans -= N - abs(xmy)
    

縦、横のラインはそれぞれのラインがNマスづつ塞ぐのでansからNを引き算します。斜めは単純にはいかないです。
斜め1はsetから取り出したi+jの値をxpyとすると N - abs(N+1-xpy) を引くとうまくいきます。
斜め2はsetから取り出したi-jの値をxmyとすると N - abs(xmy) を引くとうまくいきます。
腑に落ちない方は3x3ぐらいのサイズでいいので手を動かして確かめるとわかると思います。

  • 今までライン同士の重なりを考えていないので現時点ではansから引き算し過ぎている状態です。ここからはラインが重なってるマスの数だけansに足し算してあげます。そのためにまずはライン同士が重なっている点を調べてsetに追加しいきます。
    • ラインが二つ重なってる点があれば、さっきのステップで1多く引き算してしまっているので+1して戻します。同じように三つ重なってる点があれば+2して戻します。
F.py
def ok(x,y):
    return 1<=x<=N and 1<=y<=N

mult = set()
for x in hor:
    for y in ver:
        cx = x
        cy = y
        if ok(cx,cy):
            mult.add((cx,cy))
for x in hor:
    for xpy in dia1:
        cx = x
        cy = xpy - x
        if ok(cx,cy):
            mult.add((cx,cy))
for x in hor:
    for xmy in dia2:
        cx = x
        cy = x - xmy
        if ok(cx,cy):
            mult.add((cx,cy))
for y in ver:
    for xpy in dia1:
        cx = xpy - y
        cy = y
        if ok(cx,cy):
            mult.add((cx,cy))
for y in ver:
    for xmy in dia2:
        cx = xmy + y
        cy = y
        if ok(cx,cy):
            mult.add((cx,cy))
for xpy in dia1:
    for xmy in dia2:
        cx = (xpy + xmy)//2
        cy = (xpy - xmy)//2
        if ok(cx,cy):
            mult.add((cx,cy))

縦、横、斜め1、斜め2の交わる組み合わせは全部で6通りあるので全部やります。
ラインの交わる点(cx、cy)が盤面の中に収まっているかを関数:okで確認してからset:multに追加しています。
交点の集合を取れたのでいよいよ仕上げになります。

  • 各交点が縦、横、斜め1、斜め2のラインに含まれているかを調べます。2つのラインに含まれれば、ansから1を引きます。3つのラインに含まれれば、ansから2を引きます。4つのラインに含まれればansから3を引きます。
F.py
for x, y in mult:
    cnt = -1
    if x in hor:
        cnt += 1
    if y in ver:
        cnt += 1
    if x+y in dia1:
        cnt += 1
    if x-y in dia2:
        cnt += 1
    ans += cnt
print(ans)

ラインに含まれるとわかればcnt += 1をします。2つのラインに含まれると1引く、3つだと2引く、4つだと3引く。ということはcntの初期値は−1にしておくと都合がいいですね。交点だけを扱ってるので必ず2つ以上のラインに含まれます。

以上で問題が解けました。最後に各ステップをまとめたソースコード載せておきますね。

F.py
N, M = map(int,input().split())
queen = [tuple(map(int,input().split())) for _ in range(M)]
hor = set()
ver = set()
dia1 = set() # i+j
dia2 = set() # i-j
ans = N**2
for a, b in queen:
    if not a in hor:
        hor.add(a)
    if not b in ver:
        ver.add(b)
    if not a+b in dia1:
        dia1.add(a+b)
    if not a-b in dia2:
        dia2.add(a-b)

for x in hor:
    ans -= N
for y in ver:
    ans -= N
for xpy in dia1:
    ans -= N - abs(N+1-xpy)
for xmy in dia2:
    ans -= N - abs(xmy)

def ok(x,y):
    return 1<=x<=N and 1<=y<=N

mult = set()
for x in hor:
    for y in ver:
        cx = x
        cy = y
        if ok(cx,cy):
            mult.add((cx,cy))
for x in hor:
    for xpy in dia1:
        cx = x
        cy = xpy - x
        if ok(cx,cy):
            mult.add((cx,cy))
for x in hor:
    for xmy in dia2:
        cx = x
        cy = x - xmy
        if ok(cx,cy):
            mult.add((cx,cy))
for y in ver:
    for xpy in dia1:
        cx = xpy - y
        cy = y
        if ok(cx,cy):
            mult.add((cx,cy))
for y in ver:
    for xmy in dia2:
        cx = xmy + y
        cy = y
        if ok(cx,cy):
            mult.add((cx,cy))
for xpy in dia1:
    for xmy in dia2:
        cx = (xpy + xmy)//2
        cy = (xpy - xmy)//2
        if ok(cx,cy):
            mult.add((cx,cy))
        
for x, y in mult:
    cnt = -1
    if x in hor:
        cnt += 1
    if y in ver:
        cnt += 1
    if x+y in dia1:
        cnt += 1
    if x-y in dia2:
        cnt += 1
    ans += cnt
print(ans)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?