0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PythonでSortedListを作ってみた

Last updated at Posted at 2021-09-22

はじめに

ABC217 D - Cutting Woodsでいいデータ構造がみつからず困ったので作ってみた。

作ったもの

同じ要素を複数入れられる。要素を常に昇順に保つ。追加した順は保存しない。
$E$ : 要素の数。
$N (=2^n (1 \leq n \leq 20))$ : 要素の種類数。要素は $0 \leq x \leq N-1$ の整数。
挿入 $Ο(\log_2 N)$
削除 $Ο(\log_2 N)$
xより大きい要素の最小値取得 $Ο(\log_2 N)$
xより小さい要素の最大地値取得 $Ο(\log_2 N)$
存在確認 $Ο(1)$
要素数取得 $Ο(1)$
イテレータ $Ο(E \log_2 N)$
昇順でx番目の要素取得 $Ο(\log_2 N)$
xを挿入する位置取得 $Ο(\log_2 N)$

区間にわけてそれぞれの区間内の要素数を持つことで計算している。これは何というデータ構造になるのだろう?

使い方

コンストラクタ

sl=SortedList(18)

引数は $1 \leq n \leq 20$ でこの範囲でないとエラー。指定しなかった場合は20になる。
計算量は $Ο(N)$

挿入

sl.insert(x)

$0 \leq x < N$ でないとエラー。
計算量は $Ο(\log_2 N)$

要素数

len(sl)

組み込み関数のlenに対応している。
計算量は $Ο(1)$

x番目の要素取得

sl[0]

listと同様に指定できる。一番目の要素は0になることに注意。
計算量は $Ο(\log_2 N)$

存在確認

x in sl

inに対応している。
計算量は $Ο(1)$

イテレータ

list(sl)

イテレータに対応している。
計算量は $Ο(E \log_2 N)$

削除

sl.eject(x)

要素のxを削除。xが存在しない場合はエラー。
計算量は $Ο(\log_2 N)$


sl.remove(x)

要素のxを全て消す。
計算量は $Ο(\log_2 N)$


p=sl.pop(x)

x番目の要素を削除して返す。指定しなかった場合は0番目になる。x番目が存在しない場合はエラー。
計算量は $Ο(\log_2 N)$

xより大きい要素の最小値取得

a=sl.min(x)

xより大きい要素が存在しない場合はNoneを返す。
指定しなかった場合は単に最小値を返す。min(sl)と書いてもいいがこの場合は計算量は $Ο(E \log_2 N)$ になる。
計算量は $Ο(\log_2 N)$

xより小さい要素の最大値取得

a=sl.max(x)

上と同様に扱える。
計算量は $Ο(\log_2 N)$

xを挿入する位置取得

a=sl.bisect_left(x)

$0 \leq x < N$ でないとエラー。標準ライブラリのbisect.bisect_leftと同様の動きであるため説明は省く。
挿入する位置を返すが、挿入するときにこの値が必要なわけではない。
計算量は $Ο(\log_2 N)$


a=sl.bisect_left(x)

上と同様に扱える。
計算量は $Ο(\log_2 N)$

コード

import sys
class SortedList:
    def __init__(self,n=20):
        if not 1<=n<=20:#0~1,048,575
            print('Error: not 1<=n<=20', file=sys.stderr)
            sys.exit(1)
        self.n=n
        self.b=[]
        for i in range(self.n+1):
            self.b.append([0 for j in range(2**i)])
    def __len__(self):
        return self.b[0][0]
    def __contains__(self,x):
        if not 0<=x<2**self.n:
            print('Error: not 0<=x<2**n', file=sys.stderr)
            sys.exit(1)
        return 0!=self.b[self.n][x]
    def __getitem__(self,x):
        if not 0<=x<self.b[0][0]:
            print('Error: not 0<=x<len', file=sys.stderr)
            sys.exit(1)
        s=0
        a=0
        for i in range(1,self.n+1):
            if s+self.b[i][2*a]<=x:
                s+=self.b[i][2*a]
                a=2*a+1
            else:
                a=2*a
        return a
    def __iter__(self):
        self._a=-1
        return self
    def __next__(self):
        self._a+=1
        if self._a>=self.b[0][0]:
            raise StopIteration
        return self.__getitem__(self._a)
    def insert(self,x):
        if not 0<=x<2**self.n:
            print('Error: not 0<=x<2**n', file=sys.stderr)
            sys.exit(1)
        for i in range(self.n+1):
            self.b[i][x//(2**(self.n-i))]+=1
    def eject(self,x):
        if not 0<=x<2**self.n:
            print('Error: not 0<=x<2**n', file=sys.stderr)
            sys.exit(1)
        if not self.__contains__(x):
            print('Error: not x in', file=sys.stderr)
            sys.exit(1)
        for i in range(self.n+1):
            self.b[i][x//(2**(self.n-i))]-=1
    def pop(self,x=0):
        p=self.__getitem__(x)
        self.eject(p)
        return p
    def remove(self,x):
        a=self.b[self.n][x]
        for i in range(self.n+1):
            self.b[i][x//(2**(self.n-i))]-=a
    def min(self,x=None):
        if x==None:
            x=-1
        elif x>=2**self.n-1:
            return None
        elif x<0:
            x=-1
        for i in range(self.n):
            if self.b[self.n-i][x//(2**i)+1]:
                a=x//(2**i)+1
                for j in reversed(range(i)):
                    if self.b[self.n-j][2*a]:
                        a=2*a
                    else:
                        a=2*a+1
                return a
    def max(self,x=None):
        if x==None:
            x=2**self.n
        elif x<=0:
            return None
        elif 2**self.n<=x:
            x=2**self.n
        for i in range(self.n):
            if self.b[self.n-i][x//(2**i)-1]:
                a=x//(2**i)-1
                for j in reversed(range(i)):
                    if self.b[self.n-j][2*a+1]:
                        a=2*a+1
                    else:
                        a=2*a
                return a
    def bisect_left(self,x):
        if not 0<=x<2**self.n:
            print('Error: not 0<=x<2**n', file=sys.stderr)
            sys.exit(1)
        a=0
        for i in range(self.n):
            if x%2==1:
                a+=self.b[self.n-i][x-1]
            x=x//2
        return a
    def bisect_right(self,x):
        if not 0<=x<2**self.n:
            print('Error: not 0<=x<2**n', file=sys.stderr)
            sys.exit(1)
        return self.bisect_left(x)+self.b[self.n][x]

コードの具体的な解説については、質問があれば追記する。

実用例

ABC217 D - Cutting Woods
上のコードに下のコードを付け加えればよい。クエリ先読み+座標圧縮が必要になる。

L,Q=map(int,input().split())
cx=[tuple(map(int,input().split())) for i in range(Q)]
l=[0]+sorted(list(set([cx[i][1] for i in range(Q)])))+[L]
d=dict([(l[i],i) for i in range(len(l))])
sl=SortedList(18)
sl.insert(d[0])
sl.insert(d[L])
for c,x in cx:
    if c==1:
        sl.insert(d[x])
    else:
        print(l[sl.min(d[x])]-l[sl.max(d[x])])
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?