ZACKYこと山崎進です。
今までのコードだとネイティブコードを実行できない環境の時には panic
を発生させていたので,ネイティブコード実行可否判断をして初期設定するように改造します。
コード
早速コードを見ていきましょう。
lib/nif_llvm_2.ex
defmodule NifLlvm2 do
use Rustler, otp_app: :nif_llvm_2, crate: :llvm
use OK.Pipe
@moduledoc """
Documentation for NifLlvm2.
"""
@does_support_native "SYSTEM_ELIXIR_DOES_SUPPORT_NATIVE"
def init() do
case {initialize_native_target(), initialize_native_asm_printer()} do
{:ok, :ok} ->
System.put_env(@does_support_native, "true")
{:ok, true}
_ ->
System.put_env(@does_support_native, "false")
IO.puts "Target platform doesn't support native code."
{:ok, false}
end
end
def does_support_native() do
case System.get_env(@does_support_native) do
nil ->
init()
does_support_native()
"true" -> true
_ -> false
end
end
def run_code() do
case does_support_native() do
true ->
generate_code_nif()
~> execute_code_nif()
_ ->
{:error, :error}
end
end
def generate_code() do
case does_support_native() do
true ->
generate_code_nif()
_ ->
{:error, :error}
end
end
def execute_code(code) do
case does_support_native() do
true ->
execute_code_nif(code)
_ ->
{:error, :error}
end
end
defp generate_code_nif(), do: exit(:nif_not_loaded)
defp execute_code_nif(_code), do: exit(:nif_not_loaded)
defp initialize_native_target(), do: exit(:nif_not_loaded)
defp initialize_native_asm_printer(), do: exit(:nif_not_loaded)
end
順に説明します。
def init() do
case {initialize_native_target(), initialize_native_asm_printer()} do
{:ok, :ok} ->
System.put_env(@does_support_native, "true")
{:ok, true}
_ ->
System.put_env(@does_support_native, "false")
IO.puts "Target platform doesn't support native code."
{:ok, false}
end
end
パターンマッチを使って2つの条件(initialize_native_target/0
とinitialize_native_asm_printer/0
)を真偽値を同時に判定しています。ともにtrue
だった時に環境変数@does_support_native
を"true"
に設定し,そうでない時には"false"
に設定してメッセージを表示します。
一応 GenServer
の init
の書き方に準じて戻り値を設定しています。
def does_support_native() do
case System.get_env(@does_support_native) do
nil ->
init()
does_support_native()
"true" -> true
_ -> false
end
end
環境変数@does_support_native
の値を見て,設定されていなかった時(nil
)は初期化して再実行し,"true"
だった時には true
,それ以外だった時にはfalse
を返します。
def run_code() do
case does_support_native() do
true ->
generate_code_nif()
~> execute_code_nif()
_ ->
{:error, :error}
end
end
does_support_native/0
の戻り値を見て true
ならばコード生成して実行し,それ以外ならばエラーを返します。タプルの2つ目の:error
は何かエラーコードにした方が親切ではありますね。後ほど検討したいと思います。
generate_code/0
もexecute_code/1
も同様にラップします。
本当はこれらについては,ガード節(when
)に does_support_native/0
を使えたらさらにスッキリ書けるのですけどね。
native/llvm/src/lib.rs
#[macro_use] extern crate rustler;
// #[macro_use] extern crate rustler_codegen;
#[macro_use] extern crate lazy_static;
extern crate llvm_sys;
use rustler::{Env, Term, NifResult, Encoder};
use llvm_sys::core::*;
use llvm_sys::target;
use llvm_sys::analysis::{LLVMVerifyModule, LLVMVerifierFailureAction};
use llvm_sys::execution_engine::*;
use llvm_sys::LLVMModule;
use std::ffi::CString;
use std::os::raw::c_char;
mod atoms {
rustler_atoms! {
atom ok;
atom error;
//atom __true__ = "true";
//atom __false__ = "false";
}
}
rustler_export_nifs! {
"Elixir.NifLlvm2",
[("generate_code_nif", 0, generate_code_nif),
("execute_code_nif", 1, execute_code_nif),
("initialize_native_target", 0, initialize_native_target),
("initialize_native_asm_printer", 0, initialize_native_asm_printer)],
None
}
/*
int main() {
int a = 32;
int b = 16;
return a + b;
}
define i32 @main() #0 {
%1 = alloca i32, align 4
%a = alloca i32, align 4
%b = alloca i32, align 4
store i32 0, i32* %1
store i32 32, i32* %a, align 4
store i32 16, i32* %b, align 4
%2 = load i32, i32* %a, align 4
%3 = load i32, i32* %b, align 4
%4 = add nsw i32 %2, %3
ret i32 %4
}
*/
mod llvm {
use llvm_sys::LLVMModule;
use std::sync::RwLock;
lazy_static! {
pub static ref VEC_MUT: RwLock<Vec<&'static LLVMModule>> = {
let v = Vec::new();
RwLock::new(v)
};
}
}
fn write_vec_mut(module: &'static LLVMModule) -> Result<usize, String> {
let mut v = try!(llvm::VEC_MUT.write().map_err(|e| e.to_string()));
v.push(module);
Ok(v.len() - 1)
}
fn read_vec(id: usize) -> Result<&'static LLVMModule, String> {
let v = try!(llvm::VEC_MUT.read().map_err(|e| e.to_string()));
Ok(v[id])
}
fn initialize_native_target<'a>(env: Env<'a>, _args: &[Term<'a>]) -> NifResult<Term<'a>> {
match unsafe { target::LLVM_InitializeNativeTarget() } {
0 => Ok(atoms::ok().encode(env)),
_ => Ok(atoms::error().encode(env)),
}
}
fn initialize_native_asm_printer<'a>(env: Env<'a>, _args: &[Term<'a>]) -> NifResult<Term<'a>> {
match unsafe { target::LLVM_InitializeNativeAsmPrinter() } {
0 => Ok(atoms::ok().encode(env)),
_ => Ok(atoms::error().encode(env)),
}
}
fn generate_code_nif<'a>(env: Env<'a>, _args: &[Term<'a>]) -> NifResult<Term<'a>> {
let llvm_error = 1;
let val1 = 32;
let val2 = 16;
// setup our builder and module
let builder = unsafe { LLVMCreateBuilder() };
let mod_name = CString::new("my_module").unwrap();
let module = unsafe { LLVMModuleCreateWithName(mod_name.as_ptr()) };
// create our function prologue
let function_type = unsafe {
let mut param_types = [];
LLVMFunctionType(LLVMInt32Type(), param_types.as_mut_ptr(), param_types.len() as u32, 0)
};
let function_name = CString::new("main").unwrap();
let function = unsafe { LLVMAddFunction(module, function_name.as_ptr(), function_type)};
let entry_name = CString::new("entry").unwrap();
let entry_block = unsafe { LLVMAppendBasicBlock(function, entry_name.as_ptr())};
unsafe { LLVMPositionBuilderAtEnd(builder, entry_block); }
// int a = 32
let a_name = CString::new("a").unwrap();
let a = unsafe { LLVMBuildAlloca(builder, LLVMInt32Type(), a_name.as_ptr())};
unsafe { LLVMBuildStore(builder, LLVMConstInt(LLVMInt32Type(), val1, 0), a); }
// int b = 16
let b_name = CString::new("b").unwrap();
let b = unsafe { LLVMBuildAlloca(builder, LLVMInt32Type(), b_name.as_ptr())};
unsafe { LLVMBuildStore(builder, LLVMConstInt(LLVMInt32Type(), val2, 0), b); }
// return a + b
let b_val_name = CString::new("b_val").unwrap();
let b_val = unsafe { LLVMBuildLoad(builder, b, b_val_name.as_ptr()) };
let a_val_name = CString::new("a_val").unwrap();
let a_val = unsafe { LLVMBuildLoad(builder, a, a_val_name.as_ptr()) };
let ab_val_name = CString::new("ab_val").unwrap();
unsafe {
let res = LLVMBuildAdd(builder, a_val, b_val, ab_val_name.as_ptr());
LLVMBuildRet(builder, res);
}
// verify it's all good
let mut error: *mut c_char = 0 as *mut c_char;
let ok = unsafe {
let buf: *mut *mut c_char = &mut error;
LLVMVerifyModule(module, LLVMVerifierFailureAction::LLVMReturnStatusAction, buf)
};
if ok == llvm_error {
let err_msg = unsafe { CString::from_raw(error).into_string().unwrap() };
panic!("cannot verify module '{:?}.\nError: {}", mod_name, err_msg);
}
// Clean up the builder now that we are finished using it.
unsafe { LLVMDisposeBuilder(builder) }
// Dump the LLVM IR to stdout so we can see what we've created
unsafe { LLVMDumpModule(module) }
match unsafe { write_vec_mut(&*module) } {
Ok(r) => Ok((atoms::ok(), r).encode(env)),
Err(_) => Ok((atoms::error(), atoms::error()).encode(env)),
}
}
fn execute_code_nif<'a>(env: Env<'a>, args: &[Term<'a>]) -> NifResult<Term<'a>> {
let id: usize = try!(args[0].decode());
match read_vec(id) {
Ok(m) => {
let module = m as *const LLVMModule as *mut LLVMModule;
let llvm_error = 1;
let val1 = 32;
let val2 = 16;
// create our exe engine
let mut error: *mut c_char = 0 as *mut c_char;
let mut engine: LLVMExecutionEngineRef = 0 as LLVMExecutionEngineRef;
let ok = unsafe {
let buf: *mut *mut c_char = &mut error;
let engine_ref: *mut LLVMExecutionEngineRef = &mut engine;
LLVMLinkInInterpreter();
LLVMCreateInterpreterForModule(engine_ref, module, buf)
};
if ok == llvm_error {
let err_msg = unsafe { CString::from_raw(error).into_string().unwrap() };
println!("Execution error: {}", err_msg);
} else {
// run the function!
let func_name = CString::new("main").unwrap();
let named_function = unsafe { LLVMGetNamedFunction(module, func_name.as_ptr()) };
let mut params = [];
let func_result = unsafe { LLVMRunFunction(engine, named_function, params.len() as u32, params.as_mut_ptr()) };
let result = unsafe { LLVMGenericValueToInt(func_result, 0) };
println!("{} + {} = {}", val1, val2, result);
}
// Clean up the module after we're done with it.
unsafe { LLVMDisposeModule(module) }
Ok(atoms::ok().encode(env))
},
Err(_) => Ok((atoms::error(), atoms::error()).encode(env)),
}
}