こちらの続編です。
再びこちらを参考にさせていただいています。
ノートブックはこちら。
git-lfsのインストール
こちらをインストールしておかないと途中でエラーになります。
%sh
apt-get install git-lfs
ライブラリのインストール
%pip install transformers pytorch-lightning==1.7 deepspeed wandb ninja rwkv
トレーニングデータの取得
こちらを使わせていただいています。
Python
# 作業用ディレクトリ
rwkv_dir = "/tmp/takaaki.yayoi@databricks.com/rwkv"
rwkv_dir_local = "/dbfs/tmp/takaaki.yayoi@databricks.com/rwkv"
# データセットを移動
dbutils.fs.mv("dbfs:/FileStore/shared_uploads/takaaki.yayoi@databricks.com/dataset.txt", rwkv_dir)
モデルのダウンロード
%sh
git lfs clone https://github.com/blinkdl/RWKV-LM
カレントディレクトリを/databricks/driver/RWKV-LM/RWKV-v4
に変更しておきます。
Python
import os
os.chdir("/databricks/driver/RWKV-LM/RWKV-v4")
ベースモデルの取得
Python
base_model_name = "RWKV-4-Pile-169M"
base_model_url = f"https://huggingface.co/BlinkDL/{base_model_name.lower()}"
print(base_model_url)
os.environ['base_model_url'] = base_model_url
%sh
# This may take a while
git clone $base_model_url
Python
from glob import glob
base_model_path = glob(f"{base_model_name.lower()}/{base_model_name}*.pth")[0]
print(f"Using {base_model_path} as base")
Using rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023.pth as base
トレーニングデータの準備
Python
import numpy as np
from transformers import PreTrainedTokenizerFast
# 事前にトークンファイルをダウンロードしてください
tokenizer = PreTrainedTokenizerFast(tokenizer_file=f'{rwkv_dir_local}/20B_tokenizer.json')
input_file = f"{rwkv_dir_local}/dataset.txt"
output_file = 'train.npy'
print(f'Tokenizing {input_file} (VERY slow. please wait)')
data_raw = open(input_file, encoding="utf-8").read()
print(f'Raw length = {len(data_raw)}')
data_code = tokenizer.encode(data_raw)
print(f'Tokenized length = {len(data_code)}')
out = np.array(data_code, dtype='uint16')
np.save(output_file, out, allow_pickle=False)
トレーニング
Python
tuned_model_name = "tuned"
output_path = "rwkv-v4-rnn-pile-tuning"
os.mkdir(output_path)
Python
#@title Training Options { display-mode: "form" }
from shutil import copy
def training_options():
EXPRESS_PILE_MODE = True
EXPRESS_PILE_MODEL_NAME = base_model_path.split(".")[0]
EXPRESS_PILE_MODEL_TYPE = base_model_name
n_epoch = 100 #@param {type:"integer"}
epoch_save_frequency = 25 #@param {type:"integer"}
batch_size = 11#@param {type:"integer"}
ctx_len = 384 #@param {type:"integer"}
epoch_save_path = f"{output_path}/{tuned_model_name}"
return locals()
def model_options():
T_MAX = 384 #@param {type:"integer"}
return locals()
def env_vars():
RWKV_FLOAT_MODE = 'fp16' #@param ['fp16', 'bf16', 'bf32'] {type:"string"}
RWKV_DEEPSPEED = '1' #@param ['0', '1'] {type:"string"}
return {f"os.environ['{key}']": value for key, value in locals().items()}
def replace_lines(file_name, to_replace):
with open(file_name, 'r') as f:
lines = f.readlines()
with open(f'{file_name}.tmp', 'w') as f:
for line in lines:
key = line.split(" =")[0]
if key.strip() in to_replace:
value = to_replace[key.strip()]
if isinstance(value, str):
f.write(f'{key} = "{value}"\n')
else:
f.write(f'{key} = {value}\n')
else:
f.write(line)
copy(f'{file_name}.tmp', file_name)
os.remove(f'{file_name}.tmp')
values = training_options()
values.update(env_vars())
replace_lines('train.py', values)
replace_lines('src/model.py', model_options())
%sh
python train.py
n_epoch過ぎても止まらないので、十分学習できたら自分で停止します。
%sh
ls rwkv-v4-rnn-pile-tuning
tuned1.pth
推論
Python
os.environ['RWKV_JIT_ON'] = '1'
os.environ["RWKV_CUDA_ON"] = '0'
Python
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
# モデルとパイプラインの準備
model = RWKV(
model="/databricks/driver/RWKV-LM/RWKV-v4/rwkv-v4-rnn-pile-tuning/tuned1",
strategy="cuda fp16")
pipeline = PIPELINE(model, "/dbfs/tmp/takaaki.yayoi@databricks.com/rwkv/20B_tokenizer.json")
Python
# パイプライン引数の準備
args = PIPELINE_ARGS(
temperature = 1.0,
top_p = 0.7,
top_k = 100,
alpha_frequency = 0.25,
alpha_presence = 0.25,
token_ban = [],
token_stop = [0],
chunk_len = 256)
Python
# Instructプロンプトの生成
def generate_prompt(instruction, input=None):
if input:
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
# Instruction:
{instruction}
# Input:
{input}
# Response:
"""
else:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
# Instruction:
{instruction}
# Response:
"""
Python
# プロンプトの準備
prompt = "セガサターンほしいです"
print(prompt)
# Instructプロンプトの生成
prompt = generate_prompt(prompt)
print("--[prompt]--\n" + prompt + "----")
# パイプラインの実行
result = pipeline.generate(prompt, token_count=200, args=args)
print(result)
セガサターンほしいです
--[prompt]--
Below is an instruction that describes a task. Write a response that appropriately completes the request.
# Instruction:
セガサターンほしいです
# Response:
----
セガサターンは涼宮茜ちゃんに向かうよね?(セガサターンは茜学生にしろよね~)
セガサターンは萌えを期待しています。茜学生はより萌え的ところでよっぽど peings on it’s feelings.茜学生は喫茉の声優であった。茜学生はきょっicaさんへ移動…。
1エポックしかトレーニングしてないので精度はあれですが、元データが反映されていることがわかります
続きはこちら。