rust
LRUキャッシュ
RustDay 2

Rust でLRUキャッシュ用のコレクションを自作してみる

これは Rust Advent Calendar 2018 2日目の記事です。

はじめに

Rust 書き始めた頃は所有権に苦しめられてきた(個人の感想)。でも、 Rust を書きなれてくると大抵のアルゴリズムやデータ構造は実装できる気になってくる。

そこで(?)、Rust を素直に使うだけだと書き難いことで有名(*要出典)な双方向連結リストと連想配列を使って実装されることが多いLRUキャッシュを作ってみる。

今から作るLRUキャッシュは実用的ではないし(APIが少ない)、テストも不十分だと思う。
ここでの目的はそういうことではなく「unsafeってちょっとなー」とか思っている人向けの、 unsafe や Rust での生ポインタの扱いの練習の意味合いが強い。

作ったコードはここにある: https://github.com/hinohi/lru_collections

作るコレクションの名前は LruMap とする。

LruMapのAPIを決める

  • 最大サイズを指定できる
  • エントリーを追加したときに最大サイズを超えていたら、最近最も使われていないエントリーを破棄する
  • APIは Rust HashMap から最低限だけ選択
    • エントリーの追加
    • エントリーの取得
  • エントリーの追加・取得が O(1) で行えること
  • struct としては struct LruMap<K, V>

(今回はエントリーの内容による順序には無頓着なので) HashMap のメソッド一覧を眺めて、最低限のメソッドとして以下を実装することにした。

新しいエントリーを追加するメソッド

HashMap.insertと同じで、すでに存在する key に対して再度 insert した場合には以前に入っていた value を Some(V) として返す。

fn insert(&mut self, k: K, v: V) -> Option<V>;

値を取得するメソッド

key に対応するエントリーが存在する場合は対応する value を Some(&V) として返す。エントリーが存在しないときは None
LRU を実現するために、値にアクセスするだけでもコレクションへの変更が必要なので、 &mut self が要求される。

fn get<Q>(&mut self, k: &Q) -> Option<&V>;
where
    K: Borrow<Q>,
    Q: Hash + Eq + ?Sized,

その他のメソッド

最低限だけ用意する。

fn len(&self) -> usize;
fn is_empty(&self) -> bool;

データ構造を決める

LRUキャッシュは以下の2要素からなる。

  • エントリーが利用された順番を保持するための双方向連結リスト
  • 値の取得を高速に行うためのリストのノードへのポインタを保持する連想配列

連想配列は Rust 標準ライブラリの HashMap を使うことにする。
双方向連結リストは Rust の標準ライブラリに LinkedList というものがある。しかし LinkedList の実装の詳細はプライベートなので、「リストのノードへのポインタから辿ってリストを操作する」ことができない。今回は完全に自作することにする。

次にデータの保持場所を考える。
双方向連結リストに対して実際に行う操作を列挙すると以下の操作が必要だとわかった。

  • リストの先頭にノードを追加
  • リストの末尾のノードを削除して、削除されたノードに対応する key を返却
  • ノードへのポインタを指定して、リスト中の任意の場所のノードをリストの先頭に移動

エントリー K, V に対して、 K は連想配列と連結リストの双方で保持する必要があり、 V はどちらが保持しても良いことがわかる。
今回は K: Hash + Eq + Clone とすることで K を両方が保持、 V を連結リストが保持するように実装した。

実際の定義は以下

use std::ptr::NonNull;

struct LinkedList<K, V> {
    head: Option<NonNull<Node<K, V>>>,
    tail: Option<NonNull<Node<K, V>>>,
}

struct Node<K, V> {
    next: Option<NonNull<Node<K, V>>>,
    prev: Option<NonNull<Node<K, V>>>,
    key: K,
    value: V,
}

pub struct LruMap<K, V>
where
    K: Hash + Eq + Clone,
{
    max_size: usize,
    map: HashMap<K, NonNull<Node<K, V>>>,
    list: MyLinkedList<K, V>,
}

(K への参照を連想配列のキーにしたいが間に合わなかった(できるのかな?))

LinkedList は標準ライブラリのものと名前が被っているけどまあいいよね。
Option<NonNull<Node<_>>> というのは標準ライブラリの LinkedList のコードを参考にした。

実装していく

カスタマイズド双方向連結リスト

標準ライブラリのコードが大変に参考になる。
標準の LinkedList のコードを眺めつつ今回の LinkedList を作成していく(名前が紛らわしい)。

例えば「リストの先頭にノードを追加」は以下のように実装した。

fn push_front(&mut self, mut node: Box<Node<K, V>>) -> NonNull<Node<K, V>> {
    node.next = self.head;
    node.prev = None;
    unsafe {
        let node = Some(Box::into_raw_non_null(node));
        match self.head {
            None => self.tail = node,
            Some(mut head) => head.as_mut().prev = node,
        }
        self.head = node;
    }
    self.head.unwrap()  // Some(NonNull<_>) を NonNull<_> にしている
}
  1. Node<_> の所有権を受け取り
  2. リストの先頭に追加し(この時に Box<_>NonNull に変換している)
  3. 追加したノードへのポインタを返す

ポインターを作成するだけならセーフだが参照解決を行うのは unsafe なので大部分が unsafe で囲われている。
当初は push_front にノードへのポインタを渡していたが、こうするとなんか色々うまくいかなかったので、Box<_> を渡してノードのポインタを返すようになった。返されたノードへのポインタは連想配列の value として登録される。

ちなみに、push_front を使う側は、何も unsafe なことをしなくていいので、 関数ボディの半分以上が unsafe だが push_front 自身は unsafe ではない。

次に「ノードへのポインタを指定して、リスト中の任意の場所のノードをリストの先頭に移動」のメソッドを示す。

unsafe fn unlink_and_push_front(&mut self, node: *mut Node<K, V>) {
    let node = node.as_mut().unwrap();
    match node.prev {
        Some(mut prev) => prev.as_mut().next = node.next,
        // this node is the head node
        // nothing to do
        None => return,
    }

    match node.next {
        Some(mut next) => next.as_mut().prev = node.prev,
        // this node is the tail node
        // node.prev is Some<_> in this branch
        None => self.tail = node.prev,
    };

    node.next = self.head;
    node.prev = None;

    let node = Some(node.into());
    self.head.unwrap().as_mut().prev = node;
    self.head = node;
}

生ポインタを引数にとり、生ポインタを参照解決してただの &mut にしている。
残りの部分は多分 C でも Rust でも同じような感じになり、特に面白さはない(のに紙面はとる)。

というわけで、 Rust でも双方向連結リストを書くことはできた。

LruMap

insert のコードを以下に示す。

use std::mem;

pub fn insert(&mut self, k: K, v: V) -> Option<V> {
    // TODO use entry API
    if self.map.contains_key(&k) {
        unsafe {
            let mut node = self.map[&k];
            self.list.unlink_and_push_front(node.as_ptr());
            let node = node.as_mut();
            return Some(mem::replace(&mut node.value, v));
        }
    }

    // insert new node
    let node = Node::new(k.clone(), v);
    let ptr = self.list.push_front(Box::new(node));
    self.map.insert(k, ptr);

    // check size
    if self.max_size == 0 || self.map.len() <= self.max_size {
        return None;
    }

    // drop oldest node
    let tail = self.list.pop_back().unwrap();
    self.map.remove(&tail.key);
    None
}

まずはエントリーの存在確認を行い、結果で分岐する。

エントリーがすでに存在した場合は既存のエントリーを上書きする。
今回のように 1. 存在確認 2. 対応する値の取得と更新 はまとめて HashMap.entry.and_modified などで行うのが普通だが、今回は vmove するために move || とすると self まで move されてしまい、後続がうまく書けなかったので諦めた。無念。
おもむろに unsafe を開始して unlink_and_push_front で連結リストの順序を入れ替え、 mem::replace でノードが保持している value を書き換え、古い結果を返している。

エントリーが未登録だったときは

  1. Node<K, V> を作成し
  2. リストの先頭にノードを追加して、ノードへのポインタをもらい
  3. ポインタを HashMap に登録する

また、エントリーを追加したことによりコレクションのサイズが指定された最大サイズを超えていれば

  1. リストの末尾をpopして
  2. popされたノードから key を取得し
  3. HashMap からもエントリーを削除する

最後に、get のコードは以下の通り。

fn get<Q>(&mut self, k: &Q) -> Option<&V>
where
    K: Borrow<Q>,
    Q: Hash + Eq + ?Sized,
{
    let ptr = match self.map.get(k) {
        None => return None,
        Some(ptr) => ptr,
    };
    unsafe {
        self.list.unlink_and_push_front(ptr.as_ptr());
        Some(&ptr.as_ref().value)
    }
}

途中の ptr の型は &NonNull<Node<K, V>> になっている、ちょっと面白い(?)。
指定されてエントリーが存在していたら、unsafe を開始して、連結リストの順序を入れ替え、 value への参照を返している。

(実は関数から参照を返せる条件がよくわかってなかったりするので、 &ptr.as_ref().value でいいのかーってなってたりする。誰か教えて)

これで LruMap もできた。

軽くベンチ

せっかくなので。

#[bench]
fn bench_get_from_e6(b: &mut Bencher) {
    let mut m = LruMap::new(1_000_000);
    for i in 0..1_000_000 {
        m.insert(i, i);
    }
    b.iter(move || {
        for i in (0..100).rev() {
            m.get(&i);
        }
        for i in 0..100 {
            m.get(&i);
        }
    });
}

こんなのを e2 e4 e6 と作った。
(内部のアクセス順序微妙や)

結果

running 3 tests
test bench_get_from_e2 ... bench:       5,289 ns/iter (+/- 1,408)
test bench_get_from_e4 ... bench:       4,494 ns/iter (+/- 229)
test bench_get_from_e6 ... bench:       4,456 ns/iter (+/- 230)

エントリーの数に関係なく O(1) で取得はできてる。

まとめと感想

  • LRUキャッシュ用のコレクション LruMap を作った
  • 標準ライブラリの LinkedList を参考にすれば自作の双方向連結リストもそんなに難しくない
  • LruMap は内部で key を複製して二重に保持しているの無駄なので直したい
  • 作成中はコンパイル通ったのに実行したら SIGBUSSIGSEGV が頻発して、unsafe 嫌いってなってた
    • BoxNonNull 行き来していると C よりもデバッグしにくい印象?