7
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

DatabricksでRWKVのファインチューニングを試す

Last updated at Posted at 2023-04-15

こちらの続編です。

再びこちらを参考にさせていただいています。

ノートブックはこちら。

git-lfsのインストール

こちらをインストールしておかないと途中でエラーになります。

%sh
apt-get install git-lfs

ライブラリのインストール

%pip install transformers pytorch-lightning==1.7 deepspeed wandb ninja rwkv

トレーニングデータの取得

こちらを使わせていただいています。

Python
# 作業用ディレクトリ
rwkv_dir = "/tmp/takaaki.yayoi@databricks.com/rwkv"
rwkv_dir_local = "/dbfs/tmp/takaaki.yayoi@databricks.com/rwkv"

# データセットを移動
dbutils.fs.mv("dbfs:/FileStore/shared_uploads/takaaki.yayoi@databricks.com/dataset.txt", rwkv_dir)

モデルのダウンロード

%sh
git lfs clone https://github.com/blinkdl/RWKV-LM

カレントディレクトリを/databricks/driver/RWKV-LM/RWKV-v4に変更しておきます。

Python
import os
os.chdir("/databricks/driver/RWKV-LM/RWKV-v4")

ベースモデルの取得

Python
base_model_name = "RWKV-4-Pile-169M"
base_model_url = f"https://huggingface.co/BlinkDL/{base_model_name.lower()}"

print(base_model_url)

os.environ['base_model_url'] = base_model_url
%sh
# This may take a while
git clone $base_model_url
Python
from glob import glob
base_model_path = glob(f"{base_model_name.lower()}/{base_model_name}*.pth")[0]

print(f"Using {base_model_path} as base")
Using rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023.pth as base

トレーニングデータの準備

Python
import numpy as np
from transformers import PreTrainedTokenizerFast

# 事前にトークンファイルをダウンロードしてください
tokenizer = PreTrainedTokenizerFast(tokenizer_file=f'{rwkv_dir_local}/20B_tokenizer.json')

input_file = f"{rwkv_dir_local}/dataset.txt"
output_file = 'train.npy'

print(f'Tokenizing {input_file} (VERY slow. please wait)')

data_raw = open(input_file, encoding="utf-8").read()
print(f'Raw length = {len(data_raw)}')

data_code = tokenizer.encode(data_raw)
print(f'Tokenized length = {len(data_code)}')

out = np.array(data_code, dtype='uint16')
np.save(output_file, out, allow_pickle=False)

トレーニング

Python
tuned_model_name = "tuned"
output_path = "rwkv-v4-rnn-pile-tuning"

os.mkdir(output_path)
Python
#@title Training Options { display-mode: "form" }
from shutil import copy

def training_options():
    EXPRESS_PILE_MODE = True
    EXPRESS_PILE_MODEL_NAME = base_model_path.split(".")[0]
    EXPRESS_PILE_MODEL_TYPE = base_model_name
    n_epoch = 100 #@param {type:"integer"}
    epoch_save_frequency = 25 #@param {type:"integer"}
    batch_size =  11#@param {type:"integer"} 
    ctx_len = 384 #@param {type:"integer"}
    epoch_save_path = f"{output_path}/{tuned_model_name}"
    return locals()

def model_options():
    T_MAX = 384 #@param {type:"integer"}
    return locals()

def env_vars():
    RWKV_FLOAT_MODE = 'fp16' #@param ['fp16', 'bf16', 'bf32'] {type:"string"}
    RWKV_DEEPSPEED = '1' #@param ['0', '1'] {type:"string"}
    return {f"os.environ['{key}']": value for key, value in locals().items()}

def replace_lines(file_name, to_replace):
    with open(file_name, 'r') as f:
        lines = f.readlines()
    with open(f'{file_name}.tmp', 'w') as f:
        for line in lines:
            key = line.split(" =")[0]
            if key.strip() in to_replace:
                value = to_replace[key.strip()]
                if isinstance(value, str):
                    f.write(f'{key} = "{value}"\n')
                else:
                    f.write(f'{key} = {value}\n')
            else:
                f.write(line)
    copy(f'{file_name}.tmp', file_name)
    os.remove(f'{file_name}.tmp')

values = training_options()
values.update(env_vars())
replace_lines('train.py', values)
replace_lines('src/model.py', model_options())
%sh
python train.py

n_epoch過ぎても止まらないので、十分学習できたら自分で停止します。

%sh
ls rwkv-v4-rnn-pile-tuning
tuned1.pth

推論

Python
os.environ['RWKV_JIT_ON'] = '1'
os.environ["RWKV_CUDA_ON"] = '0'
Python
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS

# モデルとパイプラインの準備
model = RWKV(
    model="/databricks/driver/RWKV-LM/RWKV-v4/rwkv-v4-rnn-pile-tuning/tuned1", 
    strategy="cuda fp16")
pipeline = PIPELINE(model, "/dbfs/tmp/takaaki.yayoi@databricks.com/rwkv/20B_tokenizer.json")
Python
# パイプライン引数の準備
args = PIPELINE_ARGS(
    temperature = 1.0,
    top_p = 0.7, 
    top_k = 100, 
    alpha_frequency = 0.25, 
    alpha_presence = 0.25, 
    token_ban = [],
    token_stop = [0],
    chunk_len = 256) 
Python
# Instructプロンプトの生成
def generate_prompt(instruction, input=None):
    if input:
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

# Instruction:
{instruction}

# Input:
{input}

# Response:
"""
    else:
        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.

# Instruction:
{instruction}

# Response:
"""
Python
# プロンプトの準備
prompt = "セガサターンほしいです"
print(prompt)

# Instructプロンプトの生成
prompt = generate_prompt(prompt)
print("--[prompt]--\n" + prompt + "----")

# パイプラインの実行
result = pipeline.generate(prompt, token_count=200, args=args)
print(result)
セガサターンほしいです
--[prompt]--
Below is an instruction that describes a task. Write a response that appropriately completes the request.

# Instruction:
セガサターンほしいです

# Response:
----
セガサターンは涼宮茜ちゃんに向かうよね?(セガサターンは茜学生にしろよね~)

セガサターンは萌えを期待しています。茜学生はより萌え的ところでよっぽど peings on it’s feelings.茜学生は喫茉の声優であった。茜学生はきょっicaさんへ移動…。

1エポックしかトレーニングしてないので精度はあれですが、元データが反映されていることがわかります

続きはこちら。

Databricksクイックスタートガイド

Databricksクイックスタートガイド

Databricks無料トライアル

Databricks無料トライアル

7
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?