すでにトレーニングされている方がいらっしゃいました。
1.3Bパラメータの日本語GPTモデルを使用した指示に応答するAIです。VRAM 7GB または RAM 7GB が必要で、問題なく動作すると思われます。
rinna社の「japanese-gpt-1b」を、日本語データセット「databricks-dolly-15k-ja」を使用して学習させました。
sentencepieceをインストールします。
%pip install sentencepiece
Python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("inu-ai/dolly-japanese-gpt-1b", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("inu-ai/dolly-japanese-gpt-1b").to(device)
Python
MAX_ASSISTANT_LENGTH = 100
MAX_INPUT_LENGTH = 1024
INPUT_PROMPT = r'<s>\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n入力:\n{input}\n[SEP]\n応答:\n'
NO_INPUT_PROMPT = r'<s>\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n応答:\n'
def prepare_input(instruction, input_text):
if input_text != "":
prompt = INPUT_PROMPT.format(instruction=instruction, input=input_text)
else:
prompt = NO_INPUT_PROMPT.format(instruction=instruction)
return prompt
def format_output(output):
output = output.lstrip("<s>").rstrip("</s>").replace("[SEP]", "").replace("\\n", "\n")
return output
def generate_response(instruction, input_text):
prompt = prepare_input(instruction, input_text)
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
n = len(token_ids[0])
with torch.no_grad():
output_ids = model.generate(
token_ids.to(model.device),
min_length=n,
max_length=min(MAX_INPUT_LENGTH, n + MAX_ASSISTANT_LENGTH),
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
bad_words_ids=[[tokenizer.unk_token_id]]
)
output = tokenizer.decode(output_ids.tolist()[0])
formatted_output_all = format_output(output)
response = f"Assistant:{formatted_output_all.split('応答:')[-1].strip()}"
return formatted_output_all, response
instruction = "あなたは何でも正確に答えられるAIです。"
questions = [
"日本で一番高い山は?",
"日本で一番広い湖は?",
"世界で一番高い山は?",
"世界で一番広い湖は?",
"冗談を言ってください。",
]
# 各質問に対して応答を生成して表示
for question in questions:
formatted_output_all, response = generate_response(instruction, question)
print(response)
Assistant:富士山
Assistant:琵琶湖
Assistant:エベレストは、標高8,848メートルです。ネパールにある。
Assistant:バイカル湖は、世界で一番広い湖です。
Assistant:ええ、そうですね、私は何でも正確に答えられるAIです。
おおー、すごい、きちんと返ってきます。
Python
formatted_output_all, response = generate_response(instruction, "人工知能とは")
print(response)
Assistant:AIは、人間よりも正確で速いデータ分析を行い、より多くの問題を解決し、多くのタスクを自動化することができます。
Python
formatted_output_all, response = generate_response(instruction, "人工知能のメリットを教えてください")
print(response)
Assistant:AIは、人間よりも正確で速いデータ分析を行い、より多くの問題を解決し、多くのタスクを自動化することができます。
Python
formatted_output_all, response = generate_response(instruction, "人工知能のリスクを教えてください")
print(response)
Assistant:AIは、人間の能力を超えるかもしれない、と広く考えられています。AIは、人間の能力を超えると考えられています。しかし、AIが人間の能力を超えるかどうかは、人間が決定することではありません。現在のところ、AIは人間の能力を超えているのでしょうか?
オープンソース化することのメリットをしみじみと感じた週末でした。