前回はとりあえずIntern出来るような仕組みを作りました。
しかし、どこか使い勝手の悪い違和感がありました。
そのアドレスを比較する際にどうもas *const _ とするとすべてのアドレスが違ってしまう場合があって悩んでいたのです。
色々いじってみてわかったのですが、問題となったのはa:&&Exp,b:&&Expだった場合、cmp(a,b)と書いた場合の型が&Expと&Expの比較をしてほしいのに&&Expと&&Expの比較をしてしまうことがあるため、明示的にcmp(*a,*a)と書かないと期待した比較ができなくなるという問題があるため、間違いを起こしやすい設計になっていました。
同じようなミスは繰り返したくありませんので、cmp(&&Exp,&&Exp)と受け取られないようにしたいです。
そこで今回はミスを減らすために、トレイトを使って比較する方法を考えてみました。
trait AddressEq {
fn addr_eq(&self, e2: &Self) -> bool {
self as *const _ == e2 as *const _
}
}
まずこのようなアドレスを比較するトレイトを作ります。これだけで動かすことは出来ませんのでimplします:
impl AddressEq for Exp {}
トレイトにすでに実装はあるので中身はなくても大丈夫です。
これを前回のcmp関数のかわりに使うようにします。
これで比較するときは必ず&Expのアドレスで比較されるようになるので、おかしなエラーで悩まず済むようになります。
以下全ソースになります:
use rustc_hash::FxHashMap;
use std::cell::RefCell;
fn leak<A>(a:A) -> &'static A {
Box::leak(Box::new(a))
}
#[derive(Debug,Hash,Eq)]
enum Exp {
Int(i8),
Add(EXP,EXP),
}
trait AddressEq {
fn addr_eq(&self, e2: &Self) -> bool {
self as *const _ == e2 as *const _
}
}
impl AddressEq for Exp {}
type EXP = &'static Exp;// boxのアドレスになってる。
impl PartialEq for Exp {
fn eq(&self,other:&Exp) -> bool {
self.addr_eq(other) ||
match (self,other) {
(Exp::Int(i1),Exp::Int(i2)) => i1==i2,
(Exp::Add(e1,e2),Exp::Add(e3,e4)) => e1.addr_eq(e3) && e2.addr_eq(e4),
(_, _) => false
}
}
}
struct EArena(FxHashMap<EXP, EXP>);
impl EArena {
fn new()->EArena {
EArena(FxHashMap::default())
}
fn intern(&mut self,e:Exp) -> EXP {
let pe = leak(e);
if let Some(x) = self.0.get(pe) {
unsafe {core::ptr::drop_in_place(pe as *const _ as *mut Exp)}
return *x;
}
self.0.insert(pe, pe);
pe
}
}
impl Drop for EArena {
fn drop(&mut self) {
for (e, _) in self.0.iter() {
unsafe {core::ptr::drop_in_place(e as *const _ as *mut Exp)}
}
}
}
thread_local!(static E_ARENA: RefCell<EArena> = RefCell::new(EArena::new()));
fn intern(e:Exp) -> EXP {
E_ARENA.with(|arena|{arena.borrow_mut().intern(e)})
}
fn int(i:i8) -> EXP {
intern(Exp::Int(i))
}
fn add(e1:EXP,e2:EXP) -> EXP {
intern(Exp::Add(e1,e2))
}
fn eval(e:EXP) -> i8{
match e {
Exp::Int(i) => *i,
Exp::Add(e1,e2) => eval(*e1)+eval(*e2)
}
}
#[test]
fn testa() {
assert_eq!(3, eval(intern(Exp::Add(intern(Exp::Int(1)),intern(Exp::Int(2))))));
assert_eq!(3, eval(add(int(1),int(2))));
assert_eq!(true, int(1).addr_eq(int(1)));
assert_eq!(true, add(int(1),int(2)).addr_eq(add(int(1),int(2))));
assert_eq!(false, int(2).addr_eq(leak(Exp::Int(2))));
assert_eq!(true, int(1)==int(1));
assert_eq!(true, add(int(1),int(2))==add(int(1),int(2)));
assert_eq!(true, int(2)==leak(Exp::Int(2)));
assert_eq!(false, add(int(1),int(2))==add(int(1),leak(Exp::Int(2))));
}
fn main() {
println!("{:?}", eval(intern(Exp::Add(intern(Exp::Int(1)),intern(Exp::Int(2))))));
println!("{:?}", eval(add(int(1),int(2))));
println!("{:?}", int(1));
println!("{:?}", int(1).addr_eq(int(1)));
println!("{:?}", add(int(1),int(2)).addr_eq(add(int(1),int(2))));
println!("{:?}", int(2).addr_eq(leak(Exp::Int(2))));
}
続き: