LoginSignup
0
2

More than 3 years have passed since last update.

yukicoder contest 254 参戦記

Last updated at Posted at 2020-06-26

yukicoder contest 254 参戦記

A 1095 Smallest Kadomatsu Subsequence

ナイーブに書くと O(N2) になってしまう. 最初はセグ木かなあと思ったけど、指定値以上の最小値の検索が出来る気がしなかったので、平衡二分探索木となりました. 中心の門松より左側と右側の平衡二分探索木をメンテしつつ、凸の場合は最小値、凹の場合は真ん中の大きさ以上の最小値を求めて、門松列の大きさの最小値を求めればいいだけ. O(NlogN).

package main

import (
    "bufio"
    "fmt"
    "math"
    "os"
    "strconv"
    "strings"
)

var (
    y uint = 88172645463325252
)

func xorshift() uint {
    y ^= y << 7
    y ^= y >> 9
    return y
}

type treapNode struct {
    value    int
    priority uint
    count    int
    left     *treapNode
    right    *treapNode
}

func newTreapNode(v int) *treapNode {
    return &treapNode{v, xorshift(), 1, nil, nil}
}

func treapRotateRight(n *treapNode) *treapNode {
    l := n.left
    n.left = l.right
    l.right = n
    return l
}

func treapRotateLeft(n *treapNode) *treapNode {
    r := n.right
    n.right = r.left
    r.left = n
    return r
}

func treapInsert(n *treapNode, v int) *treapNode {
    if n == nil {
        return newTreapNode(v)
    }
    if n.value == v {
        n.count++
        return n
    }
    if n.value > v {
        n.left = treapInsert(n.left, v)
        if n.priority > n.left.priority {
            n = treapRotateRight(n)
        }
    } else {
        n.right = treapInsert(n.right, v)
        if n.priority > n.right.priority {
            n = treapRotateLeft(n)
        }
    }
    return n
}

func treapDelete(n *treapNode, v int) *treapNode {
    if n == nil {
        panic("node is not found!")
    }
    if n.value > v {
        n.left = treapDelete(n.left, v)
        return n
    }
    if n.value < v {
        n.right = treapDelete(n.right, v)
        return n
    }

    // n.value == v
    if n.count > 1 {
        n.count--
        return n
    }

    if n.left == nil && n.right == nil {
        return nil
    }

    if n.left == nil {
        n = treapRotateLeft(n)
    } else if n.right == nil {
        n = treapRotateRight(n)
    } else {
        // n.left != nil && n.right != nil
        if n.left.priority < n.right.priority {
            n = treapRotateRight(n)
        } else {
            n = treapRotateLeft(n)
        }
    }
    return treapDelete(n, v)
}

func treapCount(n *treapNode) int {
    if n == nil {
        return 0
    }
    return n.count + treapCount(n.left) + treapCount(n.right)
}

func treapString(n *treapNode) string {
    if n == nil {
        return ""
    }
    result := make([]string, 0)
    if n.left != nil {
        result = append(result, treapString(n.left))
    }
    result = append(result, fmt.Sprintf("%d:%d", n.value, n.count))
    if n.right != nil {
        result = append(result, treapString(n.right))
    }
    return strings.Join(result, " ")
}

func treapMin(n *treapNode) int {
    if n.left != nil {
        return treapMin(n.left)
    }
    return n.value
}

func treapGEMin(n *treapNode, v int) int {
    if n.value == v {
        return v
    }
    if n.value > v {
        if n.left != nil {
            return treapGEMin(n.left, v)
        }
        return n.value
    }
    // n.value < v
    if n.right != nil {
        return treapGEMin(n.right, v)
    }
    return math.MaxInt64
}

func treapMax(n *treapNode) int {
    if n.right != nil {
        return treapMax(n.right)
    }
    return n.value
}

type treap struct {
    root *treapNode
    size int
}

func (t *treap) Insert(v int) {
    t.root = treapInsert(t.root, v)
    t.size++
}

func (t *treap) Delete(v int) {
    t.root = treapDelete(t.root, v)
    t.size--
}

func (t *treap) String() string {
    return treapString(t.root)
}

func (t *treap) Count() int {
    return t.size
}

func (t *treap) Min() int {
    return treapMin(t.root)
}

func (t *treap) Max() int {
    return treapMax(t.root)
}

func (t *treap) GEMin(v int) int {
    return treapGEMin(t.root, v)
}

func min(x, y int) int {
    if x < y {
        return x
    }
    return y
}

func main() {
    defer flush()

    N := readInt()
    A := make([]int, N)
    for i := 0; i < N; i++ {
        A[i] = readInt()
    }

    lt := &treap{}
    rt := &treap{}
    lt.Insert(A[0])
    for i := 2; i < N; i++ {
        rt.Insert(A[i])
    }

    result := math.MaxInt64
    for i := 1; i < N-1; i++ {
        a := lt.Min()
        b := rt.Min()
        if a <= A[i] && b <= A[i] {
            // printf("凸 %d %d %d\n", a, A[i], b)
            result = min(result, a+A[i]+b)
        }

        c := lt.GEMin(A[i])
        d := rt.GEMin(A[i])
        if c != math.MaxInt64 && d != math.MaxInt64 && c >= A[i] && d >= A[i] {
            // printf("凹 %d %d %d\n", c, A[i], d)
            result = min(result, c+A[i]+d)
        }

        lt.Insert(A[i])
        rt.Delete(A[i+1])
    }
    if result == math.MaxInt64 {
        println(-1)
    } else {
        println(result)
    }
}

const (
    ioBufferSize = 1 * 1024 * 1024 // 1 MB
)

var stdinScanner = func() *bufio.Scanner {
    result := bufio.NewScanner(os.Stdin)
    result.Buffer(make([]byte, ioBufferSize), ioBufferSize)
    result.Split(bufio.ScanWords)
    return result
}()

func readString() string {
    stdinScanner.Scan()
    return stdinScanner.Text()
}

func readInt() int {
    result, err := strconv.Atoi(readString())
    if err != nil {
        panic(err)
    }
    return result
}

func readInts(n int) []int {
    result := make([]int, n)
    for i := 0; i < n; i++ {
        result[i] = readInt()
    }
    return result
}

var stdoutWriter = bufio.NewWriter(os.Stdout)

func flush() {
    stdoutWriter.Flush()
}

func printf(f string, args ...interface{}) (int, error) {
    return fmt.Fprintf(stdoutWriter, f, args...)
}

func println(args ...interface{}) (int, error) {
    return fmt.Fprintln(stdoutWriter, args...)
}

追記: 以下でも通る. 左右からの累積 Min を使っている. 指定値以上の最小値の検索が出来ていないので Ai を B としたときの 凹 タイプの門松列が作成できる場合でも、作成できないとしてしまうことがある. 嘘解法じゃないかと思いつつ、反例を考えているが思いつかない. O(N).

from itertools import accumulate

INF = float('inf')

N, *A = map(int, open(0).read().split())

l = list(accumulate(A, min))
r = list(accumulate(A[::-1], min))[::-1]

result = INF

# 凸 タイプの門松列の場合
for i in range(1, N - 1):
    a = l[i - 1]
    b = A[i]
    c = r[i + 1]
    if a <= b and c <= b:
        result = min(result, a + b + c)

# 凹 タイプの門松列の場合
for i in range(1, N - 1):
    a = l[i - 1]
    b = A[i]
    c = r[i + 1]
    if b <= a and b <= c:
        result = min(result, a + b + c)

if result == INF:
    print(-1)
else:
    print(result)

B 1096 Range Sums

ナイーブに書くと O(N3) になってしまう. 累積和しても O(N2). 更にセグ木を投入することにより、O(NlogN) になって解けた.

from operator import add
from itertools import accumulate


class SegmentTree:
    _f = None
    _size = None
    _offset = None
    _data = None

    def __init__(self, size, f):
        self._f = f
        self._size = size
        t = 1
        while t < size:
            t *= 2
        self._offset = t - 1
        self._data = [0] * (t * 2 - 1)

    def build(self, iterable):
        f = self._f
        data = self._data
        data[self._offset:self._offset + self._size] = iterable
        for i in range(self._offset - 1, -1, -1):
            data[i] = f(data[i * 2 + 1], data[i * 2 + 2])

    def query(self, start, stop):
        def iter_segments(data, l, r):
            while l < r:
                if l & 1 == 0:
                    yield data[l]
                if r & 1 == 0:
                    yield data[r - 1]
                l = l // 2
                r = (r - 1) // 2
        f = self._f
        it = iter_segments(self._data, start + self._offset,
                           stop + self._offset)
        result = next(it)
        for e in it:
            result = f(result, e)
        return result


N, *A = map(int, open(0).read().split())

a = list(accumulate(A))

st = SegmentTree(N, add)
st.build(a)

result = 0
result += st.query(0, N)
for i in range(1, N):
    result += st.query(i, N) - a[i - 1] * (N - i)
print(result)

追記: Ai が答えに足し込まれるのは l = 1, .., i かつ r = i, ..., N のときなので、答えに足し込まれる回数は (i + 1) × (N - i) 回となると考えれば簡単だった. O(N).

N, *A = map(int, open(0).read().split())

result = 0
for i in range(N):
    # l = 0 .. i, r = i .. N - 1
    result += A[i] * (i + 1) * (N - i)
print(result)

追々記: セグ木ではなく Sparse table で解いたと言う人がいたのだが、Disjoint sparse table の間違え? Sparse table は演算に冪等性が必要だが、Min とは違い Sum にはそれはないので. Disjoint sparse table の実装を持ってないので試せなかった.

追々々記: Disjoint sparse table を実装した.

from itertools import accumulate
from operator import add


class DisjointSparseTable:
    _f = None
    _data = None
    _lookup = None

    def __init__(self, a, f):
        self._f = f
        b = 0
        while (1 << b) <= len(a):
            b += 1
        _data = [[0] * len(a) for _ in range(b)]
        _data[0] = a[:]
        for i in range(1, b):
            shift = 1 << i
            for j in range(0, len(a), shift << 1):
                t = min(j + shift, len(a))
                _data[i][t - 1] = a[t - 1]
                for k in range(t - 2, j - 1, -1):
                    _data[i][k] = f(a[k], _data[i][k + 1])
                if t >= len(a):
                    break
                _data[i][t] = a[t]
                r = min(t + shift, len(a))
                for k in range(t + 1, r):
                    _data[i][k] = f(_data[i][k - 1], a[k])
        self._data = _data
        _lookup = [0] * (1 << b)
        for i in range(2, len(_lookup)):
            _lookup[i] = _lookup[i >> 1] + 1
        self._lookup = _lookup

    def query(self, start, stop):
        stop -= 1
        if start >= stop:
            return self._data[0][start]
        p = self._lookup[start ^ stop]
        return self._f(self._data[p][start], self._data[p][stop])


N, *A = map(int, open(0).read().split())

a = list(accumulate(A))

st = DisjointSparseTable(a, add)

result = 0
result += st.query(0, N)
for i in range(1, N):
    result += st.query(i, N) - a[i - 1] * (N - i)
print(result)

C 1097 Remainder Operation

終了2分前に解けて嬉しかった. 余りは N 回以内に同じ余りが出てきてループする. ループ検出して、1周期の長さと1ループ中の増分を求めればクエリに O(1) で回答できるようになるので O(N + Q) で解けた. ABC167D - Teleporter を思い出した.

from sys import stdin
readline = stdin.readline

N = int(readline())
A = list(map(int, readline().split()))

X = 0
b = [-1]
c = [X]
used = set()
while True:
    r = X % N
    X += A[r]
    if r in used:
        break
    used.add(r)
    b.append(r)
    c.append(X)
b.append(r)
c.append(X)
loop_start = b.index(b[-1])
loop_len = len(b) - loop_start - 1

Q = int(readline())
for _ in range(Q):
    K = int(readline())
    if K < len(b):
        print(c[K])
    else:
        K -= loop_start
        a = K // loop_len
        K %= loop_len
        K += loop_start
        print(c[K] + a * (c[-1] - c[loop_start]))
0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2