LoginSignup
0
0

More than 1 year has passed since last update.

Rust で Intern されたASTを扱う方法

Last updated at Posted at 2023-06-11

TypeScriptのUnion型は便利ですが、型の同一系検査は結構重い処理だそうです。そこで1つ1つの型にIDをつけてIDが同じなら比較することで高速に処理するそうです。
Rustには同様の処理をする仕組みとしてArenaという仕組みがあります。
Arenaを使えば高速にASTを使えるので便利です。
しかしながらパターンマッチをする際にはID管理されたArenaだと多少面倒なコードを書く必要が出てきます。
そこでTypedArenaという型情報を持ったArenaを使うことでパターンマッチを行えるArenaが作れます。

TypedArenaは便利ですが関数を引き回す、あるいはself.を関数名につけなければならず面倒です。
そこでThreadLocalなTypedArenaが欲しくなります。
またTypedArenaは同じASTに同じアドレスが割り当てられるかもよくわかりません。

その変の解決がうまくされたライブラリがあるとよいのですが、様々なArenaやCacheのライブラリを見てもよく分からなかったので勉強がてらThreadLocalにあって引き回すことがなく、同じASTには同じアドレスを割り当てる同一判定が高速でパターンマッチもきれいに書けて、最終的にはアリーナの中身のデータは消去されるメモリリークを起こさなそうな、仕組みを作ってみました。

まず型安全に&'static Aな型にするためBox::leakを使いました。Boxで包んだポインタになるので効率が悪い気もしますが、とりあえず&'static Aな型に出来るので良しとしました。

fn leak<A>(a:A) -> &'static A {
    Box::leak(Box::new(a))
}

比較ですがaddr_ofを使ってアドレスを比較します。
今回作成したものはBoxで包まれたものを扱っているので、ポインタの中身のポインタを比較しています:

fn cmp<A>(e1: &A, e2: &A) -> bool {
    core::ptr::addr_of!(*e1)==core::ptr::addr_of!(*e2)
}

ASTはHashで使うためにHashとEqをDerivingしています。ネストした型は&'static Expを使いますがEXPという名前で扱うことにしました。

#[derive(Debug,Hash,Eq)]
enum Exp {
    Int(i8),
    Add(EXP,EXP),
}
type EXP = &'static Exp;// boxのアドレスになってる。

HashMapを使う際にはeqメソッドが使われますが、同じアドレスなら等価であり、ネストしたデータについては常に同一アドレスを使っている前提でアドレス比較のみをすることでアドレスが違うデータでHashMapを引く場合でも高速に比較ができるようにしました:

impl PartialEq for Exp {
    fn eq(&self,other:&Exp) -> bool {
        cmp(self,other) ||
        match (self,other) {
            (Exp::Int(i1),Exp::Int(i2)) => i1==i2,
            (Exp::Add(e1,e2),Exp::Add(e3,e4)) => cmp(*e1,*e3) && cmp(*e2,*e4),
            (_, _) => false
        }
    }
}

InternするアリーナはRustcでも使われているFxHashMapを使っています。
同じASTがHashMap内にあれば指定されたASTは削除してHashMap内のASTの借用ポインタを返し、HashMap内になければHashMap内に格納して借用ポインタを返却します。

struct EArena(FxHashMap<EXP, EXP>);
impl EArena {
    fn new()->EArena {
        EArena(FxHashMap::default())
    }
    fn intern(&mut self,e:Exp) -> EXP {
        let pe = leak(e);
        if let Some(x) = self.0.get(pe) {
            unsafe {core::ptr::drop_in_place(pe as *const _ as *mut Exp)}
            return *x;
        }
        self.0.insert(pe, pe);
        pe
    }
}

アリーナが削除される際には、保存されていたeを削除するdropメソッドを追加しました。
これはなくてもいいのかもしれませんが、よくわからないので加えました:

impl Drop for EArena {
    fn drop(&mut self) {
        for (e, _) in self.0.iter() {
            unsafe {core::ptr::drop_in_place(e as *const _ as *mut Exp)}
        }
    }
}

スレッドローカルなアリーナは以下のようにRefCellに格納することで実現しました:

thread_local!(static E_ARENA: RefCell<EArena> = RefCell::new(EArena::new()));
fn intern(e:Exp) -> EXP {
    E_ARENA.with(|arena|{arena.borrow_mut().intern(e)})
}

AST生成関数はinternを呼び出すことで同じASTには同じアドレスが割り当てられるようにしました:

fn int(i:i8) -> EXP {
    intern(Exp::Int(i))
}
fn add(e1:EXP,e2:EXP) -> EXP {
    intern(Exp::Add(e1,e2))
}

パターンマッチな以下のようにきれいに書くことが出来ます:

fn eval(e:EXP) -> i8{
    match e {
        Exp::Int(i) => *i,
        Exp::Add(e1,e2) => eval(e1)+eval(e2)
    }
}

使い方はこのようにinternを呼び出して使ったり、AST生成関数で短く生成できます:

#[test]
fn testa() {
    assert_eq!(3, eval(intern(Exp::Add(intern(Exp::Int(1)),intern(Exp::Int(2))))));
    assert_eq!(3, eval(add(int(1),int(2))));
    assert_eq!(true, cmp(int(1),int(1)));
    assert_eq!(true, cmp(add(int(1),int(2)),add(int(1),int(2))));
    assert_eq!(false, cmp(int(2),leak(Exp::Int(2))));

    assert_eq!(true, int(1)==int(1));
    assert_eq!(true, add(int(1),int(2))==add(int(1),int(2)));
    assert_eq!(true, int(2)==leak(Exp::Int(2)));
    assert_eq!(false, add(int(1),int(2))==add(int(1),leak(Exp::Int(2))));
}

このような仕組みを作ると生成コストは若干大きくなりますが、比較コストが低く、本質的な処理を簡潔に記述することが可能になります。

今後の課題としては、Boxを使わないようにできるのであればBoxをなくして高速化する。
Derivingを使って比較関数や生成関数を自動生成できるようにする。
アリーナの仕組みを一般化して使えるようにするなどができると思います。

すでにこういった仕組みがあるのであればぜひコメント欄などで教えてください。

rustc 1.69.0-nightly (bd39bbb4b 2023-02-07)

use rustc_hash::FxHashMap;
use std::cell::RefCell;
fn cmp<A>(e1: &A, e2: &A) -> bool {
    core::ptr::addr_of!(*e1)==core::ptr::addr_of!(*e2)
}
fn leak<A>(a:A) -> &'static A {
    Box::leak(Box::new(a))
}
#[derive(Debug,Hash,Eq)]
enum Exp {
    Int(i8),
    Add(EXP,EXP),
}
type EXP = &'static Exp;// boxのアドレスになってる。
impl PartialEq for Exp {
    fn eq(&self,other:&Exp) -> bool {
        cmp(self,other) ||
        match (self,other) {
            (Exp::Int(i1),Exp::Int(i2)) => i1==i2,
            (Exp::Add(e1,e2),Exp::Add(e3,e4)) => cmp(*e1,*e3) && cmp(*e2,*e4),
            (_, _) => false
        }
    }
}

struct EArena(FxHashMap<EXP, EXP>);
impl EArena {
    fn new()->EArena {
        EArena(FxHashMap::default())
    }
    fn intern(&mut self,e:Exp) -> EXP {
        let pe = leak(e);
        if let Some(x) = self.0.get(pe) {
            unsafe {core::ptr::drop_in_place(pe as *const _ as *mut Exp)}
            return *x;
        }
        self.0.insert(pe, pe);
        pe
    }
}
impl Drop for EArena {
    fn drop(&mut self) {
        for (e, _) in self.0.iter() {
            unsafe {core::ptr::drop_in_place(e as *const _ as *mut Exp)}
        }
    }
}

thread_local!(static E_ARENA: RefCell<EArena> = RefCell::new(EArena::new()));
fn intern(e:Exp) -> EXP {
    E_ARENA.with(|arena|{arena.borrow_mut().intern(e)})
}

fn int(i:i8) -> EXP {
    intern(Exp::Int(i))
}
fn add(e1:EXP,e2:EXP) -> EXP {
    intern(Exp::Add(e1,e2))
}
fn eval(e:EXP) -> i8{
    match e {
        Exp::Int(i) => *i,
        Exp::Add(e1,e2) => eval(e1)+eval(e2)
    }
}
#[test]
fn testa() {
    assert_eq!(3, eval(intern(Exp::Add(intern(Exp::Int(1)),intern(Exp::Int(2))))));
    assert_eq!(3, eval(add(int(1),int(2))));
    assert_eq!(true, cmp(int(1),int(1)));
    assert_eq!(true, cmp(add(int(1),int(2)),add(int(1),int(2))));
    assert_eq!(false, cmp(int(2),leak(Exp::Int(2))));

    assert_eq!(true, int(1)==int(1));
    assert_eq!(true, add(int(1),int(2))==add(int(1),int(2)));
    assert_eq!(true, int(2)==leak(Exp::Int(2)));
    assert_eq!(false, add(int(1),int(2))==add(int(1),leak(Exp::Int(2))));
}
fn main() {
    println!("{:?}", eval(intern(Exp::Add(intern(Exp::Int(1)),intern(Exp::Int(2))))));
    println!("{:?}", eval(add(int(1),int(2))));
    println!("{:?}", int(1));
    println!("{:?}", cmp(int(1),int(1)));
    println!("{:?}", cmp(add(int(1),int(2)),add(int(1),int(2))));
    println!("{:?}", cmp(int(2),leak(Exp::Int(2))));
}

続き:

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0