LoginSignup
42
25

More than 3 years have passed since last update.

RustでJITコンパイルする電卓を実装してみる

Posted at

はじめに

JITコンパイルってなんでしょうか?インタプリタを高速化するためにネイティブコードにコンパイルしてから実行する、みたいな説明をよく見ますが、分かるようでよく分かりません。書いてみたら分かるかもしれないと思いやってみました。

念のため注意ですが、電卓程度をJITコンパイルで実装してもたぶんいいことはありません。検証とかはしてませんが、なんなら普通に作るより遅くなる可能性もあると思います。そんな感覚で読んで頂ければと思います。

インタプリタによる実装

JITコンパイラで実装する前に、素直にインタプリタで電卓を実装してみます。完成形はこんな感じになります。

> 1 + 1
2
> 1 + 2 * 3
7

結構よくあるプログラムだと思うので、さらっと解説していきます。

トークンへの変換

入力される文字列は数字、記号、空白などが入り混じっており、数字もつながりが表現されていない状態です。1 + 23だと['1', ' ', '+', ' ', '2', '3']みたいな感じになり、これでは後に続く処理で不便なので扱いやすい形に成形しておきます。

enum TokenKind {
    Number(u8),
    Plus,
    Minus,
    Asterisk,
    Slash,
}

このようなTokenKindを定義しVec<TokenKind>を生成します。この際、スペースなどの余分な文字は無視したり、予期しない文字が入力された場合はエラーにしたりします。すると上記は[Number(1), Plus, Number(23)]というすっきりとした形に変換されます。

ASTへの変換

トークンに変換したら次は、抽象構文木(Abstract Syntax Tree: AST)への変換を行います。

演算子には優先度があるので、例えば1 + 2 * 3みたいな式を先頭から評価してしまうと先に足し算が行われてしまい、正しい結果が得られません。そこで優先度を考慮した木構造として式を表現します。

enum BinOpKind {
    Add,
    Sub,
    Mul,
    Div,
}

enum NodeKind {
    Number(u8),
    BinOp {
        kind: BinOpKind,
        lhs: Box<NodeKind>,
        rhs: Box<NodeKind>,
    },
}

木構造のNodeは数値と2項演算子の2種類とし、2項演算子には加算、減算、乗算、除算の4種類があるとします。さらに2項演算子は左辺と右辺の情報をもちます。この定義を使って1 + 2 * 3を表現すると、ちょっと長いですがこんな感じになります。

BinOp {
    kind: Add,
    lhs: Number(1),
    rhs: BinOp {
        kind: Mul,
        lhs: Number(2),
        rhs: Number(3)
    }
}

この表現を内側から評価していけば、正しい順序で計算することができます。

式の評価

ASTを評価する関数を再帰的に呼ぶことで、内側から評価することができます。

fn eval(ast: NodeKind) -> u8 {
    match ast {
        Number(n) => n,
        BinOp { kind, lhs, rhs } => {
            match kind {
                Add => eval(*lhs) + eval(*rhs),
                Sub => eval(*lhs) - eval(*rhs),
                Mul => eval(*lhs) * eval(*rhs),
                Div => eval(*lhs) / eval(*rhs),
            }
        },
    }
}

先程の例では、まずAddBinOpastとして入力され、rhsを評価する際にMulBinOpが再帰的にastとして入力されることになります。MulBinOpの計算結果をlhsである1と足し合わせて計算完了です。このようにして最終的に出てきた値が、入力した式の答えになります。

JITコンパイラによる実装

ではJITコンパイルです。環境としてはLinux/x86_64を前提とします。今回はこのような手順でJITコンパイルに対応してみようと思います。

  1. コンパイルしたネイティブコードを書き込む領域を用意する。
  2. コンパイルしたネイティブコードを書き込む。
  3. ネイティブコードの領域を関数として実行する。

ここで、コンパイル処理の入力となるのはASTです。なので、AST生成までのコードはインタプリタによる実装と変わらないことになります。

書き込む領域の用意

文章で見てもいまいち分かりづらいと思うので、ここからはソースコードを見ていきます。まずは領域を用意するところです。

extern crate libc;
use libc::{c_void, c_int, size_t, PROT_READ, PROT_WRITE, PROT_EXEC};
use std::alloc::{alloc, dealloc, Layout};

extern "C" {
    fn mprotect(addr: *const c_void, len: size_t, prot: c_int) -> c_int;
}
...
struct Compiler {
    p_start: *mut u8,
    p_current: *mut u8,
}

const CODE_AREA_SIZE: usize = 1024;
const PAGE_SIZE: usize = 4096;

impl Compiler {
    unsafe fn new() -> Self {
        let layout = Layout::from_size_align(CODE_AREA_SIZE, PAGE_SIZE).unwrap();
        let p_start = alloc(layout);
        let r = mprotect(p_start as *const c_void, CODE_AREA_SIZE, PROT_READ|PROT_WRITE|PROT_EXEC);
        assert!(r == 0);
        Compiler {
            p_start,
            p_current: p_start,
        }
    }
...
}

mprotectというCの関数を呼びたいので、libcクレートをインポートして関数の宣言をしておきます。この関数はLinux環境での標準Cライブラリに含まれるシステムコールラッパーです。Rustは(Linux環境しか見てませんが)デフォルトで標準Cライブラリ(libc.so)をリンクして動くので、この関数を使うのにリンカへの指示は必要ありません。

通常メモリは安全性のために保護されています。保護というのはできることに制限があるということで、ファイルと同じように読み取り、書き込み、実行の3種類の権限をそれぞれ設定できます。通常であれば獲得したメモリ領域に実行可能フラグは付いていないため、領域内のネイティブコードを実行するには権限を変更しておく必要があります。それを可能にするのがmprotectというLinuxシステムコールです。システムコールの詳細はmprotectのマニュアルをご参照下さい。

さて、それでは関数本体を見ていきます。Layout::from_size_alignで獲得したいメモリのレイアウトを定義し、allocを使ってメモリを獲得します。領域のサイズは適当です。ページサイズについてはいろいろ細かい話もありますが、Linuxシステムなら4KiBにしておけば大体は大丈夫だと思います。メモリが獲得できたらmprotectを使って実行権限を与えます。

ネイティブコードの書き込み

領域が用意できたので、ネイティブコードを生成していきます。構造的には先ほどのインタプリタで式を評価する関数と変わりません。

impl Compiler {
...
    unsafe fn gen_code_ast(&mut self, ast: NodeKind) {
        match ast {
            Number(n) => {
                self.push_code(&[0x6a, n]); // push {}
            },
            BinOp { kind, lhs, rhs } => {
                self.gen_code_ast(*lhs);
                self.gen_code_ast(*rhs);
                self.push_code(&[0x5f]); // pop rdi
                self.push_code(&[0x58]); // pop rax
                match kind {
                    Add => {
                        self.push_code(&[0x48, 0x01, 0xf8]); // add rax, rdi
                    },
                    Sub => {
                        self.push_code(&[0x48, 0x29, 0xf8]); // sud rax, rdi
                    },
                    Mul => {
                        self.push_code(&[0x48, 0x0f, 0xaf, 0xc7]); // imul rax, rdi
                    },
                    Div => {
                        self.push_code(&[0x48, 0x99]); // cqo
                        self.push_code(&[0x48, 0xf7, 0xff]); // idiv rdi
                    },
                }
                self.push_code(&[0x50]); // push rax
            },
        }
    }
}

push_codeに与えている16進数の配列がネイティブコードになります。ここではx86_64の機械語です。一応対応するアセンブリ表現も横にコメントで付けてあります。(そうしないと自分でも分からなくなる。)アセンブリはIntel記法を採用しており、これ以降の説明も同様です。

ネイティブコードを実行するときにASTを解析して再帰的に、みたいなやり方は難しいので、ここではスタックマシンと呼ばれる方法でシーケンシャルに実行できる命令を生成していきます。今回作る電卓には2項演算子しか出てこないため、この演算を「スタックから2つ数値をPOPして答えをスタックにPUSHする」という動作で実装すればよさそうです。

再度1 + 2 * 3の例を考えてみます。アセンブリで表現すると以下のようなコードが生成されます。

push 1
push 2
push 3
pop rdi
pop rax
imul rax, rdi
push rax
pop rdi
pop rax
add rax, rdi
push rax

アセンブリの細かい解説をここではできませんが、そこまで難しい知識は必要ありません。pushpopはそのままスタック操作で、raxrdiはレジスタの名前(変数名のようなもの)です。少し注意として、addimulなどの計算は2つのレジスタを取り、1つ目のレジスタに結果が格納されます。つまりadd rax, rdiでは、raxの中身とrdiの中身を足し合わせた結果がraxに格納されます。

このコードが動く様子をスタックの状態で見てみます。まずは先頭の1がスタックにPUSHされます。

           
1

この1はそのままで2 * 3の計算が先に実行されます。左辺の2と右辺の3がそれぞれPUSHされます。

           
3
2
1

ここで2項演算子の乗算が行われます。スタックの頭にある2つの数値32をPOPし、乗算の結果である6がスタックにPUSHされます。

           
6
1

続いて2項演算子の加算が行われます。同様にスタックから61をPOPし、結果である7がPUSHされます。

           
7

ちなみにpush_codeはこのように実装しています。ループとかキャストとか、あまりいい実装じゃない気もしますね。

impl Compiler {
...
    unsafe fn push_code(&mut self, code: &[u8]) {
        for b in code.iter() {
            std::ptr::write(self.p_current, *b);
            self.p_current = (self.p_current as u64 + 1) as *mut u8;
        }
    }
...
}

ここまでで計算自体はできていますが、出来上がったネイティブコードの最後にpop raxretを追加します。

impl Compiler {
...
    unsafe fn gen_code(&mut self, ast: NodeKind) {
        self.gen_code_ast(ast);
        self.push_code(&[0x58]); // pop rax
        self.push_code(&[0xc3]); // ret
    }
}

これは、生成したネイティブコードを関数として実行させるための処理です。計算を実行すると、最終的な結果はスタックに残った状態になるため、この最終結果を取り出してraxに格納します。retは関数の呼び出し元に戻るための命令です。ここで重要なのは、関数を呼び出すときに呼び出し元は呼び出した関数の戻り値がraxに格納されていることを期待している、ということです。このような仕様はABI(Application Binary Interface)と呼ばれ、CPU、OS、言語などによって違ってきます。今回はRustのABIに従うことで、自分が組み立てたネイティブコードをあたかもRustの関数であるかのように使うことができます。

関数として実行

最後に生成したネイティブコードをRustのコードから呼び出します。ネイティブコードが格納されている領域はバイナリレベルでは実行形式における関数定義と変わりませんが、言語レベルでは関数として見えていないため、うまく呼び出すことができません。コード領域のポインタを関数ポインタに型変換する必要があります。このような型変換は、C言語などでは普通にキャストすればできますが、Rustでは*mut u8からfn() -> u8へのasによるキャストが認められていません。このような場合には、標準ライブラリにあるtransmuteという任意の型変換を行うことができる関数が使えます。あとは型変換した関数を実行し、戻り値を受け取れば計算結果が格納されているはずです。

pub fn interpret(line: &str, use_jit: bool) -> Result<u8, ()> {
...
        unsafe {
            let mut compiler = Compiler::new();
            compiler.gen_code(*ast);
            let code: fn() -> u8 = std::mem::transmute(compiler.p_start);

            // Run generated code!
            let r = code();

            compiler.free();
            Ok(r)
        }
...
}

今回、関数の型をfn() -> u8としましたが、これはABIとして「RustのABI」を使うことを意味しています。Rustでは関数のABIをexternにより指定可能です。今回はシンプルな関数なのでそこまで意識せずに動作しましたが、もう少し複雑なことをする場合は、RustのABIとして正しく構築できているか、異なるABIを使うなら適切なexternを指定できているか、などABIへの意識が必要になると思います。

ソースコード

ソースコード全体はこちらに上げていますので、よかったら見てみて下さい。

感想

JITコンパイルを理解するために実装してみましたが、書いているといろいろと疑問が出てきて、既存のJITコンパイラの実装などもいろいろ見てみたくなりました。あと、冒頭にも書いた通り電卓程度を実装してもあんまり効果が分からないので、もう少し頑張って性能検証とかができるレベルまでもっていきたいと思いました。

42
25
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
42
25