0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LLM4Decompileを使う

Last updated at Posted at 2024-12-23

要約

  • wsl上でllm4decompileを動作させた
  • RTX 3060 12GBで動作確認できたのはllm4decompile-6.7b-v2がギリギリ動作
  • v1.5モデルではアセンブリ→cに v2モデルではデコンパイラ→cに変換する
    • v2を試した

LLM4Decompileとは

https://github.com/albertan017/LLM4Decompile/tree/main
生成AIベースのデコンパイラ。
学習にはCをO0~3でコンパイルしたelf x86_64を使用。

検証環境

  • NVIDIA RTX 3060 12GB
  • WSL Ubuntu
  • RAM 64GB

環境構築

vllmを使用します。wslでないと動作しないので注意
他の方法としてはDockerを使うなど

wsl上にcudaドライバを入れる

リポジトリのクローン

ここから先は任意のディレクトリで操作してください。
クローンしなくても動くとは思う。

git clone https://github.com/albertan017/LLM4Decompile

Pythonのセットアップ

README.mdだとcondaを使っているようだが、venvでも動いた。

python -m venv llm4dec
source llm4dec/bin/activate

依存関係をインストール

cd LLM4Decompile
pip install -r requirements.txt

Ghidraのインストール

すでにある人はスキップでも化
今回はGhidra 11.1.1を使用。
ついでにjavaも入れる。11.1までjdk17。11.2からはjdk21なので注意。

sudo apt install openjdk-17-jdk

これで大体の準備が終わった。

LLM4Decompileを使う

LLM4Decompile/ghidraにあるスクリプトを実行するとサンプルのコードを試せる。
6.7Bのモデルを指定しているのでVRAMが少ない場合はllm4decompile-1.3b-v2に変更する or google colaboratoryを使う。
あとghidra_path(特にバージョン)を修正すること。
demo.pyではLLM4Decompile/samples/sample.cを各最適化オプションでコンパイルしてGhidraをHeadlessモードで起動し、decompile.pyを実行しデコンパイル結果を一時的なファイルに保存する。

cd ghidra
python demo.py
pseudo function:
# This is the assembly code:
undefined8 func0(float param_1,long param_2,int param_3)

{
  int local_10;
  int local_c;
  
  local_10 = 0;
  do {
    local_c = local_10;
    if (param_3 <= local_10) {
      return 0;
    }
    while (local_c = local_c + 1, local_c < param_3) {
      if ((float)(DAT_001020d0 &
                 (uint)(*(float *)(param_2 + (long)local_10 * 4) -
                       *(float *)(param_2 + (long)local_c * 4))) < param_1) {
        return 1;
      }
    }
    local_10 = local_10 + 1;
  } while( true );
}
# What is the source code?

refined function:
int func0(float *x, int n, float eps)
{
    int i, j;
    for (i = 0; i < n; i++)
        for (j = i + 1; j < n; j++)
            if (fabs(x[i] - x[j]) < eps)
                return 1;
    return 0;
}

見やすいですね

特定のファイルをデコンパイルする

ここらへんは正直おまけで困ったらLLM4Decompile/ghidra/demo.pyを見ながらコードを改造すれば大体やりたいことはできると思う。

executable_pathとtarget_funcを書き換えればおk。
すでにGhidraに読み込ませてある実行ファイルを使いたい場合はHeadlessモードのコマンドを変えればおk。

import tempfile
import os
import subprocess
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

timeout_duration = 10

ghidra_path = "./ghidra_11.1.1_PUBLIC/support/analyzeHeadless"#path to the headless analyzer, change the path accordingly
postscript = "./decompile.py"#path to the decompiler helper function, change the path accordingly
project_path = "."#path to temp folder for analysis, change the path accordingly
project_name = "tmp_ghidra_proj"
fileName = "dec"
executable_path = "../a.out"
target_func = "main"
# model_path = 'LLM4Binary/llm4decompile-9b-v2' # V2 Model
# model_path = 'LLM4Binary/llm4decompile-6.7b-v2' # V2 Model
model_path = 'LLM4Binary/llm4decompile-1.3b-v2' # V2 Model

with tempfile.TemporaryDirectory() as temp_dir:
    pid = os.getpid()
    asm_all = {}
    output_path = os.path.join(temp_dir, f"{pid}_.c")
    command = [
        ghidra_path,
        temp_dir,
        project_name,
        "-import", executable_path,
        "-postScript", postscript, output_path,
        "-deleteProject",  # WARNING: This will delete the project after analysis
    ]
    result = subprocess.run(command, text=True, capture_output=True, check=True)
    with open(output_path,'r') as f:
        c_decompile = f.read()
    c_func = []
    flag = 0
    for line in c_decompile.split('\n'):
        if f"Function: {target_func}" in line:#**Replace** func0 with the function name you want to decompile.
            flag = 1
            c_func.append(line)
            continue
        if flag:
            if '// Function:' in line:
                if len(c_func) > 1:
                    break
            c_func.append(line)
    if flag == 0:
        raise ValueError('bad case no function found')
    for idx_tmp in range(1,len(c_func)):##########remove the comments
        if target_func in c_func[idx_tmp]:
            break
    c_func = c_func[idx_tmp:]
    input_asm = '\n'.join(c_func).strip()

    before = f"# This is the assembly code:\n"#prompt
    after = "\n# What is the source code?\n"#prompt
    input_asm_prompt = before+input_asm.strip()+after
    with open(fileName + '.pseudo','w',encoding='utf-8') as f:
        f.write(input_asm_prompt)



tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).cuda()

with open(fileName + '.pseudo','r') as f:#optimization level O0
    asm_func = f.read()
inputs = tokenizer(asm_func, return_tensors="pt").to(model.device)
with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=2048)### max length to 4096, max new tokens should be below the range
c_func_decompile = tokenizer.decode(outputs[0][len(inputs[0]):-1])

# with open(fileName + '.pseudo','r') as f:#original file
#     func = f.read()

# print(f'pseudo function:\n{func}')# Note we only decompile one function, where the original file may contain multiple functions
print(f'refined function:\n{c_func_decompile}')

検証

元のコード
これをgcc -O2でコンパイルした。

#include <stdio.h>
#include <stdlib.h>
#include <time.h>

int gen_rand() {
    return rand() % 100;
}

int main() {
    int *a = (int *)malloc(3000 * 3000 * sizeof(int));
    int *b = (int *)malloc(3000 * 3000 * sizeof(int));
    int *c = (int *)calloc(3000 * 3000, sizeof(int));
    
    srand(0);
    for (int i = 0; i < 3000 * 3000; i++) {
        a[i] = gen_rand();
        b[i] = gen_rand();
    }

    clock_t start_clock, end_clock;
    start_clock = clock();
    for (int i = 0; i < 3000; i++) {
        for (int j = 0; j < 3000; j++) {
            for (int k = 0; k < 3000; k++) {
                c[j * 3000 + k] += a[j * 3000 + i] * b[i * 3000 + k];
            }
        }
    }
    end_clock = clock();
    printf("Clock: %f\n", (double)(end_clock - start_clock) / CLOCKS_PER_SEC);

    free(a);
    free(b);
    free(c);

    return 0;
}

Before

undefined8 main(void)

{
  ulong *puVar1;
  int *piVar2;
  int *piVar3;
  uint uVar4;
  undefined auVar5 [16];
  ulong uVar6;
  ulong uVar7;
  int iVar8;
  int iVar9;
  ulong uVar10;
  int iVar11;
  uint *__ptr;
  void *__ptr_00;
  void *__ptr_01;
  clock_t cVar12;
  long lVar13;
  clock_t cVar14;
  void *pvVar15;
  uint *puVar16;
  uint *puVar17;
  long lVar18;
  
  lVar18 = 0;
  __ptr = (uint *)malloc(36000000);
  __ptr_00 = malloc(36000000);
  __ptr_01 = calloc(9000000,4);
  srand(0);
  do {
    iVar11 = rand();
    *(int *)((long)__ptr + lVar18) = iVar11 % 100;
    iVar11 = rand();
    *(int *)((long)__ptr_00 + lVar18) = iVar11 % 100;
    lVar18 = lVar18 + 4;
  } while (lVar18 != 36000000);
  cVar12 = clock();
  lVar18 = 0;
  puVar17 = __ptr;
  do {
    pvVar15 = __ptr_01;
    puVar16 = puVar17;
    do {
      uVar4 = *puVar16;
      lVar13 = 0;
      auVar5._4_4_ = uVar4;
      auVar5._0_4_ = uVar4;
      auVar5._8_4_ = uVar4;
      auVar5._12_4_ = uVar4;
      do {
        puVar1 = (ulong *)((long)__ptr_00 + lVar13 + lVar18 * 12000);
        uVar6 = *puVar1;
        uVar7 = puVar1[1];
        puVar1 = (ulong *)((long)__ptr_00 + lVar13 + lVar18 * 12000);
        uVar10 = puVar1[1];
        piVar2 = (int *)((long)pvVar15 + lVar13);
        iVar11 = piVar2[1];
        iVar8 = piVar2[2];
        iVar9 = piVar2[3];
        piVar3 = (int *)((long)pvVar15 + lVar13);
        *piVar3 = *piVar2 + (int)((*puVar1 & 0xffffffff) * (ulong)uVar4);
        piVar3[1] = iVar11 + (int)((uVar6 >> 0x20) * (ulong)uVar4);
        piVar3[2] = iVar8 + (int)((uVar10 & 0xffffffff) * (ulong)uVar4);
        piVar3[3] = iVar9 + (int)((uVar7 >> 0x20) * (auVar5._8_8_ >> 0x20));
        lVar13 = lVar13 + 0x10;
      } while (lVar13 != 12000);
      puVar16 = puVar16 + 3000;
      pvVar15 = (void *)((long)pvVar15 + 12000);
    } while (puVar16 != puVar17 + 9000000);
    lVar18 = lVar18 + 1;
    puVar17 = puVar17 + 1;
  } while (lVar18 != 3000);
  cVar14 = clock();
  __printf_chk((double)(cVar14 - cVar12) / _DAT_00102010,1,"Clock: %f\n");
  free(__ptr);
  free(__ptr_00);
  free(__ptr_01);
  return 0;
}

After

int main(int argc, char *argv[])
{
    int i, j, k;
    int *a = malloc(3000 * 3000 * sizeof(int));
    int *b = malloc(3000 * 3000 * sizeof(int));
    int *c = calloc(3000 * 3000, sizeof(int));

    srand(0);

    for (i = 0; i < 3000 * 3000; i++) {
        a[i] = rand() % 100;
        b[i] = rand() % 100;
    }

    clock_t start = clock();

    for (i = 0; i < 3000; i++) {
        for (j = 0; j < 3000; j++) {
            for (k = 0; k < 3000; k++) {
                c[i * 3000 + j] += a[i * 3000 + k] * b[k * 3000 + j];
            }
        }
    }

    __printf_chk((double)(clock() - start) / CLOCKS_PER_SEC, 1, "Clock: %f\n");

    free(a);
    free(b);
    free(c);

    return 0;
}

大体あってそうに見えるけどループ内の配列のアドレスが元コードとは若干違った。
でもすごい!

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?