はじめに
ChatGPT盛り上がっていますね。
本記事では、MicroSoft社が公開している 「ChatGPT と複数のVisual Foundation Models(大量かつ多様なデータで訓練され多様な下流タスクに適応できるモデル)を組み合わせ様々な画像生成を対話形式で簡単に行える!Visual ChatGPT 」の実装コードを著者の勉強も兼ねて整理してみました。
ReACTを使ってどんどん機能を拡張できそうですし、コード自体は比較的シンプルでこれからChatGPTを使ったアプリを作流のにも参考になりそうです。
コードは2023年4月1日時点でGitUHBに公開されているものを使用しました。
Visual ChatGPT
まだ勉強不足のところもあり、認識違いや誤りがあれば本記事を順次ブラッシュアップしていきますのでご指摘があればお願いします。
1. 実装されている機能
2023年4月1日で実装されているモデル(Visual Foundation Models)は以下の通りです。
クラス | メソッド | 説明 |
---|---|---|
ImageEditing | inference_remove() | 写真の中から指示した何かを取り除く |
inference_replace() | 写真の中から指示した何かを置き換える | |
InstructPix2Pix | inference() | テキストを使用して画像スタイルを変更する |
Text2Image | inference() | ユーザーが入力したテキストから画像を生成する |
ImageCaptioning | inference() | 画像キャプション(画像の説明)を行う |
Image2Canny | inference() | 画像のCannyエッジ(輪郭線)検出する |
CannyText2Image | inference() | Cannyエッジ(輪郭線)画像とテキストから画像を生成する |
Image2Line | inference() | 画像上のライン検出する |
LineText2Image | inference() | ライン画像とテキストから画像を生成する |
Image2Hed | inference() | 画像上のHEDを検出する |
HedText2Image | inference() | ソフトHED境界の画像から画像を生成する |
Image2Scribble | inference() | 画像からスケッチを生成する |
ScribbleText2Image | inference() | スケッチ画像から画像を生成する |
Image2Pose | inference() | 画像からポーズを検出する |
PoseText2Image | inference() | ポーズ画像とテキストから画像を生成する |
Image2Seg | inference() | 画像のセグメンテーションを行う |
SegText2Image | inference() | セグメンテーション画像とテキストから画像を生成する |
Image2Depth | inference() | 画像の奥行きを予測する |
DepthText2Image | inference() | 奥行画像とテキストから画像を生成する |
Image2Normal | inference() | 画像の法線マップを予測する |
NormalText2Image | inference() | 法線マップとテキストから画像を生成する |
VisualQuestionAnswering | inference() | 画像に関する質問に回答する |
InfinityOutPainting | inference() | 画像を拡大する |
公式のHP上では、Google ColabではGPUリソースの簡単から「ImageCaptioning」と「Text2Image」のみを使用することが推奨されています。画像の変換をするにはInstructPix2Pixモデルを入れた方が楽しく遊べます。筆者のGoogle Colab ProではInstructPix2Pixを追加しても問題なく動作しました。
# Advice for 1 Tesla T4 15GB (Google Colab)
python visual_chatgpt.py --load "ImageCaptioning_cuda:0,Text2Image_cuda:0,InstructPix2Pix_cuda:0"
2. 実装コードの説明
それでは実装されているコードを順に説明していきます。
主な使用ライブラリィ
Transformersライブラリィから主に画像のセグメンテーションを行う各種モジュールをimportします。
14 from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
15 from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
16 from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
続いて、Hugging Face のDiffusersライブラリ(画像生成)から各種モジュールをimportします。
Diffusersライブラリは、Stable Diffusionをはじめとするモデルを、共通インタフェースで簡単に利用するためのパッケージです。
18 from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline, StableDiffusionInstructPix2PixPipeline
19 from diffusers import EulerAncestralDiscreteScheduler
20 from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
21 from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector
LangChainライブラリィから各種モジュールをimportします。
LangChainは、ChatGPTなどの大規模言語モデル(LLM)を使ったアプリケーションやサービスの開発を支援するための強力なライブラリです。例えば、Webページやテキストを読み込ませてChatGPTに質問や要約をさせることが簡単に行えます。
24 from langchain.agents.initialize import initialize_agent
25 from langchain.agents.tools import Tool
26 from langchain.chains.conversation.memory import ConversationBufferMemory
27 from langchain.llms.openai import OpenAI
モデル(Visual Foundation Models)定義
213行目から971行目で、上記「1. 実装されている機能」 で一覧表示している各モデルが定義されています。
213 class ImageEditing:
214 def __init__(self, device):
・・・
971 return updated_image_path
処理の流れ
ここから、1046行目以降の処理の流れに沿って説明します。
- 1046〜1050行目
visual_chatgpt.pyを起動すると、まず引数の処理が行われます。
「--load」のパラメータにデフォルトの「ImageCaptioning」と「Text2Image」モジュールが強制的に指定され、他の指定と合わせてload_dictディクショナリに保存されます。
1046 if __name__ == '__main__':
1047 parser = argparse.ArgumentParser()
1048 parser.add_argument('--load', type=str, default="ImageCaptioning_cuda:0,Text2Image_cuda:0")
1049 args = parser.parse_args()
1050 load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
- 1051行目
bot(ConversationBot)の生成します(詳細は後述)。
1051 bot = ConversationBot(load_dict=load_dict)
- 1052行目から1070行目
gradio:gr を使ったチャットbotのWeb画面の定義になります。ここではWeb上のボタンを押されたら先ほど生成したbotの処理を呼び出して画像処理が行われます。例えばテキストエリア(「Enter text and press enter, or upload an imaget」と初期表示されたエリア)に指示テキストを入れてリターンキーを押すと1064行目のbot.run_text
により、bot内のrun_text()
メソッドが呼び出されます。
1052 with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
1053 chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT")
1054 state = gr.State([])
1055 with gr.Row():
1056 with gr.Column(scale=0.7):
1057 txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(
1058 container=False)
1059 with gr.Column(scale=0.15, min_width=0):
1060 clear = gr.Button("Clear")
1061 with gr.Column(scale=0.15, min_width=0):
1062 btn = gr.UploadButton("Upload", file_types=["image"])
1063
1064 txt.submit(bot.run_text, [txt, state], [chatbot, state])
1065 txt.submit(lambda: "", None, txt)
1066 btn.upload(bot.run_image, [btn, state, txt], [chatbot, state, txt])
1067 clear.click(bot.memory.clear)
1068 clear.click(lambda: [], None, chatbot)
1069 clear.click(lambda: [], None, state)
1070 demo.launch(server_name="0.0.0.0", server_port=1015)
1070行目でgradioのインスタンスを読み出してチャットbotの起動を行います。
Google Colabで実行する場合、1070行目のlaunch()を以下のように修正し、ローカルでなく share=True
でシェアして利用するようにし、ログイン時のID、パスワードを適当に設定してください。
demo.launch(server_name="0.0.0.0", server_port=1015)
↓
demo.launch(auth=("MyID", "MyPassward"), share=True)
visual_chatgpt.pyを起動後に、コンソールに
Running on public URL: https://xxx.xxx.xxx.xxx
とリンクが表示れたらクリックして、上記で設定したID、パスワードを入れてアクセスします。
ConversationBotの実装
続いて肝となる ConversationBot
の実装を見てみましょう。
- ConversationBotクラス生成時に行われる
__init__()
メソッドでは、981行目のOpenAI()
でLLMの生成を行います。続いて過去の会話ログをすべてプロンプトに追加するメモリであるConversationBufferMemory
を982行目で生成します。
974 class ConversationBot:
975 def __init__(self, load_dict):
976 # load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
977 print(f"Initializing VisualChatGPT, load_dict={load_dict}")
978 if 'ImageCaptioning' not in load_dict:
979 raise ValueError("You have to load ImageCaptioning as a basic function for VisualChatGPT")
980
981 self.llm = OpenAI(temperature=0)
982 self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
- Load Basic Foundation Models
985行目からの#Load Basic Foundation Models
では、起動時のパラメータで与えられた使用するモデルと使用デバイス(GPU or CPU)情報を取得します。
985 # Load Basic Foundation Models
986 for class_name, device in load_dict.items():
987 self.models[class_name] = globals()[class_name](device=device)
988
- Load Template Foundation Models
989行目からの# Load Template Foundation Models
では、テンプレートモデル(クラス)とそこで内部で使われているモデルを取得しています。現在の実装では、876行目のInfinityOutPainting
クラスのみがテンプレート定義されており、そこで使用するImageCaptioning
,ImageEditing
,VisualQuestionAnswering
モデルが取得されます。
989 # Load Template Foundation Models
990 for class_name, module in globals().items():
991 if getattr(module, 'template_model', False):
992 template_required_names = {k for k in inspect.signature(module.__init__).parameters.keys() if k!='self'}
993 loaded_names = set([type(e).__name__ for e in self.models.values()])
994 if template_required_names.issubset(loaded_names):
995 self.models[class_name] = globals()[class_name](
996 **{name: self.models[name] for name in template_required_names})
- Toolの生成
998行目のfor
文では、対象モデルから「inference」で開始するメソッドを探してTool
に登録します。このとき、inference
で開始するメソッドには@prompts
デコレータを使ってdescription
にAgentこのモデルを使ってもらうためのヒントが与えられています。例えば、ImageCaptioningでは以下のようになっています。
@prompts(name="Get Photo Description",
description="useful when you want to know what is inside the >photo. receives image_path as input. "
"The input to this tool should be a string, >representing the image_path. ")
997 self.tools = []
998 for instance in self.models.values():
999 for e in dir(instance):
1000 if e.startswith('inference'):
1001 func = getattr(instance, e)
1002 self.tools.append(Tool(name=func.name, description=func.description, func=func))
- Agentの生成
1003行目のinitialize_agent()
で、LLMインスタンスやToolインスタスを指定してAgentの生成を行います。 なお、agent
パラメータに指定されている「conversational-react-description」は、会話用に最適化されたエージェントです。
prefix
は、promptの最初に常にくっつける文章、suffix
は{agent_scratchpad}は絶対に含んでいる必要がありAgentがここにいろいろな文章を代入します。format_instructions
は回答のフォーマットを示す文章になります。それぞれ28行目から71行目に定義されています(どのような文章を指定するのかはノウハウなのでしょうか、、、)。
1003 self.agent = initialize_agent(
1004 self.tools,
1005 self.llm,
1006 agent="conversational-react-description",
1007 verbose=True,
1008 memory=self.memory,
1009 return_intermediate_steps=True,
1010 agent_kwargs={'prefix': VISUAL_CHATGPT_PREFIX, 'format_instructions': VISUAL_CHATGPT_FORMAT_INSTRUCTIONS,
1011 'suffix': VISUAL_CHATGPT_SUFFIX}, )
1012
これで準備が完了しました。 続いてはWeb画面で実際に操作が行われた際の処理になります。
- テキストエリアに動作指示したとき
1013行目のrun_text()
メソッドで、Web画面のテキストエリアに入力して指示された情報をAgentに渡してLLMに処理をさせます。このとき、どのクラスを使ってどのように処理をするのかは先ほど生成したAgentに任せます。これはReActモデルと呼ばれますが、興味がある人は以下の記事が詳しく書かれていますので参考にしてください(参考にさせていただきました)。
【Prompt Engineering】LLMを効率的に動かす「ReAct」論文徹底分解!
1013 def run_text(self, text, state):
1014 self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
1015 res = self.agent({"input": text})
1016 res['output'] = res['output'].replace("\\", "/")
1017 response = re.sub('(image/\S*png)', lambda m: f'})*{m.group(0)}*', res['output'])
1018 state = state + [(text, response)]
1019 print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
1020 f"Current Memory: {self.agent.memory.buffer}")
1021 return state, state
1022
- 画像をアップロードしたとき
1023行目のrun_image()
メソッドで、画像をアップロードした際の処理を実施します。アップロードしたイメージを「ImageCaptioning」を使って画像の解釈を加えてrun_text()
時のプロンプトに追加して会話メモリに保存します。
1023 def run_image(self, image, state, txt):
1024 image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
1025 print("======>Auto Resize Image...")
1026 img = Image.open(image.name)
1027 width, height = img.size
1028 ratio = min(512 / width, 512 / height)
1029 width_new, height_new = (round(width * ratio), round(height * ratio))
1030 width_new = int(np.round(width_new / 64.0)) * 64
1031 height_new = int(np.round(height_new / 64.0)) * 64
1032 img = img.resize((width_new, height_new))
1033 img = img.convert('RGB')
1034 img.save(image_filename, "PNG")
1035 print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
1036 description = self.models['ImageCaptioning'].inference(image_filename)
1037 Human_prompt = f'\nHuman: provide a figure named {image_filename}. The description is: {description}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
1038 AI_prompt = "Received. "
1039 self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
1040 state = state + [(f"*{image_filename}*", AI_prompt)]
1041 print(f"\nProcessed run_image, Input image: {image_filename}\nCurrent state: {state}\n"
1042 f"Current Memory: {self.agent.memory.buffer}")
1043 return state, state, f'{txt} {image_filename} '
3. Googleコラボでの実行について
2023年4月1日時点のコードでGoogle Colab Proで試行した手順です。
公式HPで記載されているQuick Startをベースにしています。
(1) Pythonのバージョンを3.8
に変更する
現在のGoogle Colabのデフォルトが3.9
ですので3.8
にバージョンを下げます。
# 下記のコマンドで使用できるPythonのバージョンの一覧が出るので選択する
!sudo update-alternatives --config python3
# Pythonのバージョンを変えた場合、pipをインストールし直す
!sudo apt install python3-pip
!wget https://bootstrap.pypa.io/pip/3.8/get-pip.py
!python3 get-pip.py
(2) ライブラリィのインストール
# 必要なライブラリィのインストール
!pip install -r xxxx/requirements.txt
(3) OpenAIのキー設定
OpenAIから取得したAPIキーを設定します。
%env OPENAI_API_KEY=XXXXXXXXXXXXXXXXXXXXXXXXXX
(4) 実行
!python XXXX/visual_chatgpt.py --load "ImageCaptioning_cuda:0,Text2Image_cuda:0,InstructPix2Pix_cuda:0"
4. 終わりに
Visual ChatGPTコード自体は上記の通り比較的シンプルなものです。今後、組み込まれるモデルが増えてくればますます便利になりそうですね。
今回は、Google Colab Proでしか動かしておりませんが、時間があれば高性能なGPUサーバでも他のモデルも組み込んで動かしてみたいと思います。また、LangChain/ReACTをさらに掘り下げて勉強して記事の記載していきたいと思います。