35
23

高速なビームサーチが欲しい!!!

Last updated at Posted at 2022-07-11

この記事はアルゴリズム強化月間の一環として書かれた記事です。

はじめに

こんにちは!
rhooというアカウント名で競技プログラミングをやっている者です。
半年ほど前からヒューリスティック系のコンテストにも手を出し始めました。

この記事では私がゲーム実況者xの挑戦という過去問を解くときに使った手法の解説記事となります。

要約

ビームサーチの状態の管理を差分を持った木構造で管理するとコピーコストが発生しなくなり高速になります。

前提知識

ビームサーチについての基礎的な知識と参照カウントベースのスマートポインタ(c++でのshared_ptr,weak_ptr)などについての知識を要求します。
この記事では詳しくは説明しないので、分からない場合はこれらの記事が参考になるかもしれません。

アイディア

ナイーブなビームサーチでは上位M個の状態をコピーして行きますが、この実装だとコピーコストは無視できないボトルネックとなります。
そのため、できるだけコピーコストを小さくするために、状態を使いまわしたり64bit整数を16bitにしたりといった工夫が有効になります。

しかし、問題の中にはコピーコストは大きいが状態を差分更新するコストは小さいといったようなものも存在しています。
その場合、状態をコピーするのではなく毎回知りたい状態をシュミレーションするといった方法がより高速になるかもしれません。
特に操作をこの図のように木構造で管理すればDFSを利用して効率的にシュミレーションすることができそうです。

 
しかし、ビームサーチを実装するのにはこの実装だけでは足りません。
例えばこのように木構造をビームサーチの挙動に合わせて更新していったときに探索しなくてもいいノードが現れてしまいます。

そのため、このようなノードを高速に判定する必要があります。

探索したいノードかどうかの判定を高速に

本題です。
探索したいノードかどうかの判定は普通に実装しようとするとかなり大変になってしまいます。
しかし、参照カウント方式のスマートポインタを用いるとほとんど実装せずに判定ができてしまいます。
具体的には木構造の親への参照を強参照、子への参照を弱参照で持つことにより、参照カウントを調べるだけで簡単に判定ができてしまいます。

最初はcur-listが現在のノードを持っており、参照カウントは図のようになります。

新しいノードがnext-listというところに置かれます。

next-listcur-listになるとcur-listが参照していたノードの参照カウントが一つ減り探索する必要のないノードはdropされます。
そして、探索するべきかどうかの判定はweak-ptrが有効かどうかを見るだけで終わってしまいます。

さらなる工夫

更に上記の説明に加えある工夫を2つすることによって、速度と多様性がより良くなります。

1.速度の工夫

このビームサーチはDFSが終わったときは状態を元に戻す必要がありますが、この図のようなことになっている場合は状態を元に戻す必要はありません。

ビームサーチではこの形になることが多々あるのでこの高速化はかなり効果があります。

2.多様性の工夫

ビームサーチではできるだけ局所的な状態ばかりを保持しているようなことは避けたいです。
とくに、スコアが同じ状態同士であれば、できるだけ離れた状態同士を選んだほうがパフォーマンスは良くなります。(逆に言えば何も考えずに安定ソートをすると悲惨なことになる。)
ナイーブなビームサーチではソートするときに乱数を持たせると言った工夫が使われますが、乱数では確実に離れている状態を選び取ることを保証はしてくれません。
そこで、私はビームサーチを実装するときに優先度というある親の何番目の子かを表す数を組み合わせて、それをもとにソートをします。
そうすることで、ソートしたときに離れた状態同士がまんべんなく分布するようになります。

この実装だと一つ前の親だけでなく、N個前の親がわかるので、優先度をN個前の親がという風に拡張することができ、より精度が上がります。

実装

上記の説明を実装をするとこうなります。

実装(rust)
use std::rc::*;
use std::cell::UnsafeCell;
use std::cmp::Reverse;


struct State{}
impl State{
    // 初期状態の生成
    fn new()->State{
        todo!();
    }

    fn score(&self)->usize{
        todo!();
    }

    fn hash(&self)->u64{
        todo!();
    }

    // スコアとハッシュの差分計算
    // 状態は更新しない
    fn try_apply(&self,op:usize,score:usize,hash:u64)->(usize,u64){
        todo!();
    }

    // 状態を更新する
    // 元の状態に戻すための情報を返す
    fn apply(&mut self,op:usize)->u128{
        todo!();
    }
    
    // applyから返された情報をもとに状態を元に戻す
    fn back(&mut self,backup:u128){
        todo!();
    }
}


struct Kouho{
    op:usize,
    parent:Rc<Node>,
    score:usize,
    hash:u64,
    p:usize // 優先度(複数もたせたほうが良い場合があるかもしれない。)
}


struct Node{
    parent:Option<(usize,Rc<Node>)>, // 操作、親への参照
    // 速度のためにUnsafeCellを使っているがRefCellのほうが安全
    child:UnsafeCell<Vec<(usize,Weak<Node>)>>, // 操作、子への参照
    score:usize,
    hash:u64
}


// 多スタート用に構造体にまとめておくと楽
struct Tree{
    state:State,
    node:Rc<Node>,
    rank:usize
}
impl Tree{
    // 注意: depthは深くなっていくごとに-1されていく
    fn dfs(&mut self,next_states:&mut Vec<Kouho>,one:bool,p:&mut usize,depth:usize){
        if depth==0{
            let score=self.node.score;
            let hash=self.node.hash;

            // 検算
            // assert_eq!(score,self.state.score());
            // assert_eq!(hash,self.state.hash());

            // 次の操作を列挙
            for op in 0..4{
                let (next_score,next_hash)=self.state.try_apply(op,score,hash);
                next_states.push(
                    Kouho{
                        op,
                        parent:self.node.clone(),
                        score:next_score,
                        hash:next_hash,
                        p:*p
                    }
                );
                *p+=1;
            }
        }
        else{
            let node=self.node.clone();
            let child=unsafe{&mut *node.child.get()};
            // 有効な子だけにする
            child.retain(|(_,x)|x.upgrade().is_some());

            let next_one=one&(child.len()==1);
            
            // 定数調整の必要あり
            if depth==5{
                *p=0;
            }
            self.rank+=1;
            
            for (op,ptr) in child{
                self.node=ptr.upgrade().unwrap();
                let backup=self.state.apply(*op);
                self.dfs(next_states,next_one,p,depth-1);
                
                if !next_one{
                    self.state.back(backup);
                }
            }
            
            if !next_one{
                self.node=node.clone();
                self.rank-=1;
            }
        }
    }
}


fn beam()->Vec<usize>{
    const TURN:usize=1000;
    const M:usize=100; // ビーム幅

    let mut tree={
        let state=State::new();
        let score=state.score();
        let hash=state.hash();
        Tree{
            state,
            node:Rc::new(
                Node{
                    parent:None,
                    child:UnsafeCell::new(vec![]),
                    score,
                    hash
                }
            ),
            rank:0
        }
    };

    let mut cur=vec![tree.node.clone()];
    let mut next_states=vec![];

    let mut set=rustc_hash::FxHashSet::default();
    
    for i in 0..TURN{
        next_states.clear();
        tree.dfs(&mut next_states,true,&mut 0,i-tree.rank);

        if i+1!=TURN{
            // 上位M個を残す
            if next_states.len()>M{
                next_states.select_nth_unstable_by_key(M,|Kouho{score,p,..}|(Reverse(*score),*p));
                next_states.truncate(M);
            }

            cur.clear();
            set.clear();
            for Kouho{op,parent,score,hash,..} in &next_states{
                // 重複除去
                if set.insert(*hash){
                    let child=unsafe{&mut *parent.child.get()};
                    let child_ptr=Rc::new(
                        Node{
                            parent:Some((*op,parent.clone())),
                            child:UnsafeCell::new(vec![]),
                            hash:*hash,
                            score:*score
                        }
                    );
                    child.push((*op,Rc::downgrade(&child_ptr)));
                    cur.push(child_ptr);
                }
            }
        }
    }

    // 最良の状態を選択
    let Kouho{op,parent:mut ptr,score,..}=next_states.into_iter()
        .max_by_key(|Kouho{score,..}|*score).unwrap();
    
    let mut ret=vec![op];
    eprintln!("score: {}",score);
    eprintln!("rank: {}",TURN-tree.rank);

    // 操作の復元
    while let Some((op,parent))=ptr.parent.clone(){
        ret.push(op);
        ptr=parent.clone();
    }

    ret.reverse();
    ret
}

c++での実装も置いておきますが、筆者はあまりc++が得意ではないので バグっている and 遅い 可能性があります。
ご容赦ください...

実装(c++)
#include "bits/stdc++.h"
using namespace std;

using ull=unsigned long long;


struct State{
    // 初期状態の生成
    State(){
        // TODO
    }

    ull score(){
        // TODO
    }

    ull hash(){
        // TODO
    }

    // スコアとハッシュの差分計算
    // 状態は更新しない
    pair<ull,ull> try_apply(ull op,ull score,ull hash){
        // TODO
    }

    // 状態を更新する
    // 元の状態に戻すための情報を返す
    ull apply(ull op){
        // TODO
    }
    
    // applyから返された情報をもとに状態を元に戻す
    ull back(ull backup){
        // TODO
    }
};

struct Node;

struct Kouho{
    ull op;
    shared_ptr<Node> parent;
    ull score;
    ull hash;
    ull p; // 優先度(複数もたせたほうがいい場合があるかもしれない。)
};

using Parent=optional<pair<ull,shared_ptr<Node>>>;
using Children=vector<pair<ull,weak_ptr<Node>>>;

struct Node{
    Parent parent; // 操作、親への参照
    Children child; // 操作、子への参照
    ull score;
    ull hash;
    
    Node(Parent parent,Children child,ull score,ull hash):
        parent(parent),child(child),score(score),hash(hash){}
};


// 多スタート用に構造体にまとめておくと楽
struct Tree{
    State state;
    shared_ptr<Node> node;
    ull rank;
    
    // 注意: depthは深くなっていくごとに-1されていく
    void dfs(vector<Kouho>& next_states,bool one,ull& p,ull depth){
        if(depth==0){
            ull score=node->score;
            ull hash=node->hash;

            // 検算
            // assert(score==state.score());
            // assert(hash==state.hash());

            // 次の操作を列挙
            for(ull op=0;op<4;++op){
                auto [next_score,next_hash]=state.try_apply(op,score,hash);
                next_states.emplace_back(Kouho{op,node,next_score,next_hash,p});
                p+=1;
            }
        }
        else{
            auto node_backup=node;
            auto child=&node_backup->child;
            // 有効な子だけにする
            child->erase(remove_if(child->begin(),child->end(),[](pair<ull,weak_ptr<Node>>& x){return x.second.expired();}),child->end());

            bool next_one=one&child->size()==1;
            
            // 定数調整の必要あり
            if(depth==5){
                p=0;
            }
            ++rank;

            for(const auto& [op,ptr]:*child){
                node=ptr.lock();
                ull backup=state.apply(op);
                dfs(next_states,next_one,p,depth-1);
                
                if(!next_one){
                    state.back(backup);
                }
            }
            
            if(!next_one){
                node=node_backup;
                --rank;
            }
        }
    }
};

vector<ull> beam(){
    constexpr ull TURN=1000;
    constexpr ull M=100; // ビーム幅

    State state;
    ull score=state.score();
    ull hash=state.hash();

    Tree tree{move(state),shared_ptr<Node>(new Node(Parent(),Children(),score,hash)),0};

    vector<shared_ptr<Node>> cur{tree.node};
    vector<Kouho> next_states;

    unordered_set<ull> set;
    
    for(ull i=0;i<TURN;++i){
        next_states.clear();
        ull tmp=0;
        tree.dfs(next_states,true,tmp,i-tree.rank);

        if(i+1!=TURN){
            // 上位M個を残す
            if(next_states.size()>M){
                nth_element(next_states.begin(),next_states.begin()+M,next_states.end(),[](Kouho& a,Kouho& b){
                    if(a.score==b.score){
                        return a.p>b.p;
                    }
                    else{
                        return a.score>b.score;
                    }
                });
                next_states.erase(next_states.begin()+M,next_states.end());
            }

            cur.clear();
            set.clear();
            for(const auto&[op,parent,next_score,next_hash,p]:next_states){
                // 重複除去
                if(set.count(hash)==0){
                    set.insert(hash);
                    auto child_ptr=shared_ptr<Node>(new Node(Parent({op,parent}),Children(),score,hash));
                    parent->child.emplace_back(op,weak_ptr<Node>(child_ptr));
                    cur.emplace_back(child_ptr);
                }
            }
        }
    }

    // 最良の状態を選択
    ull arg_max=-1;
    ull max=0;
    for(ull i=0;i<next_states.size();++i){
        if(next_states[i].score<=max){
            max=next_states[i].score;
            arg_max=i;
        }
    }
    auto [op,ptr,best_score,_,__]=next_states[arg_max];

    vector<ull> ret{op};
    cerr<<"score: "<<score<<endl;
    cerr<<"rank: "<<TURN-tree.rank<<endl;

    // 操作の復元
    while(ptr->parent.has_value()){
        auto [op,parent]=ptr->parent.value();
        ret.emplace_back(op);
        ptr=parent;
    }

    reverse(ret.begin(),ret.end());
    return ret;
}

追記(2023/2/16)
当時実装がだるそうだなと思って放置していたスマートポインタを使わない実装を最近になって実装しました。木構造を二重連鎖木で管理することでヒープアロケーションが定数回しか行われず、他の処理も単純になります。

use std::cmp::Reverse;

struct State{}
impl State{
    fn new()->State{
        todo!();
    }

    fn score(&self)->i64{
        todo!();
    }

    fn hash(&self)->usize{
        todo!();
    }

    fn try_apply(&self,op:usize,score:i64,hash:usize)->(i64,usize){
        todo!();
    }

    fn apply(&mut self,op:usize){
        todo!();
    }
    
    fn revert(&mut self,op:usize){
        todo!();
    }
}


#[derive(Clone)]
struct Cand{
    op:usize,
    parent:usize,
    score:i64,
    hash:usize
}


#[derive(Clone)]
struct Node{
    op:usize,
    parent:usize,
    child:usize,
    prev:usize,
    next:usize,
    score:i64,
    hash:usize
}


struct Tree{
    state:State,
    latest:usize,
    nodes:Vec<Node>, // nodes[latest..]が最新のノード
    cur_node:usize
}
impl Tree{
    fn add_node(&mut self,op:usize,parent:usize,score:i64,hash:usize){
        let next=self.nodes[parent].child;
        if next!=!0{
            self.nodes[next].prev=self.nodes.len();
        }
        self.nodes[parent].child=self.nodes.len();

        self.nodes.push(Node{op,parent,child:!0,prev:!0,next,score,hash});
    }

    fn del_node(&mut self,mut idx:usize){
        loop{
            let Node{prev,next,parent,..}=self.nodes[idx];
            assert_ne!(parent,!0);
            if prev&next==!0{
                idx=parent;
                continue;
            }

            if prev!=!0{
                self.nodes[prev].next=next;
            }
            else{
                self.nodes[parent].child=next;
            }
            if next!=!0{
                self.nodes[next].prev=prev;
            }
            
            break;
        }
    }
    
    fn resotre(&self,mut idx:usize)->Vec<usize>{
        let mut ret=vec![];

        loop{
            let Node{op,parent,..}=self.nodes[idx];
            if op==!0{
                break;
            }
            ret.push(op);
            idx=parent;
        }
        
        ret.reverse();
        ret
    }

    fn update<I:Iterator<Item=Cand>>(&mut self,cands:I){
        let len=self.nodes.len();
        for Cand{op,parent,score,hash,..} in cands{
            self.add_node(op,parent,score,hash);
        }

        for i in self.latest..len{
            if self.nodes[i].child==!0{
                self.del_node(i);
            }
        }
        self.latest=len;
    }

    fn dfs(&mut self,cands:&mut Vec<Cand>,single:bool){
        let node=&self.nodes[self.cur_node];
        if node.child==!0{
            // assert_eq!(node.score,self.state.score());
            // assert_eq!(node.hash,self.state.hash());
            
            for op in 0..4{
                let (score,hash)=self.state.try_apply(op,node.score,node.hash);
                cands.push(Cand{op,parent:self.cur_node,score,hash});
            }
        }
        else{
            let node=self.cur_node;
            let mut child=self.nodes[node].child;
            let next_single=single&(self.nodes[child].next==!0);

            loop{
                self.cur_node=child;
                self.state.apply(self.nodes[child].op);
                self.dfs(cands,next_single);

                if !next_single{
                    self.state.revert(self.nodes[child].op);
                }
                child=self.nodes[child].next;
                if child==!0{
                    break;
                }
            }
            
            if !next_single{
                self.cur_node=node;
            }
        }
    }
}


fn solve()->Vec<usize>{
    const T:usize=1000;
    const M:usize=100;

    let mut tree={
        let state=State::new();
        let mut nodes=Vec::with_capacity(M*T*2);
        nodes.push(Node{op:!0,parent:!0,child:!0,prev:!0,next:!0,score:state.score(),hash:state.hash()});

        Tree{
            state,
            latest:0,
            nodes,
            cur_node:0
        }
    };

    let mut cands=vec![];
    let mut set=rustc_hash::FxHashSet::default();

    for i in 0..T{
        cands.clear();
        tree.dfs(&mut cands,true);

        if i+1==T{
            break;
        }

        assert_ne!(cands.len(),0);
        if cands.len()>M{
            cands.select_nth_unstable_by_key(M,|s|Reverse(s.score));
            cands.truncate(M);
        }

        set.clear();
        tree.update(cands.iter().filter(|cand|set.insert(cand.hash)).cloned());
    }

    let Cand{op,parent,score,..}=cands.into_iter().max_by_key(|s|s.score).unwrap();
    
    eprintln!("score = {}",score);
    let mut ret=tree.resotre(parent);
    ret.push(op);

    ret
}
雑な検証では1.5倍ほど高速になりました。

比較

ゲーム実況者Xの挑戦でビーム幅を固定して速度を比較しました。

普通の実装 今回の実装
最小 3208ms 2655ms
最大 3308ms 2756ms
平均 3249ms 2714ms

スコアに差があるのは優先度の精度差のよるものだと思います。

実装についての補足

この問題のテストケースではコインが多い順にマップをgreedyに選んでも即死するケースがないので多スタートはしていません。
また、atcoderでは未だにrustのバーションが足りていないのでselect_nth_unstableの実装をドキュメントからコピーしてmod nthに貼り付けています。

終わりに

普段は特に記事などは書かないのですが、この方法はメモリに優しいchokudaiサーチなどのビームサーチ以外のゲーム木探索でも応用可能だと思うので、記事にしてみました。

ちなみにこの記事を書いている途中にcolunさんも同じようなビームサーチを以前に実装していたことに気が付きました。

実装方法、挙動は違いますが、状態集合を差分で管理するというアイディアは既出のようです。

35
23
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
35
23