Rust
LLVM

[Rust]LLVMで引数を取得する

LLVMで引数をとる関数の書き方がよくわからなかったのでメモ。

ここでは、引数を2つ取り加算するadd関数を考える。

LLVM IR コード(LLVM アセンブリ)

アセンブリで書くと以下のような感じ。

add.ll
define i32 @add(i32, i32) {
entry:
  %x = alloca i32        ; 変数xをアロケート
  %y = alloca i32        ; 変数yをアロケート
  store i32 %0, i32* %x  ; 1番目の引数をxにセット
  store i32 %1, i32* %y  ; 2番目の引数をyにセット
  %2 = load i32, i32* %x
  %3 = load i32, i32* %y
  %4 = add i32 %2, %3
  ret i32 %4
}

関数の宣言部分は define i32 @add(i32, i32) となっていて、引数を2つ取ることはわかるが、引数名はなんなのかが書かれていない。(LLVM APIを使ってなんらかの言語から出力するとこの形になる。)

アセンブリの場合は簡単で、引数名は'%'のあとに0から始まる数値がつく形で表される。
つまり、引数が2つの関数の場合、1番目の引数名は'%0'、2番目の引数名は'%1'となる。

これは気づくと簡単なのだが、知らないとハマる。

LLVM API

以下の例では、llvm-sysを使用しています。

RustでLLVM APIを使って書く場合、引数を得るには LLVMGetParam を使う。

1番目の引数を変数xに得るには以下のように書ける。

    let arg1 = unsafe { LLVMGetParam(function, 0) };  // 1番目の引数を得る。
    let x_name = CString::new("x").unwrap();
    let x = unsafe { LLVMBuildAlloca(builder, LLVMInt32Type(), x_name.as_ptr()) };
    unsafe { LLVMBuildStore(builder, arg1, x); }

2番目の引数の場合には、LLVMGetParam(function, 1)のようになる。
もし3番目の引数があるならLLVMGetParam(function, 2)のようになる、以下同様。

add関数全体を出力するコードは以下。

add.rs
extern crate llvm_sys;

use llvm_sys::core::*;
use llvm_sys::target;
use llvm_sys::analysis::{LLVMVerifyModule, LLVMVerifierFailureAction};
use std::ffi::CString;
use std::os::raw::c_char;

/// Initialise LLVM
///
/// Makes sure that the parts of LLVM we are going to use are
/// initialised before we do anything with them.
fn initialise_llvm() {
    unsafe {
        if target::LLVM_InitializeNativeTarget() != 0 {
            panic!("Could not initialise target");
        }
        if target::LLVM_InitializeNativeAsmPrinter() != 0 {
            panic!("Could not initialise ASM Printer");
        }
    }    
}

fn main(){
    let llvm_error = 1;

    initialise_llvm();

    // setup our builder and module
    let builder = unsafe { LLVMCreateBuilder() };
    let mod_name = CString::new("my_module").unwrap();
    let module = unsafe { LLVMModuleCreateWithName(mod_name.as_ptr()) };

    // create our function prologue
    let function_type = unsafe {
        let mut param_types = [LLVMInt32Type(), LLVMInt32Type()];
        LLVMFunctionType(LLVMInt32Type(), param_types.as_mut_ptr(), param_types.len() as u32, 0)
    };
    let function_name = CString::new("add").unwrap();
    let function = unsafe { LLVMAddFunction(module, function_name.as_ptr(), function_type) };
    let entry_name = CString::new("entry").unwrap();
    let entry_block = unsafe { LLVMAppendBasicBlock(function, entry_name.as_ptr()) };
    unsafe { LLVMPositionBuilderAtEnd(builder, entry_block); }

    // int x = arg1
    let arg1 = unsafe { LLVMGetParam(function, 0) };  // 1番目の引数を得る。
    let x_name = CString::new("x").unwrap();
    let x = unsafe { LLVMBuildAlloca(builder, LLVMInt32Type(), x_name.as_ptr()) };
    unsafe { LLVMBuildStore(builder, arg1, x); }

    // int y = arg2
    let arg2 = unsafe { LLVMGetParam(function, 1) };  // 2番目の引数を得る。
    let y_name = CString::new("y").unwrap();
    let y = unsafe { LLVMBuildAlloca(builder, LLVMInt32Type(), y_name.as_ptr()) };
    unsafe { LLVMBuildStore(builder, arg2, y); }


    // return x + y
    let x_val_name = CString::new("x_val").unwrap();
    let x_val = unsafe { LLVMBuildLoad(builder, x, x_val_name.as_ptr()) };
    let y_val_name = CString::new("y_val").unwrap();
    let y_val = unsafe { LLVMBuildLoad(builder, y, y_val_name.as_ptr()) };
    let xy_val_name = CString::new("xy_val").unwrap();
    unsafe {
        let res = LLVMBuildAdd(builder, x_val, y_val, xy_val_name.as_ptr());
        LLVMBuildRet(builder, res);
    }

    // verify it's all good
    let mut error: *mut c_char = 0 as *mut c_char;
    let ok = unsafe {
        let buf: *mut *mut c_char = &mut error;
        LLVMVerifyModule(module, LLVMVerifierFailureAction::LLVMReturnStatusAction, buf)
    };
    if ok == llvm_error {
        let err_msg = unsafe { CString::from_raw(error).into_string().unwrap() };
        panic!("cannot verify module '{:?}'.\nError: {}", mod_name, err_msg);
    }

    // Clean up the builder now that we are finished using it.
    unsafe { LLVMDisposeBuilder(builder) }

    // Dump the LLVM IR to stdout so we can see what we've created
    unsafe { LLVMDumpModule(module) }

    // Clean up the module after we're done with it.
    unsafe { LLVMDisposeModule(module) }
}