前回の続き。前回はCyberAgent M2-7B-Chatを動かしただけで終わってしまったので、今回は以下のような機能を追加して遊んでみる。
- CyberAgent M2-7B-Chatからの応答をストリームの形で返すサーバーを作る。(LLMサーバーと呼ぶ)
- 上記のサーバーとのやり取りを仲介して、一連のコンテキストとして扱えるサーバーを作る。(Chatサーバーと呼ぶ)
- ChatサーバーにLLMサーバーからの出力をVOICEVOXに送信して喋らせる。
VOICEVOX ENGINEの導入
Ubuntu(WSL2)にどうやって導入しようかと思っていたら、Dockerイメージが用意されているのでそれを使う。
なおVOICEVOX ENGINEはFast APIで動いているため、/docsにアクセスすることでAPIの使い方を調べたり、実際にAPIにクエリやJSONを投げて動作を確かめたりできる。
LLMサーバーの実装
Fast APIが手軽に使えそうだったので、こちらもFast APIで実装してみる。まず。元のコードでは、出力に使うstreamerにTextStreamerを利用していた。これは標準出力に出力するStreamerなのだが、ChatGPTのように文字列をストリーミングで返すにはTextIteratorStreamerを使うようだ。
また、FastAPIでこのTextIteratorStreamerのようなものの結果をストリームして返すには、StreamingResponseを用いればよいようだ。
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import asyncio
from threading import Thread
from typing import AsyncIterator
from pydantic import BaseModel
app = FastAPI()
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
assert transformers.__version__ >= "4.34.1"
class ChatInput(BaseModel):
message: str
model = AutoModelForCausalLM.from_pretrained("cyberagent/calm2-7b-chat", device_map="auto", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained("cyberagent/calm2-7b-chat")
async def generate(prompt: str):
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
token_ids = tokenizer.encode(prompt, return_tensors="pt")
thread = Thread(target=model.generate, kwargs=dict(
input_ids=token_ids.to(model.device),
max_new_tokens=300,
do_sample=True,
temperature=0.8,
streamer=streamer,
))
thread.start()
for output in streamer:
if not output:
continue
await asyncio.sleep(0)
yield output
@app.post("/chat")
async def chat(message: ChatInput):
return StreamingResponse(generate(message.message), media_type="text/plain")
FastAPIのドキュメントではuvicornで起動していたので、それに従って起動する。
uvicorn llmserver:app --reload
試しにcurlからリクエストを送信してみる。
curl -X POST \
-H "Content-Type: application/json" \
-d "{ \"message\": \"USER: Generative AIについて簡潔に説明してください。\" }" \
http://localhost:8000/chat
ASSISTANT: 生成的AIは、従来のように人間が問題を定義し、コンピューターが答えを見つけるのではなく、コンピューターがデータから自動的に新しい知識を創り出す技術です。
生成的AIは、データから学習し、パターンを見つけ、新しい知識を生成することで、自然言語生成、画像生成、音楽生成、データマイニングなどの分野で広く利用されています。
生成的AIは、人間が介入することなく自動的に生成されるため、人間が思いつかないような新しいアイデアや解決策を生み出すことができます。また、大量のデータを扱うことができるため、データマイニングやWebスクレイピングなどの分野での活用が期待されています。
実際に動かしてみると、徐々にレスポンスが返ってきているのが判る。
Chatサーバーの実装
メッセージにIDを持たせて対話が続くようにする事もできるが、サンプルとして載せるには長すぎるのでその辺は割愛。単純にLLMサーバーからストリーミングされたデータをクライアントにストリームする一方で、VOICEVOXに送って喋らせているだけの実装。
import requests
import json
import simpleaudio as sa
import numpy as np
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import asyncio
from threading import Thread
from typing import AsyncIterator
from pydantic import BaseModel
app = FastAPI()
class ChatInput(BaseModel):
message: str
def apply_fade_in(audio_data, sample_rate, fade_duration=0.1):
fade_in_samples = int(sample_rate * fade_duration)
fade_in = np.linspace(0, 1, fade_in_samples)
audio_length = len(audio_data)
if fade_in_samples > audio_length:
raise ValueError("Fade duration is longer than the audio length.")
audio_data[:fade_in_samples] *= fade_in
return audio_data.astype(np.int16)
def speak(text: str):
res = requests.post(
'http://127.0.0.1:50021/audio_query',
params={'text': text, 'style_id': 1})
query = res.json()
query['speedScale'] = 1.3
res = requests.post(
'http://127.0.0.1:50021/synthesis',
params={'style_id': 1},
data=json.dumps(query))
audio_data = res.content
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float64).copy()
audio = apply_fade_in(audio_array, 24000)
play_obj = sa.play_buffer(audio, 1, 2, 24000)
play_obj.wait_done()
return res
async def llm(message: str):
try:
res = requests.post(
'http://127.0.0.1:8000/chat',
stream=True,
json={ "message": message})
for line in res.iter_lines():
msg = line.decode('utf-8')
yield msg
except asyncio.CancelledError:
print("caught cancelled error")
async def chat_main(input: str):
async for msg in llm(input):
yield msg + "\n"
speak(msg)
yield "\n"
@app.post("/chat")
async def chat(input: ChatInput):
return StreamingResponse(chat_main(input.message))
ちょっと長くなってしまった。本筋とはズレるが、WSL2上でsimpleaudioライブラリを利用して音声を再生しようとしたところ、クリックノイズが必ず入るようになってしまった。解決策はないか調べたが、どうも頭に極々短いフェードインを入れることで回避するのが簡単なようだったので、その方法を取った。
なんかFast API超便利みたいな内容になってしまったが、そこは気にしない。