はじめに
HuggingFaceに公開された日本語LLMを使ってみて、ChatBotを作りたいな〜と思ったらそれなりに簡単にできたので、まとめてみます。
ダウンロード
以下のURLに詳細が記載されています。
以下のコマンドでインストールします。
git lfs install
git clone https://huggingface.co/cyberagent/calm2-7b-chat
クローンを実行したディレクトリにcalm2-7b-chatというディレクトリが作成されます。
動作環境
- OS Ubuntu 22.04
- CPU Intel® Xeon(R) CPU E5-1680 v3 @ 3.20GHz × 16
- RAM 120GB
- GPU NVIDIA Corporation GP102GL [Quadro P6000]
試しに動かしてみる
huggingfaceのページにサンプルコードがあるので、それを実行してみます。
sample.py
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
assert transformers.__version__ >= "4.34.1"
model = AutoModelForCausalLM.from_pretrained("./calm2-7b-chat", device_map="auto", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained("./calm2-7b-chat")
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
prompt = """USER: AIによって私達の暮らしはどのように変わりますか?
ASSISTANT: """
token_ids = tokenizer.encode(prompt, return_tensors="pt")
output_ids = model.generate(
input_ids=token_ids.to(model.device),
max_new_tokens=300,
do_sample=True,
temperature=0.8,
streamer=streamer,
)
以下のコマンドで実行。
python3 sample.py
Flaskを使ってWebアプリケーションにする
必要なライブラリをインストール
sudo apt install -y python3-flask python3-flask-socketio
Flaskアプリケーションを作成。
app.py
from flask import Flask, render_template
from flask_socketio import SocketIO
from threading import Thread
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, TextIteratorStreamer
assert transformers.__version__ >= "4.34.1"
model = AutoModelForCausalLM.from_pretrained(
"./calm2-7b-chat",
device_map="auto",
torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained("./calm2-7b-chat")
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
app = Flask(__name__)
app.config['SECRET_KEY'] = '*****'
socketio = SocketIO(app)
@app.route('/')
def index():
return render_template('index.html')
@socketio.on('message')
def handle_message(message):
prompt = "USER: {}\nASSISTANT: ".format(message)
print('received message: ' + message)
socketio.emit('message', "<div>{}</div><div>".format(message))
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
generated_text = ""
token_ids = tokenizer.encode(prompt, return_tensors="pt")
generation_kwargs = dict(
input_ids=token_ids.to(model.device),
max_new_tokens=300,
do_sample=True,
temperature=0.8,
streamer=streamer,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
generated_text += new_text
socketio.emit('message', new_text.replace("\n","<br />"))
socketio.emit('message', "</div>")
if __name__ == '__main__':
socketio.run(app, host='0.0.0.0')
表示用テンプレートファイルを作成。
templates/index.html
<!DOCTYPE html>
<html>
<head>
<title>Chat App</title>
<script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
<script src="https://cdn.socket.io/4.7.2/socket.io.min.js" integrity="sha384-mZLF4UVrpi/QTWPA7BjNPEnkIfRFn4ZEO3Qt/HFklTJBj/gBOV8G3HcKn4NfQblz" crossorigin="anonymous"></script>
<script type="text/javascript">
$(document).ready(function() {
var socket = io.connect('http://' + document.domain + ':' + location.port);
socket.on('message', function(msg) {
$('#messages').append(msg);
});
$('#sendbutton').on('click', function() {
var message = $('#myMessage').val();
socket.emit('message', message);
$('#myMessage').val('');
});
});
</script>
</head>
<body>
<ul id="messages"></ul>
<input type="text" id="myMessage">
<button id="sendbutton">Send</button>
</body>
</html>
以下のコマンドで実行。
python3 app.py
実行したら「http://localhost:5000/」にアクセス。
できた!
楽しい!