LoginSignup
22
16

爆速ビームサーチライブラリを作る

Last updated at Posted at 2023-07-16

この記事では非常に高速なビームサーチライブラリの実装と、そのいくつかの具体的な使用例を紹介します。ビームサーチはDPを拡張し、上位M個を保持するようにしたもので、AHC系のコンテストで主に使われる非常に汎用的な手法の一つなのですが、高速なライブラリ設計に関する知見はあまり整備されていないと認識しています。そこで本記事では何種類かの典型的なケースに対応する高速なビームサーチの実装方法とそれらのライブラリ化の手法を紹介します。この方法では、問題固有のコードとライブラリ側のコードの大部分が分離されため、本質の探索以外を実装する必要がなくなり、しかも非常に高速な動作させることが可能になります。

本記事ではビームサーチの存在を知っていることが前提となっています。そのためビームサーチの入門のような内容は含みません。知らない方はまずビームサーチに関する記事を読むか実際にビームサーチが有効な問題を解いてみることをおすすめします。
また、一部こちらの記事の内容も含みますが、読まなくても問題のない構成にしてあります。
完成形だけを見たい人はこちらをご参照ください。

なお、本記事に記載されたコード並びに完成形のコードは自由に使用・改変して頂いて構いませんが、それに伴ういかなる結果にも筆者は責任を負いません。


ビームサーチの定義

最初に本記事の目標となるビームサーチがどのような動作をするのか定義します。
本記事ではビームサーチを1次元DPを抽象化した形で扱い、ビームサーチでの1ステップを

  1. dp[i]の状態集合の中から使うもののみを選択する
    • これは上位M個の選択以外にも、重複除去や、playerの位置で区別し別々に上位k個を取るなどを含みます
  2. dp[i]の状態集合からdp[i+1],dp[i+2]へ状態を遷移させる

という動作をすることとします。


一般的なビームサーチ

一般的な(だと筆者が思っている)ビームサーチの実装が以下になります:

長いので折りたたみ
// 入力や前計算を管理する構造体
struct Input{}


#[derive(Clone)]
// 確定した状態を管理する構造体
struct Node{
    track_id:usize, // 履歴のid
}
impl Node{
    // selfからcandの遷移をしたときの新たなNodeを返す
    fn new_node(&self,cand:&Cand)->Node{
        todo!();
    }
}


#[derive(Clone)]
// Nodeと違い採用するかまだ確定していない状態を管理する構造体
struct Cand{
    op:u8, // Nodeからの操作
    parent:usize, // 親のindex
    eval_score:i64, // この大小で上位Mが決定される
    hash:u64, // 重複除去に用いられる
}
impl Cand {
    // 問題の生スコア、これを最終的に最大化したい
    fn raw_score(&self,input:&Input)->i64{
        todo!();
    }
}


const MAX_WIDTH:usize=1000;
const TURN:usize=100;

struct BeamSearch{
    track:Vec<(usize,u8)>, // 履歴を管理
    nodes:Vec<Node>,
    next_nodes:Vec<Node>,
}
impl BeamSearch{
    fn new(node:Node)->BeamSearch{
        BeamSearch{
            nodes:vec![node],
            track:vec![],
            next_nodes:vec![],
        }
    }
    
    // 候補の列挙
    fn enum_cands(&self,input:&Input,cands:&mut Vec<Cand>){
        for i in 0..self.nodes.len(){
            self.append_cands(input,i,cands);
        }
    }
    
    // 候補の採用を反映する
    fn update<I:Iterator<Item=Cand>>(&mut self,cands:I){
        self.next_nodes.clear();
        for cand in cands{
            let mut new=self.nodes[cand.parent].new_node(&cand);
            self.track.push((new.track_id,cand.op));
            new.track_id=self.track.len()-1;
            self.next_nodes.push(new);
        }
        
        std::mem::swap(&mut self.nodes,&mut self.next_nodes);
    }
    
    // 復元
    fn restore(&self,mut idx:usize)->Vec<u8>{
        idx=self.nodes[idx].track_id;
        let mut ret=vec![];
        while idx!=!0{
            ret.push(self.track[idx].1);
            idx=self.track[idx].0;
        }
        ret.reverse();
        ret
    }

    // self.nodes[idx]からの候補をcandsに積む
    fn append_cands(&self,input:&Input,idx:usize,cands:&mut Vec<Cand>){
        let node=&self.nodes[idx];
        todo!();
    }

    // ビームサーチを実行する
    fn solve(&mut self,input:&Input)->Vec<u8>{
        use std::cmp::Reverse;
        let M=MAX_WIDTH;
        
        let mut cands=Vec::<Cand>::new();
        let mut set=std::collections::HashSet::new();
        for t in 0..TURN{
            if t!=0{
                cands.sort_unstable_by_key(|a|Reverse(a.eval_score));
                set.clear();
                // 重複除去をして上位M個を取り出す
                self.update(cands.iter().filter(|cand|
                    set.insert(cand.hash)
                ).take(M).cloned());
            }
            cands.clear();
            self.enum_cands(input, &mut cands);
            assert!(!cands.is_empty(),"次の合法手が存在しないよ");
        }

        // 一番いい状態を取り出す
        let best=cands.iter().max_by_key(|a|a.raw_score(input)).unwrap();
        eprintln!("score = {}",best.raw_score(input));

        let mut ret=self.restore(best.parent);
        ret.push(best.op);

        ret
    }
}

特徴は、update関数で選択された状態集合をIteratorで受け取り内部状態を更新します。
これによって状態の列挙や重複除去を簡潔に記述することが可能になり、経路復元などもライブラリ側で自動で管理することが可能になります。
以後のインターフェースはこれ寄せることとします。


1手のみ差分更新したビームサーチ

上述の実装はどこが律速となるのでしょうか。実はupdate関数が非常に重くなっています。
これは、AHC系の問題は最適解を出すことが困難ということもあり、今までの行動の結果として持つべき状態量が大きく、Nodeのコピーにかかる時間が無視できなくなるためです。
そのため、実際の問題を解く際はできるだけ省メモリになるようbitsetなどで状態管理が行われます。
しかし、遷移の結果として状態がすべて入れ替わるようなことは珍しく、変更される領域は全体の$1/N$や$1/N^2$であることがほとんどです。
この事実から前回の領域を使い回し差分更新することで高速化する方針が考えられます。

折りたたみ
#[derive(Clone,Default)]
struct Node{
    track_id:usize,
    refs:u8, // 何個のCandに参照されているか
}
impl Node {
    // selfからcandの遷移をしたときの新たなNode
    fn new_node(&self,cand:&Cand)->Node{
        todo!();
    }

    // selfにcandの遷移を差分更新で適用する
    fn apply(&mut self,cand:&Cand){
        todo!();
    }
}


struct BeamSearch{
    track:Vec<(usize,u8)>,
    nodes:Vec<Node>,
    free:Vec<usize>, // free[..at]:=nodesで現在使われているindex free[at..]:=使われていないindex
    at:usize,
    cands:Vec<Cand>,
}
impl BeamSearch{
    fn new(node:Node)->BeamSearch{
        let mut nodes=vec![Node::default();MAX_WIDTH*2];
        nodes[0]=node;
        
        BeamSearch{
            free:(0..nodes.len()).collect(),
            nodes,
            at:1,
            track:vec![],
            cands:vec![],
        }
    }
    
    fn enum_cands(&self,input:&Input,cands:&mut Vec<Cand>){
        for &i in &self.free[..self.at]{
            self.append_cands(input,i,cands);
        }
    }
    
    fn update<I:Iterator<Item=Cand>>(&mut self,cands:I){
        self.cands.clear();
        for cand in cands{
            self.nodes[cand.parent as usize].refs+=1;
            self.cands.push(cand);
        }

        // 使われなくなったindexを更新する
        for i in (0..self.at).rev(){
            if self.nodes[self.free[i]].refs==0{
                self.at-=1;
                self.free.swap(i,self.at);
            }
        }

        for cand in &self.cands{
            let node=&mut self.nodes[cand.parent as usize];
            node.refs-=1;
            let prev=node.track_id;
            
            let new=if node.refs==0{
                // そのNodeを必要とするCandがないなら領域を使いまわして差分更新
                node.apply(cand);
                node
            }
            else{
                // コピーして適用
                let mut new=node.new_node(cand);
                new.refs=0;
                let idx=self.free[self.at];
                self.at+=1;
                self.nodes[idx]=new;
                &mut self.nodes[idx]
            };

            self.track.push((prev,cand.op));
            new.track_id=self.track.len() as usize-1;
        }
    }
}
このように、現在使われている領域と使われていない領域を free[..at] とfree[at..]で扱うことで簡潔、かつ高速に扱うことが可能になります。

木上のビームサーチ (遷移がdp[i]からdp[i+1]へ限定される場合)

先程の方針では、一部の状態を差分更新することで高速化しましたが、特に遷移の差分更新が定数時間で済むような場合は一度もコピーをせず、毎回状態をシュミレートすることで高速になる場合があります。つまり、遷移を木構造で管理し、採用した頂点を追加、要らなくなった頂点を削除した後、木上をトラバースし候補を列挙するというアルゴリズムを用いればよいです。
その為には、木に対して頂点の追加、削除、そして全頂点の走査が定数倍高速にできるデータ構造が必要になりますが、これは二重連鎖木を用いることで実現できます1

具体的には、ノードに親、前と次の兄弟、代表の子供の4つのポインタを持たせて管理をすればよいです。(Rustはポインタを扱いづらいので配列のindexで実装しています。)

ノードのサイズが定数になるのでゼロアロケーションの実装が可能になり定数倍が高速になります。

長いので折りたたみ
#[derive(Clone,PartialEq)]
struct State{}
impl State{
    // 初期状態の生成
    fn new(input:&Input)->State{
        todo!();
    }

    // 差分更新で適用する
    fn apply(&mut self,node:&Node){
        todo!();
    }

    // 差分更新で元に戻す
    fn revert(&mut self,node:&Node){
        todo!();
    }
}


#[derive(Clone)]
struct Cand{
    op:u8,
    parent:usize,
    eval_score:i64,
    hash:u64,
}
impl Cand{
    fn raw_score(&self,input:&Input)->i64{
        todo!();
    }
    
    fn to_node(&self)->Node{
        Node{
            child:!0,
            prev:!0,
            next:!0,
            op:self.op,
            parent:self.parent,
        }
    }
}


#[derive(Clone,Default)]
struct Node{
    op:u8,
    parent:usize, // 親Node
    child:usize, // 代表の子Node
    prev:usize, // 前の兄弟Node
    next:usize, // 次の兄弟Node
}


const MAX_WIDTH:usize=1000;
const TURN:usize=100;

struct BeamSearch{
    state:State,
    leaf:Vec<usize>, // 子が存在しないNodeのindex
    next_leaf:Vec<usize>,
    nodes:Vec<Node>,
    cur_node:usize,
    free:Vec<usize>, // nodesのうち使われていないindex
}
impl BeamSearch{
    fn new(state:State,node:Node)->BeamSearch{
        const MAX_NODES:usize=MAX_WIDTH*5;
        let mut nodes=vec![Node::default();MAX_NODES];
        nodes[0]=node;
        let free=(1..MAX_NODES as usize).rev().collect();
        
        BeamSearch{
            state,nodes,free,
            leaf:vec![0],
            next_leaf:vec![],
            cur_node:0,
        }
    }
    
    // 頂点を新たに追加する
    // 代表の子Nodeの前に挿入する形で実装
    fn add_node(&mut self,cand:Cand){
        let next=self.nodes[cand.parent as usize].child;
        let new=self.free.pop().expect("MAX_NODEが足りないよ") as usize;
        if next!=!0{
            self.nodes[next as usize].prev=new;
        }
        self.nodes[cand.parent as usize].child=new;
        
        self.next_leaf.push(new);
        self.nodes[new as usize]=Node{next,..cand.to_node()};
    }

    // 既に探索済みのノードで葉のノードを再帰的に消していく
    fn del_node(&mut self,mut idx:usize){
        loop{
            self.free.push(idx);
            let Node{prev,next,parent,..}=self.nodes[idx as usize];
            assert_ne!(parent,!0,"全てのノードを消そうとしています");
            // 兄弟がいないなら親を消しに行く
            if prev&next==!0{
                idx=parent;
                continue;
            }

            if prev!=!0{
                self.nodes[prev as usize].next=next;
            }
            else{
                self.nodes[parent as usize].child=next;
            }
            if next!=!0{
                self.nodes[next as usize].prev=prev;
            }
            
            break;
        }
    }

    // dfsで木を走査
    // 一本道の場合戻る必要はないのでそれをsingleで管理
    fn dfs(&mut self,input:&Input,cands:&mut Vec<Cand>,single:bool){
        if self.nodes[self.cur_node].child==!0{
            self.append_cands(input,self.cur_node,cands);
            return;
        }

        let node=self.cur_node;
        let mut child=self.nodes[node].child;
        let next_single=single&(self.nodes[child as usize].next==!0);

        // let prev_state=self.state.clone();
        loop{
            self.cur_node=child as usize;
            self.state.apply(&self.nodes[child as usize]);
            self.dfs(input,cands,next_single);

            if !next_single{
                self.state.revert(&self.nodes[child as usize]);
                // assert!(prev_state==self.state);
            }
            child=self.nodes[child as usize].next;
            if child==!0{
                break;
            }
        }
        
        if !next_single{
            self.cur_node=node;
        }
    }

    // 走査の非再帰実装
    fn no_dfs(&mut self,input:&Input,cands:&mut Vec<Cand>){
        // 1本道でなくなるまで潜る
        loop{
            let Node{next,child,..}=self.nodes[self.cur_node];
            if next==!0 || child==!0{
                break;
            }
            self.cur_node=child as usize;
            self.state.apply(&self.nodes[self.cur_node]);
        }

        let root=self.cur_node;
        loop{
            let child=self.nodes[self.cur_node].child;
            if child==!0{
                self.append_cands(input,self.cur_node,cands);
                loop{
                    if self.cur_node==root{
                        return;
                    }
                    let node=&self.nodes[self.cur_node];
                    self.state.revert(&node);
                    if node.next!=!0{
                        self.cur_node=node.next as usize;
                        self.state.apply(&self.nodes[self.cur_node]);
                        break;
                    }
                    self.cur_node=node.parent as usize;
                }
            }
            else{
                self.cur_node=child as usize;
                self.state.apply(&self.nodes[self.cur_node]);
            }
        }
    }

    fn enum_cands(&mut self,input:&Input,cands:&mut Vec<Cand>){
        // self.dfs(input,cands,true);
        self.no_dfs(input,cands);
    }

    fn update<I:Iterator<Item=Cand>>(&mut self,cands:I){
        self.next_leaf.clear();
        for cand in cands{
            self.add_node(cand);
        }

        for i in 0..self.leaf.len(){
            let n=self.leaf[i];
            // 子が存在しないノードは無駄なので消す
            if self.nodes[n as usize].child==!0{
                self.del_node(n);
            }
        }

        std::mem::swap(&mut self.leaf,&mut self.next_leaf);
    }

    fn restore(&self,mut idx:usize)->Vec<u8>{
        let mut ret=vec![];
        loop{
            let Node{op,parent,..}=self.nodes[idx as usize];
            if op==!0{
                break;
            }
            ret.push(op);
            idx=parent;
        }
        
        ret.reverse();
        ret
    }

    // self.stateがself.nodes[idx]のノードが表す状態になっている
    // self.nodes[idx]からのCandをcandsに積む
    fn append_cands(&self,input:&Input,idx:usize,cands:&mut Vec<Cand>){
        let node=&self.nodes[idx];
        assert_eq!(node.child,!0);

        todo!();
    }
}

木上のビームサーチ (遷移先が限定されない場合)

今までの実装は全てdp[i]の遷移先がdp[i+1]のみである前提の設計となっていました。そのため、dp[i]からdp[i+2]やdp[i+3]への遷移を考える場合は定数倍を犠牲にした別の実装を用いる必要があります。
流石にすべての場合を解説するのは大変なのでここでは実装が一番大変な木上のビームサーチの遷移先が限定されない実装のみを解説します。

限定された場合の実装ではノードは一生使わないノード、次走査するノードの二種類のみでしたが、さらに次の走査には使わないが今後使うかもしれないノードが新たに増えます。

またそれに伴い、一生使わないノードの条件も少し変わり、子供が存在せず、かつ、参照している候補も存在しないことが条件となります。

また、無駄な走査を無くすために木の走査の前に次に走査したいノードを選択するよう変更しました。

長いので折りたたみ
#[derive(Clone,Default)]
struct Node{
    op:u8,
    parent:uint,
    child:uint,
    prev:uint,
    next:uint,
    refs:u8, // 何個のCandに参照されているか
    valid:u16, // このノードが有効かどうか、BeamSearch::atと等しいと有効
}


const MAX_WIDTH:usize=1000;
const TURN:usize=100;

struct BeamSearch{
    state:State,
    nodes:Vec<Node>,
    que:Vec<uint>, // 消してもいいノード
    cur_node:usize,
    free:Vec<uint>,
    at:u16,
}
impl BeamSearch{
    fn new(state:State,node:Node)->BeamSearch{
        const MAX_NODES:usize=MAX_WIDTH*5;
        assert!(MAX_NODES<uint::MAX as usize,"uintのサイズが足りないよ");
        let mut nodes=vec![Node::default();MAX_NODES];
        nodes[0]=node;
        let free=(1..MAX_NODES as uint).rev().collect();

        BeamSearch{
            state,nodes,free,
            que:Vec::with_capacity(MAX_WIDTH),
            cur_node:0,
            at:0,
        }
    }
    
    fn add_node(&mut self,cand:Cand){
        let next=self.nodes[cand.parent as usize].child;
        let new=self.free.pop().expect("MAX_NODEが足りないよ") as uint;
        if next!=!0{
            self.nodes[next as usize].prev=new;
        }
        self.nodes[cand.parent as usize].child=new;
        
        self.nodes[new as usize]=Node{next,..cand.to_node()};
        self.retarget(new);
    }

    fn del_node(&mut self,mut idx:uint){
        assert_eq!(self.nodes[idx as usize].refs,0);
        loop{
            self.free.push(idx);
            let Node{prev,next,parent,..}=self.nodes[idx as usize];
            assert_ne!(parent,!0,"全てのノードを消そうとしています");
            
            self.nodes[parent as usize].refs-=1;

            if prev&next==!0 && self.nodes[parent as usize].refs==0{
                idx=parent;
                continue;
            }

            if prev!=!0{
                self.nodes[prev as usize].next=next;
            }
            else{
                self.nodes[parent as usize].child=next;
            }
            if next!=!0{
                self.nodes[next as usize].prev=prev;
            }
            
            break;
        }
    }

    fn dfs(&mut self,input:&Input,turn:usize,cands:&mut Vec<Vec<Cand>>,single:bool){
        if self.nodes[self.cur_node].child==!0{
            let cnt=self.append_cands(input,turn,self.cur_node,cands);
            if cnt==0{
                self.que.push(self.cur_node as uint);
            }
            self.nodes[self.cur_node].refs+=cnt;
            return;
        }

        let node=self.cur_node;
        let mut child=self.nodes[node].child;
        let next_single=single&(self.nodes[child as usize].next==!0);

        // let prev_state=self.state.clone();
        'a: loop{
            // 次の有効な子供を見つける
            while self.nodes[child as usize].valid!=self.at{
                child=self.nodes[child as usize].next;
                if child==!0{
                    break 'a;
                }
            }
            
            self.cur_node=child as usize;
            self.state.apply(&self.nodes[child as usize]);
            self.dfs(input,turn,cands,next_single);

            if !next_single{
                self.state.revert(&self.nodes[child as usize]);
                // assert!(prev_state==self.state);
            }

            child=self.nodes[child as usize].next;
            if child==!0{
                break;
            }
        }
        
        if !next_single{
            self.cur_node=node;
        }
    }

    // dfsの非再帰実装
    fn no_dfs(&mut self,input:&Input,turn:usize,cands:&mut Vec<Vec<Cand>>){
        loop{
            let Node{next,child,..}=self.nodes[self.cur_node];
            if next==!0 || child==!0{
                break;
            }
            self.cur_node=child as usize;
            self.state.apply(&self.nodes[self.cur_node]);
        }

        assert!(self.nodes[self.cur_node].valid==self.at);
        let root=self.cur_node;

        loop{
            assert!(self.nodes[self.cur_node].valid==self.at);
            let mut child=self.nodes[self.cur_node].child;

            if child==!0{
                let cnt=self.append_cands(input,turn,self.cur_node,cands);
                if cnt==0{
                    self.que.push(self.cur_node as uint);
                }
                self.nodes[self.cur_node].refs+=cnt;

                'a: loop{
                    if self.cur_node==root{
                        return;
                    }
                    let node=&self.nodes[self.cur_node];
                    self.state.revert(&node);
                    let mut next=node.next;
                    loop{
                        if next==!0{
                            self.cur_node=node.parent as usize;
                            break;
                        }
                        if self.nodes[next as usize].valid==self.at{
                            self.cur_node=next as usize;
                            self.state.apply(&self.nodes[self.cur_node]);
                            break 'a;
                        }
                        next=self.nodes[next as usize].next;
                    }
                }
            }
            else{
                while self.nodes[child as usize].valid!=self.at{
                    child=self.nodes[child as usize].next;
                }
                self.cur_node=child as usize;
                self.state.apply(&self.nodes[self.cur_node]);
            }
        }
    }

    fn enum_cands(&mut self,input:&Input,turn:usize,cands:&mut Vec<Vec<Cand>>){
        assert_eq!(self.nodes[self.cur_node].valid,self.at);
        self.que.clear();
        // self.dfs(input,turn,cands,true);
        self.no_dfs(input,turn,cands);
    }

    // 走査するべきノードを更新する
    fn retarget(&mut self,mut idx:uint){
        while self.nodes[idx as usize].valid!=self.at{
            self.nodes[idx as usize].valid=self.at;
            if idx as usize==self.cur_node{
                break;
            }
            idx=self.nodes[idx as usize].parent;
        }
    }

    fn update<I:Iterator<Item=(Cand,bool)>>(&mut self,cands:I){
        self.at+=1;
        for i in 0..self.que.len(){
            self.del_node(self.que[i]);
        }
        
        for (cand,f) in cands{
            let node=&mut self.nodes[cand.parent as usize];
            if f{
                self.add_node(cand);
            }
            else{
                node.refs-=1;
                if node.refs==0{
                    self.del_node(cand.parent);
                }
            }
        }
    }

    fn restore(&self,mut idx:usize)->Vec<u8>{
        let mut ret=vec![];
        loop{
            let Node{op,parent,..}=self.nodes[idx];
            if parent==!0{
                break;
            }
            ret.push(op);
            idx=parent as usize;
        }
        
        ret.reverse();
        ret
    }

    // 子供の個数を返す
    fn append_cands(&self,input:&Input,turn:usize,idx:usize,cands:&mut Vec<Vec<Cand>>)->u8{
        let node=&self.nodes[idx];
        assert_eq!(node.child,!0);
        assert_eq!(node.valid,self.at);

        todo!();
    }

    fn solve(&mut self,input:&Input)->Vec<u8>{
        use std::cmp::Reverse;
        let M=MAX_WIDTH;
    
        let mut cands=(0..=TURN).map(|_|Vec::<Cand>::with_capacity(MAX_WIDTH*4)).collect::<Vec<_>>();
        let mut set=rustc_hash::FxHashSet::default();
        for t in 0..TURN{
            if t!=0{
                let M0=(M as f64*2.).round() as usize;
                let cands=&mut cands[t];
                if cands.len()>M0{
                    cands.select_nth_unstable_by_key(M0,|a|Reverse(a.eval_score));
                }
                
                let len=M0.min(cands.len());
                cands[..len].sort_unstable_by_key(|a|Reverse(a.eval_score));

                set.clear();
                let mut total=0;

                self.update(cands.drain(..).map(|cand|{
                    let f=total<M && set.insert(cand.hash);
                    total+=f as usize;
                    (cand,f)
                }));
            }
            
            self.enum_cands(input,t,&mut cands);
        }
    
        let best=cands.last().unwrap().iter().max_by_key(|a|a.raw_score(input)).unwrap();
        eprintln!("score = {}",best.raw_score(input));
        let mut ret=self.restore(best.parent as usize);
        ret.push(best.op);
    
        ret
    }
}


struct Input{}

ノードの選択は初期化配列の要領で実装して高速化しました。


その他の高速化

ソートや枝刈り

上位M個を選択する手法は、全体の候補数に対するMが十分に小さいのであればpartial_sort(HeapSort)2、でなければselect_nth_unstableが高速です。

また、重複除去のため、スコアを予め降順に並べたい場合でも上位2M個程を選択してからsort_unstableをするとほぼ挙動を変化させずにより高速な選択が可能になります。

他にも、枝刈りのために採用されるスコアの下界を見積もりたいときは候補が2M溜まるごとに下位半分を消していけばいいです。
ヒープの構築は線形時間なので毎回厳密な判定もオーダーを変えずにやることは可能ではあるのですが、ヒープは定数倍が重いので...。

重複除去

ビームサーチで重複除去のために用いるhashは主にZobristHashingで生成されます。そのため、適切な乱数が用いられていれば既にhash値がランダムな値になっていることが期待できHashSet内部で更にhash関数を適用する必要性がなくなります。幸い、RustやC++のHashSetは自前のhash関数を指定することができ、何もしないhash関数を指定することが容易となっております。

折りたたみ
use std::collections::{HashMap,HashSet};
use core::hash::BuildHasherDefault;
use core::hash::Hasher;

#[derive(Default)]
pub struct NopHasher {
    hash: u64,
}
impl Hasher for NopHasher {
    fn write(&mut self, _: &[u8]) {
        panic!();
    }

    #[inline]
    fn write_u64(&mut self, n: u64) {
        self.hash=n;
    }

    #[inline]
    fn finish(&self) -> u64 {
        self.hash
    }
}

pub type NopHashMap<K, V> = HashMap<K, V, BuildHasherDefault<NopHasher>>;
pub type NopHashSet<V> = HashSet<V, BuildHasherDefault<NopHasher>>;
上述の実装が使えないような場合でも、Rustの標準のhash関数は重いので`rustc_hash::FxHashSet`などを用いたほうがいいです。

また、hash値をu16で管理すればHashSetを使うまでもなくbool[65536]で管理可能なのですが、自分の経験では衝突がスコアに影響が出る範囲で発生することが多かったので使っておりません、が、たくさん衝突しても構わない場合は有用だと思われます。

省メモリ化

indexはu16で持つようにします。結構効きます。
一度に持つ状態数が65536を超える場合はu32を指定してください。


使用例

実際にこのビームサーチライブラリを用いてAtCoder上の問題を解いてみます。解説はかなりお気持ちが多いです(ヒューリスティックだから多少は...)。

競技プログラミングの鉄則 演習問題集 A49 "Heuristic 2"

以下に問題文を引用します:

長さ20の配列があり、最初すべての要素は0です。
あなたは配列に対して100回以下の操作のどちらかを行います。

操作A: $X_{P_i}$,$X_{Q_i}$,$X_{R_i}$に$+1$を加算する
操作B: $X_{P_i}$,$X_{Q_i}$,$X_{R_i}$に$-1$を加算する

各操作が終わった後、「$X_j=0$となる$j$の個数」だけスコアが加算されます。
スコアを最大化してください。

一般に、どんな順番で操作しても結局同じ状態になる(別の言い方をすると独立な変数が多い)のであれば、多様性を確保するのが難しく、ビームサーチは有効になりにくいです。
今回の問題は毎ターンでスコアが変化していくので独立な変数は多くないです。(スコアに影響するのが最終形だけなら独立な変数が多いので焼き鈍しましょう。)

評価関数は安直にやると現在までのスコアですが、最大化したいのは目先のスコアではなく100ターン目のスコアなのでスコアに先読みを入れます。
table[i][j][k] := iターン目でj番目の要素の絶対値がkであるときそこからj番目の要素だけを考えて最適に動かしたときの今後得られるスコアの最大値
state[i] := 現在の状態のi番目の値
として、ターンtでのスコアを今までのスコア + Σ table[t][k][abs(state[k])]としました。

また、重複除去は乱数列randを定義しておいて、Σ rand[k]*state[k]をハッシュ値として使いました。
すべての要素の符号が反転されているのも同一の局面なのですが、面倒なので最初の操作をAに限定するだけで特に対策はしていません。

また、細かいことですが、入力を固定長配列に入れておけばイテレータがループアンローリングがされるので楽に高速化できます。

i8[20]で状態が決まるのでコピーコストは非常に小さいです。そのため、実装は普通のビームサーチを使っても問題ないでしょう。
以上のことを踏まえて実装の提出が以下です。

実装(Rust、M=100000、776ms)

1手だけ差分更新したビームサーチを使うと遅くなりました。この程度だったら何も考えずコピーしたほうが速いということだと思います。

実装(Rust、M=100000、948ms)

ちなみにビーム幅を10万に設定していますが、この方法だとビーム幅2000程度で最適解(だと思われる値)が出るので完全にオーバースペックです。(何ならハッシュの衝突のリスクが増えるのでビーム幅を増やしたほうが最適解が出にくい...。)

第2回 RCO日本橋ハーフマラソン 予選 A "ゲーム実況者Xの挑戦"

問題概要です:

50*50のマップで各マスが壁、罠、コイン、スタートのいずれかであるマップが100個与えられる。
この中から8個のマップを選び2500ターン上下左右のいずれかにプレイヤーを動かす。
また、壁の方向に移動したもしくは、罠にかかったプレイヤーの位置は変わらない。
ただし、8個のマップのプレイヤーは連動して動く。
コインは一度取ったらなくなるので得られるコインを最大化してください。

簡単のために8個のマップは100個のマップから到達可能なコインの個数が多い順に貪欲に選ぶことにします。

行動は罠にかからないもののみに限定します。途中で罠にかかるというのがあまりいい行動ではなく、また、罠にかからないと限定すればより高速になるため、このようにしました。

評価関数は先読みを考えると、今までの取ったコインの他に、進んでいる先にどれくらいコインがあるのかを知りたい気分になります。そこでマスごとにそのマスに何回訪れたことがあるのかをカウントしていけば訪れた回数が少ないほうがその先にコインがありそうです。ということで、評価関数にはプレイヤーの位置に対し、その位置に訪れた回数分だけペナルティを入れました。

重複除去はプレイヤーの位置のみでのZobristHashingと、選択した状態でプレイヤーの位置が偏らないよう総和が閾値を超えないなら採用という方法で多様性を確保しました。

let set=NopHashSet::default();
let th=M/25;
let mut cnt=[0u16;N*N*K];
self.update(cands.iter().filter(|cand|{
    (0..K).map(|i|cnt[cand.pos[i]]).sum::<usize>()<=th
    && set.insert(cand.hash)
    && {
        for i in 0..K{
            cnt[cand.pos[i]]+=1;
        }
        true
    }
}).take(M).cloned());

また本問はターン数が決まっておりビーム幅をこのままにしたとき終了まで必要な時間が予測しやすく、時間を使い切るためにこのような式で$n$ターンに一回ビーム幅を更新しています。
$$ M_{t+1} = clamp(M_t\sqrt {\frac{nr}{ae}} , M_{min}, M_{max}) $$
ただし、$a$は残りのターン数、$r$は残り時間、$e$は$n$ターンにかかった時間です。
平方根、clampは急激な変化を抑えるために入れています。

この問題で使うビームサーチですが、選ばれるK個のマップに存在するコインの合計は平均で14000個程です。一応bitsetで管理すればコピーコストは小さくなりますが、その分ランダムアクセスは遅くなります。しかも、そのマスに何回到達したことがあるかの情報も入れるとなるとコピーコストはかなり重くなります。

そこで実装は木上のビームサーチを用いました。

提出(Rust、M=5000(動的調節あり)、3997ms)

AtCoder Heuristic Contest 021 "Pyramid Sorting"

問題概要です:

30段のピラミッドがあり、各マスには1から465までの数が重複なく書き込まれている。
隣接するマスで書かれている数を交換することができる。
すべてのマスについて上のマスにかかれている数が自身に書かれている数よりも小さくなる状態にするための操作回数を最小化してください。

すべての操作を考慮すると大変なので、揃っていない数で一番小さいものを揃えることを考えます。ですが、一番小さい数を動かすだけだと考えられる行動がかなり少なくなるので、一番小さい数から距離が2以下の数を一番小さい数に向かって動かすことを考えます。
つまり、2回の行動も1手と考えることにします。3
前回と同じ行動は元に戻ってしまうので飛ばします。

評価関数は1200*揃った数+Σ段数*ボールの値を使い、重複除去は揃っているマスと現在揃えようとしている数があるマスの位置で区別したZobristHasingを使います。

また、細かいヒューリスティックですがスコアにターンが進むごとに徐々に減少する微小な乱数を加え多スタートをしました。
乱数は多様性確保のためでターンが進むごとに減少させるのは序盤はより多様な探索をしたほうが良いということでしょうか。
多スタートは構造体にまとめておくと非常に楽で、SuccessiveHalvingなどへの対応も簡単に実装できるようになるのでオススメです。

木上のビームサーチで遷移先が限定されないやつを用いて実装したのが以下です:

提出(Rust、M=300(回多スタート)、1979ms)


おわりに

本記事では非常に高速なビームサーチライブラリの実装を紹介しました。本記事の方法では個々の問題のビームサーチに共通する部分を抜き出し、問題固有の状態遷移を記述するだけで非常に高速な探索が可能になります。これによるメリットは短期、長期ともに極めて大きいと感じています。

自分も当面の間ビームサーチはこれを使っていこうと思っているので、定期的に整備していく予定です。

この記事を作るにあたって多くの先人達の記事、実装、tweetなどを参考にしています。ありがとうございました。
誤字指摘、マサカリ、質問などあればコメントか@rho__oまでお願いします。




C++実装は、Topcoderで必要になったら作ろうと思っています。(その前に誰かが作ってくれそう)

  1. 他にもEularTourをLinkedListで管理するという手法も存在します(Rafbillさんによる実装)。筆者が試した限りでは二重連鎖木による実装と時間空間ともに同等程度の性能でした。ただ、EularTourによる実装は途中の木の変形を考えにくい都合上、無駄なノードの検索や探索順の変更などが難しく一部の問題では二重連鎖木実装よりもパフォーマンスが落ちてしまったので今回のライブラリでは採用していません。無論、EularTour実装の方が良い場合もありそうなので知っておくに越したことはないと思います。

  2. 現状RustにはHeapSort実装のpartial_sortが存在しないため作りました(ここです)。in-placeのためstd::collections::BinaryHeapは使用していませんが、内部実装は同様のものを拝借しているため高速に動作します。

  3. 実は操作を数手だけまとめる遷移はあんまり良くなかったりします。(ビームサーチによる枝刈りが上手く刺ささらず低速になるので。)体力とかMPとかであれば素直にdp[i]からdp[i+2]への遷移が存在しますが、問題設定が単純なやつで都合のいいやつがなかったので....。

22
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
22
16