要約
- wsl上でllm4decompileを動作させた
- RTX 3060 12GBで動作確認できたのはllm4decompile-6.7b-v2がギリギリ動作
- v1.5モデルではアセンブリ→cに v2モデルではデコンパイラ→cに変換する
- v2を試した
LLM4Decompileとは
https://github.com/albertan017/LLM4Decompile/tree/main
生成AIベースのデコンパイラ。
学習にはCをO0~3でコンパイルしたelf x86_64を使用。
検証環境
- NVIDIA RTX 3060 12GB
- WSL Ubuntu
- RAM 64GB
環境構築
vllmを使用します。wslでないと動作しないので注意
他の方法としてはDockerを使うなど
wsl上にcudaドライバを入れる
リポジトリのクローン
ここから先は任意のディレクトリで操作してください。
クローンしなくても動くとは思う。
git clone https://github.com/albertan017/LLM4Decompile
Pythonのセットアップ
README.mdだとcondaを使っているようだが、venvでも動いた。
python -m venv llm4dec
source llm4dec/bin/activate
依存関係をインストール
cd LLM4Decompile
pip install -r requirements.txt
Ghidraのインストール
すでにある人はスキップでも化
今回はGhidra 11.1.1を使用。
ついでにjavaも入れる。11.1までjdk17。11.2からはjdk21なので注意。
sudo apt install openjdk-17-jdk
これで大体の準備が終わった。
LLM4Decompileを使う
LLM4Decompile/ghidraにあるスクリプトを実行するとサンプルのコードを試せる。
6.7Bのモデルを指定しているのでVRAMが少ない場合はllm4decompile-1.3b-v2に変更する or google colaboratoryを使う。
あとghidra_path(特にバージョン)を修正すること。
demo.pyではLLM4Decompile/samples/sample.cを各最適化オプションでコンパイルしてGhidraをHeadlessモードで起動し、decompile.pyを実行しデコンパイル結果を一時的なファイルに保存する。
cd ghidra
python demo.py
pseudo function:
# This is the assembly code:
undefined8 func0(float param_1,long param_2,int param_3)
{
int local_10;
int local_c;
local_10 = 0;
do {
local_c = local_10;
if (param_3 <= local_10) {
return 0;
}
while (local_c = local_c + 1, local_c < param_3) {
if ((float)(DAT_001020d0 &
(uint)(*(float *)(param_2 + (long)local_10 * 4) -
*(float *)(param_2 + (long)local_c * 4))) < param_1) {
return 1;
}
}
local_10 = local_10 + 1;
} while( true );
}
# What is the source code?
refined function:
int func0(float *x, int n, float eps)
{
int i, j;
for (i = 0; i < n; i++)
for (j = i + 1; j < n; j++)
if (fabs(x[i] - x[j]) < eps)
return 1;
return 0;
}
見やすいですね
特定のファイルをデコンパイルする
ここらへんは正直おまけで困ったらLLM4Decompile/ghidra/demo.pyを見ながらコードを改造すれば大体やりたいことはできると思う。
executable_pathとtarget_funcを書き換えればおk。
すでにGhidraに読み込ませてある実行ファイルを使いたい場合はHeadlessモードのコマンドを変えればおk。
import tempfile
import os
import subprocess
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
timeout_duration = 10
ghidra_path = "./ghidra_11.1.1_PUBLIC/support/analyzeHeadless"#path to the headless analyzer, change the path accordingly
postscript = "./decompile.py"#path to the decompiler helper function, change the path accordingly
project_path = "."#path to temp folder for analysis, change the path accordingly
project_name = "tmp_ghidra_proj"
fileName = "dec"
executable_path = "../a.out"
target_func = "main"
# model_path = 'LLM4Binary/llm4decompile-9b-v2' # V2 Model
# model_path = 'LLM4Binary/llm4decompile-6.7b-v2' # V2 Model
model_path = 'LLM4Binary/llm4decompile-1.3b-v2' # V2 Model
with tempfile.TemporaryDirectory() as temp_dir:
pid = os.getpid()
asm_all = {}
output_path = os.path.join(temp_dir, f"{pid}_.c")
command = [
ghidra_path,
temp_dir,
project_name,
"-import", executable_path,
"-postScript", postscript, output_path,
"-deleteProject", # WARNING: This will delete the project after analysis
]
result = subprocess.run(command, text=True, capture_output=True, check=True)
with open(output_path,'r') as f:
c_decompile = f.read()
c_func = []
flag = 0
for line in c_decompile.split('\n'):
if f"Function: {target_func}" in line:#**Replace** func0 with the function name you want to decompile.
flag = 1
c_func.append(line)
continue
if flag:
if '// Function:' in line:
if len(c_func) > 1:
break
c_func.append(line)
if flag == 0:
raise ValueError('bad case no function found')
for idx_tmp in range(1,len(c_func)):##########remove the comments
if target_func in c_func[idx_tmp]:
break
c_func = c_func[idx_tmp:]
input_asm = '\n'.join(c_func).strip()
before = f"# This is the assembly code:\n"#prompt
after = "\n# What is the source code?\n"#prompt
input_asm_prompt = before+input_asm.strip()+after
with open(fileName + '.pseudo','w',encoding='utf-8') as f:
f.write(input_asm_prompt)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).cuda()
with open(fileName + '.pseudo','r') as f:#optimization level O0
asm_func = f.read()
inputs = tokenizer(asm_func, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=2048)### max length to 4096, max new tokens should be below the range
c_func_decompile = tokenizer.decode(outputs[0][len(inputs[0]):-1])
# with open(fileName + '.pseudo','r') as f:#original file
# func = f.read()
# print(f'pseudo function:\n{func}')# Note we only decompile one function, where the original file may contain multiple functions
print(f'refined function:\n{c_func_decompile}')
検証
元のコード
これをgcc -O2でコンパイルした。
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
int gen_rand() {
return rand() % 100;
}
int main() {
int *a = (int *)malloc(3000 * 3000 * sizeof(int));
int *b = (int *)malloc(3000 * 3000 * sizeof(int));
int *c = (int *)calloc(3000 * 3000, sizeof(int));
srand(0);
for (int i = 0; i < 3000 * 3000; i++) {
a[i] = gen_rand();
b[i] = gen_rand();
}
clock_t start_clock, end_clock;
start_clock = clock();
for (int i = 0; i < 3000; i++) {
for (int j = 0; j < 3000; j++) {
for (int k = 0; k < 3000; k++) {
c[j * 3000 + k] += a[j * 3000 + i] * b[i * 3000 + k];
}
}
}
end_clock = clock();
printf("Clock: %f\n", (double)(end_clock - start_clock) / CLOCKS_PER_SEC);
free(a);
free(b);
free(c);
return 0;
}
Before
undefined8 main(void)
{
ulong *puVar1;
int *piVar2;
int *piVar3;
uint uVar4;
undefined auVar5 [16];
ulong uVar6;
ulong uVar7;
int iVar8;
int iVar9;
ulong uVar10;
int iVar11;
uint *__ptr;
void *__ptr_00;
void *__ptr_01;
clock_t cVar12;
long lVar13;
clock_t cVar14;
void *pvVar15;
uint *puVar16;
uint *puVar17;
long lVar18;
lVar18 = 0;
__ptr = (uint *)malloc(36000000);
__ptr_00 = malloc(36000000);
__ptr_01 = calloc(9000000,4);
srand(0);
do {
iVar11 = rand();
*(int *)((long)__ptr + lVar18) = iVar11 % 100;
iVar11 = rand();
*(int *)((long)__ptr_00 + lVar18) = iVar11 % 100;
lVar18 = lVar18 + 4;
} while (lVar18 != 36000000);
cVar12 = clock();
lVar18 = 0;
puVar17 = __ptr;
do {
pvVar15 = __ptr_01;
puVar16 = puVar17;
do {
uVar4 = *puVar16;
lVar13 = 0;
auVar5._4_4_ = uVar4;
auVar5._0_4_ = uVar4;
auVar5._8_4_ = uVar4;
auVar5._12_4_ = uVar4;
do {
puVar1 = (ulong *)((long)__ptr_00 + lVar13 + lVar18 * 12000);
uVar6 = *puVar1;
uVar7 = puVar1[1];
puVar1 = (ulong *)((long)__ptr_00 + lVar13 + lVar18 * 12000);
uVar10 = puVar1[1];
piVar2 = (int *)((long)pvVar15 + lVar13);
iVar11 = piVar2[1];
iVar8 = piVar2[2];
iVar9 = piVar2[3];
piVar3 = (int *)((long)pvVar15 + lVar13);
*piVar3 = *piVar2 + (int)((*puVar1 & 0xffffffff) * (ulong)uVar4);
piVar3[1] = iVar11 + (int)((uVar6 >> 0x20) * (ulong)uVar4);
piVar3[2] = iVar8 + (int)((uVar10 & 0xffffffff) * (ulong)uVar4);
piVar3[3] = iVar9 + (int)((uVar7 >> 0x20) * (auVar5._8_8_ >> 0x20));
lVar13 = lVar13 + 0x10;
} while (lVar13 != 12000);
puVar16 = puVar16 + 3000;
pvVar15 = (void *)((long)pvVar15 + 12000);
} while (puVar16 != puVar17 + 9000000);
lVar18 = lVar18 + 1;
puVar17 = puVar17 + 1;
} while (lVar18 != 3000);
cVar14 = clock();
__printf_chk((double)(cVar14 - cVar12) / _DAT_00102010,1,"Clock: %f\n");
free(__ptr);
free(__ptr_00);
free(__ptr_01);
return 0;
}
After
int main(int argc, char *argv[])
{
int i, j, k;
int *a = malloc(3000 * 3000 * sizeof(int));
int *b = malloc(3000 * 3000 * sizeof(int));
int *c = calloc(3000 * 3000, sizeof(int));
srand(0);
for (i = 0; i < 3000 * 3000; i++) {
a[i] = rand() % 100;
b[i] = rand() % 100;
}
clock_t start = clock();
for (i = 0; i < 3000; i++) {
for (j = 0; j < 3000; j++) {
for (k = 0; k < 3000; k++) {
c[i * 3000 + j] += a[i * 3000 + k] * b[k * 3000 + j];
}
}
}
__printf_chk((double)(clock() - start) / CLOCKS_PER_SEC, 1, "Clock: %f\n");
free(a);
free(b);
free(c);
return 0;
}
大体あってそうに見えるけどループ内の配列のアドレスが元コードとは若干違った。
でもすごい!