30
22

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

生成AIプロダクトAdvent Calendar 2024

Day 2

自分のPCオンリーでキャラクターと音声対話がしたい!ローカルで動くspeech-to-speechサーバーを作る方法

Last updated at Posted at 2024-12-02

こんにちは!逆瀬川 ( https://x.com/gyakuse ) です!
アドベントカレンダー2日目です!これがあと23日続くのか…?
今日はこれ (の裏側部分) を作っていきます。表側は ChatVRMという pixiv さんのめちゃ神アプリケーションです。

GbAivADbQAIHJPn.jpeg

今回の記事について

今日はspeech-to-speechサーバーについて紹介します。

さいきんOpenAIのRealtime APIが公開された通り、世は音声対話システム時代です。ですが、Realtime APIは高額なので、安くできたらうれしいです。あと、自分のPCで動いたら、便利。

ということで、今日はspeech-to-speechサーバーを作り、ChatVRMと連結させて自分専用かつ自分のPCのみで動くアシスタントAIを作っていきます。なお、Realtime APIとは違い、end-to-endなモデルではなく、昔ながらのASR->LLM->TTS連結機構です。

作ったもの

今回の記事で得られる知識

  • 言語モデル、音声認識モデル、音声合成モデルのかんたんな使い方

何をするのか

ChatVRMはNextJS + OpenAI APIで構築されたキャラクターと音声対話が行えるアプリケーションです。
今回は、OpenAI API部分をそっくり消して、Realtime API的な独自実装したサーバーと連結させます。

嬉しさについて

  • Realtime APIは高いので、やすくできて便利
  • サーバー側の実装自体は 341行 なので、誰でも触って遊べる
  • 今回は工夫せずシンプル推論ですがちゃんとやると、ASR, LLM, TTS, プロンプトエンジニアリング, 等々, 非常に幅広い領域を理解できてお得 (プログラミングやりだしたばかりの人におすすめ!)

実装

それではちゃちゃっと実装していきます。

Webhookサーバー

今回の核心部分です。
めちゃくちゃコード汚くてあれですが、ここでwebhookサーバーを立ち上げ、webrtcvadを使ってVADを行っています。webrtcvadは10ms, 20ms, 30msのいずれかの大きさの音声チャンクをもとに発話かどうかを判定します。
一方で、1回だけこれが発話判定を行っても、それは誤検知の可能性を有します。
誤検知には以下のようなものがあります

  • 非発話
    • せき、くしゃみ、拍手
      • 300ms超えることがある
  • 同じデバイスで再生しているYouTubeの人間の発話
    • エコーキャンセルをしていてもちょくちょく100ms程度の長さでTrueが継続するが、すぐFalseになる
  • 違うデバイスで再生しているYouTubeの人間の発話
    • 誤って音声認識する

人の発話時間は「うん」などでは400msであるため「うん」などの相槌も検出したい場合、うまい具合にくしゃみ等をノイズキャンセルする必要があります。

ここでやっていること

  • webhookサーバーの立ち上げ
  • 音声の発話検出
  • 発話の検出が一定時間以上続いた場合、実際にユーザーが発話しているとみなす
  • 同じように一定時間以上非発話が続いた場合、発話の終了とみなす (Realtime APIなどではここのしきい値が200msになっています)
  • 発話終了後、ASRサーバーで書き起こしにし、LLMサーバーで返答を生成し、TTSサーバーで音声合成

具体的な実装は以下です。

import asyncio
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import webrtcvad
import math
import aiohttp
import logging
import json

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

app = FastAPI()

class Session:
    def __init__(self, websocket):
        self.websocket = websocket
        self.session_options = None
        self.vad = webrtcvad.Vad(3)
        self.sample_rate = 16000
        self.frame_duration = 20
        self.frame_size = int(self.sample_rate * self.frame_duration / 1000)
        self.frame_bytes = self.frame_size * 2
        self.min_speech_frames = 5
        self.buffer = bytes()
        self.audio_frames = []
        self.silence_counter = 0
        self.silence_duration = 200
        self.silence_frames = math.ceil(self.silence_duration / self.frame_duration)
        self.user_speech_started = False
        self.user_speech_to_assistant_speech_task = None  # stt (audio to text) -> llm (user text to assistant text) -> tts (assistant text to speech)
        self.user_text_to_assistant_speech_task = None  # llm (user text to assistant text) -> tts (assistant text to speech)
        self.stt_model_id = "openai/whisper-large-v3-turbo"
        self.llm_model_id = "google/gemma-2-2b-jpn-it"
        self.tts_model_id = "litagin/style_bert_vits2_jvnv"
        self.messages = []

    async def handle_text_message(self, data):
        try:
            message = json.loads(data)
            event = message.get('event')
            if event == 'startSession':
                self.session_options = message
                self.stt_model_id = message.get('sttModelId')
                self.llm_model_id = message.get('llmModelId')
                self.tts_model_id = message.get('ttsModelId')
                self.messages = message.get('messages', [])
            elif event == 'userMessage':
                message_text = message.get('message')
                # 実行中のタスクをキャンセル
                if self.user_text_to_assistant_speech_task and not self.user_text_to_assistant_speech_task.done():
                    self.user_text_to_assistant_speech_task.cancel()
                    try:
                        await self.user_text_to_assistant_speech_task
                    except asyncio.CancelledError:
                        pass
                if message_text:
                    self.user_text_to_assistant_speech_task = asyncio.create_task(self.user_message_to_assistant_speech(message_text))
        except json.JSONDecodeError:
            logger.error('JSONメッセージの解析に失敗しました')

    async def add_audio(self, audio_chunk):
        self.buffer += audio_chunk
        while len(self.buffer) >= self.frame_bytes:
            frame = self.buffer[: self.frame_bytes]
            self.buffer = self.buffer[self.frame_bytes :]
            is_speech = self.vad.is_speech(frame, self.sample_rate)
            if is_speech:
                self.audio_frames.append(frame)
                self.silence_counter = 0
                if not self.user_speech_started and len(self.audio_frames) > self.min_speech_frames:
                    self.user_speech_started = True
                    asyncio.create_task(self.websocket.send_json({"event": "userSpeechStart"}))
                    """
                    # 既存のタスクが実行中の場合、キャンセル
                    if self.user_speech_to_assistant_speech_task and not self.user_speech_to_assistant_speech_task.done():
                        self.user_speech_to_assistant_speech_task.cancel()
                        try:
                            await self.user_speech_to_assistant_speech_task
                        except asyncio.CancelledError:
                            pass
                    """
            elif self.user_speech_started:
                self.silence_counter += 1
                if self.silence_counter >= self.silence_frames:
                    logger.info(f"len(audio_frames) is {len(self.audio_frames)}")
                    audio_data = b"".join(self.audio_frames)
                    self.user_speech_to_assistant_speech_task = asyncio.create_task(self.user_speech_to_assistant_speech(audio_data))
                    asyncio.create_task(self.websocket.send_json({"event": "userSpeechEnd"}))
                    self.audio_frames = []
                    self.user_speech_started = False
            else:
                self.silence_counter = 0

    async def process_stt(self, audio_data):
        async with aiohttp.ClientSession() as session:
            async with session.post("http://localhost:8001/stt/bytes", data=audio_data) as resp:
                if resp.status == 200:
                    result = await resp.json()
                    text = result.get("text")
                    if text:
                        # 無視するフレーズの処理
                        ignore_phrases = {"ご視聴ありがとうございました", "ご視聴ありがとうございました。", "ありがとうございました", "ありがとうございました。", "はい", "", "2", ""}
                        if text not in ignore_phrases:
                            return text
                else:
                    logger.error(f"STTサーバーのリクエストがステータス{resp.status}で失敗しました")
                    return None

    async def process_llm(self, text):
        async with aiohttp.ClientSession() as session:
            send_messages = self.messages + [{"role": "user", "content": text}]
            async with session.post("http://localhost:8002/llm", json=send_messages) as resp:
                if resp.status == 200:
                    llm_response = await resp.json()
                    llm_text = llm_response["message"]
                    return llm_text
                else:
                    logger.error(f"LLMリクエストがステータス{resp.status}で失敗しました")
                    return None

    async def process_tts(self, text):
        async with aiohttp.ClientSession() as session:
            async with session.post("http://localhost:8003/tts", json={"text": text}) as resp:
                if resp.status == 200:
                    audio_bytes = await resp.read()
                    return audio_bytes
                else:
                    logger.error(f"TTSリクエストがステータス{resp.status}で失敗しました")
                    return None

    async def user_speech_to_assistant_speech(self, audio_data):
        logger.info('user_speech_to_assistant_speech')
        transcript = await self.process_stt(audio_data)
        if not transcript:
            return
        await self.websocket.send_json({"event": "userSpeechTranscript", "transcript": transcript})
        llm_text = await self.process_llm(transcript)
        if not llm_text:
            return
        await self.websocket.send_json({"event": "assistantMessageGenerated", "generatedMessageContent": llm_text})
        audio_bytes = await self.process_tts(llm_text)
        if not audio_bytes:
            return
        await self.websocket.send_json({"event": "assistantSpeechGenerated", "audioData": list(audio_bytes)})
        self.user_speech_to_assistant_speech_task = None

    async def user_message_to_assistant_speech(self, message_text):
        llm_text = await self.process_llm(message_text)
        if not llm_text:
            return
        await self.websocket.send_json({"event": "assistantMessageGenerated", "generatedMessageContent": llm_text})
        audio_bytes = await self.process_tts(llm_text)
        if not audio_bytes:
            return
        await self.websocket.send_json({"event": "assistantSpeechGenerated", "audioData": list(audio_bytes)})
        self.user_text_to_assistant_speech_task = None

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    session = Session(websocket)
    try:
        while True:
            message = await websocket.receive()
            if 'text' in message:
                data = message['text']
                await session.handle_text_message(data)
            elif 'bytes' in message:
                data = message['bytes']
                await session.add_audio(data)
    except WebSocketDisconnect:
        pass
    except Exception as e:
        logger.error(f"WebSocketでエラーが発生しました: {e}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8003)

ASRサーバー

whisper-v3-turboを使っています。
なお、よりリアルタイム向けの音声認識モデルもあります。

具体的な推論部分はこちらです。

model_id = "openai/whisper-large-v3-turbo"
# model_id = "openai/whisper-tiny"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
).to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
)

def get_generate_kwargs(language: str):
    return {
        "num_beams": 1,
        "return_timestamps": False,
        "language": language,
    }

result = pipe(audio, generate_kwargs=get_generate_kwargs(language))

全体の実装はこちらです。

from fastapi import FastAPI, UploadFile, HTTPException, Request
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import torch
import numpy as np
import librosa
import io

app = FastAPI()

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# TODO: Modelを動的に選択し、reloadするように修正
model_id = "openai/whisper-large-v3-turbo"
# model_id = "openai/whisper-tiny"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
).to(device)
processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
)

def get_generate_kwargs(language: str):
    return {
        "num_beams": 1,
        "return_timestamps": False,
        "language": language,
    }

@app.post("/stt/file")
async def stt_from_file(file: UploadFile, language: str = "japanese"):
    try:
        contents = await file.read()
        audio, sr = librosa.load(io.BytesIO(contents), sr=16000)

        result = pipe(audio, generate_kwargs=get_generate_kwargs(language))
        return {"text": result["text"]}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"STT error: {e}")

@app.post("/stt/bytes")
async def stt_from_bytes(request: Request, language: str = "japanese"):
    try:
        audio_data = await request.body()
        audio = np.frombuffer(audio_data, np.int16).astype(np.float32) / 32768.0
        result = pipe(audio, generate_kwargs=get_generate_kwargs(language))
        return {"text": result["text"]}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"STT error: {e}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8001)

LLMサーバー

google/gemma-2-2b-jpn-itで返答を生成しています。
gemma-2-2b-jpn-itは神です。

model_id = "google/gemma-2-2b-jpn-it"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=device,
    torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32,
)
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True, return_dict=True).to(model.device)
outputs = model.generate(**inputs, max_new_tokens=128)
generated_text = tokenizer.batch_decode(outputs[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True)[0]

以下が全体のコードです。

from fastapi import FastAPI, HTTPException
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from typing import List, Dict

app = FastAPI()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# TODO: Modelを動的に選択し、reloadするように修正
model_id = "google/gemma-2-2b-jpn-it"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=device,
    torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32,
)

@app.post("/llm/stream")
async def generate_stream(messages: List[Dict[str, str]]):
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(model.device)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=256)
    model.generate(**generation_kwargs)
    return streamer

@app.post("/llm")
async def generate(messages: List[Dict[str, str]], stream: bool = False):
    if stream:
        return await generate_stream(messages)
    else:
        inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True, return_dict=True).to(model.device)
        outputs = model.generate(**inputs, max_new_tokens=128)
        generated_text = tokenizer.batch_decode(outputs[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True)[0]
        return {"message": generated_text}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8002)

TTSサーバー

litaginさんのstyle-bert-vits2で音声合成をしています。
音声合成部分は以下となります。


# bert model の load
bert_models.load_model(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")
bert_models.load_tokenizer(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")

# TTS model の download
model_file = "jvnv-F1-jp/jvnv-F1-jp_e160_s14000.safetensors"
config_file = "jvnv-F1-jp/config.json"
style_file = "jvnv-F1-jp/style_vectors.npy"

for file in [model_file, config_file, style_file]:
    print(file)
    hf_hub_download("litagin/style_bert_vits2_jvnv", file, local_dir="models")

# TTS model の load
assets_root = Path("models")
model = TTSModel(
    model_path=assets_root / model_file,
    config_path=assets_root / config_file,
    style_vec_path=assets_root / style_file,
    device=device,
)
sr, audio = model.infer(text=request.text)

全体はこちら。

from fastapi import FastAPI, HTTPException
from style_bert_vits2.tts_model import TTSModel
from style_bert_vits2.nlp import bert_models
from style_bert_vits2.constants import Languages
from pathlib import Path
import torch
from fastapi.responses import Response
import soundfile as sf
from io import BytesIO
from huggingface_hub import hf_hub_download
from pydantic import BaseModel

app = FastAPI()

import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# bert model の load
bert_models.load_model(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")
bert_models.load_tokenizer(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")

# TTS model の download
model_file = "jvnv-F1-jp/jvnv-F1-jp_e160_s14000.safetensors"
config_file = "jvnv-F1-jp/config.json"
style_file = "jvnv-F1-jp/style_vectors.npy"

for file in [model_file, config_file, style_file]:
    print(file)
    hf_hub_download("litagin/style_bert_vits2_jvnv", file, local_dir="models")

# TTS model の load
assets_root = Path("models")
model = TTSModel(
    model_path=assets_root / model_file,
    config_path=assets_root / config_file,
    style_vec_path=assets_root / style_file,
    device=device,
)

class TTSRequest(BaseModel):
    text: str

@app.post("/tts")
async def generate_tts(request: TTSRequest):
    try:
        sr, audio = model.infer(text=request.text)

        # 音声をWAV形式でエンコード
        wav_io = BytesIO()
        sf.write(wav_io, audio, sr, format='WAV')
        wav_io.seek(0)

        # レスポンスとしてWAVファイルを返す
        return Response(content=wav_io.read(), media_type="audio/wav")

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"TTS error: {e}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8003)

クライアント側の実装

意外といろんなことやっているので、詳細は以下

まとめ

  • speech-to-speechをシンプルにやるならわりと簡単な時代になった
  • 細かいけど大事な部分 (発話認識等) がわりとむずかしい
  • みんなも自分専用のアシスタントAIキャラクターを作っていこう
30
22
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
30
22

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?