24
18

More than 1 year has passed since last update.

gradioでAIモデルのデモアプリを作成しながら基本機能を確認していく

Last updated at Posted at 2022-12-15

この記事はBrainPad Advent Calender 15日目の記事です

はじめに

今回は機械学習やデータサイエンスのデモに用いることができるwebインターフェイスライブラリーであるgradioを紹介させていただきます

huggingface spacesで公開されている様々な種類の深層学習を使ったアプリはgradioもしくはstreamlitで構築されておりhuggingfaceユーザーには馴染み深いものとなってきています

Gradioとは

Gradioは、PythonベースのWebインターフェースを構築できるオープンソースライブラリです

image.png

PythonベースのWebインターフェイスライブラリーはgradio以外にもstreamlitやdashなどのライブラリもありますが、gradioの場合には機械学習や深層学習の推論を簡単にデモアプリで利用できるようなコンポーネントが多数用意されています

Gradioの特徴

gradioは以下のような特徴を持ったwebインターフェイスライブラリーとなります

Good

  • Google Colaboratory上で簡単なデモアプリを作れる
  • アプリは72時間は無料でホスティングされる
  • 深層学習の推論結果をインタラクティブに確認できる

Bad

  • 現状では恒久的なデプロイ先がhuggingface spacesしか用意されていない
  • 機能としてはデモアプリに過ぎないので本番サービスには向かない

コンポーネント

gradioにとって最も基本的なコンポーネントInterfaceもしkはBlocksとなります

どちらも似たような機能ですが、Blocksの方がカスタマイズ性が高いので、多くの場合はBlocksを用いてインターフェイスを構築することになると思います

複雑な処理(タブごとに画面を分割したり複数の機能を組み合わせたり)が必要ない場合にはInterfaceでの実装を選択すると良いでしょう

Interfaceで構築
def hello(name):
    return "Hello " + name + "!"

gr.Interface(hello, "textbox", "text").launch()

たったこれだけのコードで簡易的なデモアプリをデプロイできます

出力イメージ
hello_interface.gif

Blocksの場合はBlocks内で複雑な処理を記載することが多いのでwith句を使って表現することが大半です

Blocksで構築
def hello(name):
    return "Hello " + name + "!"

with gr.Blocks() as demo:
    name = gr.Textbox(label="name")
    output = gr.Textbox(label="output")
    greet_btn = gr.Button("Hello")
    greet_btn.click(fn=hello, inputs=name, outputs=output)

demo.launch()

出力イメージ
hello_interface.gif

具体例を見ていく

gradioが用意しているコンポーネントを逐一紹介していくよりも、活用イメージが湧きやすい具体的なデモを実装するかhuggingface spacesで公開されているデモアプリを引用しながら機能を紹介していきます

事前準備

本記事ではGoogleColaboratory上で処理を行なっていくことを前提に記載していきます

また、Colaboratoryにはgradioは事前にインストールされていないので、事前にpip installしておくこととします

!pip install gradio

自然言語

Whisperを用いた文字起こしデモアプリ 【hugginface spaces】

マイクで音声を収録して、その音声をWhisperを用いて文字起こししてくれるデモアプリを見つつ、音声入力に関するgradioの機能を確認します

ポイント
:heavy_check_mark: Audioを用いてマイクによる音声入力を受け付ける

音声入力部分の機能を抜粋
audio = gr.Audio(
    label="Input Audio",
    show_label=False,
    source="microphone",
    type="filepath"
)

出力イメージ
whisper.gif

画像

画像に関する自作のデモアプリと公開されているデモアプリを参考にgradioの機能を確認していきます

画像のセマンティックセグメンテーションのデモアプリを作成する

facebookのMaskFormerを用いてセマンティックセグメンテーションを行います

MaskFormerによるセマンティックセグメンテーション
!pip install -q gradio git+https://github.com/huggingface/transformers torch

import torch
import random
import gradio as gr
import numpy as np
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation

device = torch.device("cpu")

model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device)
model.eval()
preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade")


def input_img(img):
    target_size = (img.shape[0], img.shape[1])
    inputs = preprocessor(images=img, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    outputs.class_queries_logits = outputs.class_queries_logits.cpu()
    outputs.masks_queries_logits = outputs.masks_queries_logits.cpu()
    results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach()
    results = torch.argmax(results, dim=0).numpy()
    results = visualize_segmentated_mask(results)
    return results

def visualize_segmentated_mask(mask_image):
    image = np.zeros((mask_image.shape[0], mask_image.shape[1], 3))
    labels = np.unique(mask_image)
    label_color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}

    for i in range(image.shape[0]):
      for j in range(image.shape[1]):
        image[i, j, :] = label_color[mask_image[i, j]]

    image = image / 255
    return image

with gr.Blocks() as app:
    input = gr.Image(label="Image")
    output = gr.Image(label="Output")
    submit_btn = gr.Button(label="Submit")
    submit_btn.click(fn=input_img, inputs=input, outputs=output)

app.launch()

ポイント

:heavy_check_mark: Imageを用いて画像の入力を受け付ける

Image部分を抜粋
input = gr.Image(label="Image")
output = gr.Image(label="Output")

出力イメージ
semantic.gif

YOLOV7による物体検出デモアプリ 【hugginface spaces】

YOLOV7を用いた物体検出アプリの公開デモアプリを見つつ、画像の入力機能と物体検出モデルの選択する機能の実装を確認します

ポイント
:heavy_check_mark: Imageを用いて画像の入力を受け付ける
:heavy_check_mark: Dropdownを用いて利用するモデルを切り替える

Image機能とDropdown機能部分を抜粋
gr.Interface(
detect,[gr.Image(type="pil"),gr.Dropdown(choices=model_names)],
gr.Image(type="pil"),title="Yolov7",examples=[["horses.jpeg", "yolov7"]],
description="demo for <a href='https://github.com/WongKinYiu/yolov7' style='text-decoration: underline' target='_blank'>WongKinYiu/yolov7</a> Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors"
).launch()

出力イメージ

yolo7.gif

予測

最後に予測モデルの推論結果を表示するアプリを作成してきます

時系列予測比較デモアプリを作成

Prophetを用いて時系列予測を行い、予測期間を変更しながら時系列予測結果を描画するデモアプリを作成します

Prophetによる時系列予測モデル
import gradio as gr
import pandas as pd


from prophet import Prophet


def plot_forecast(example_name, period):
    df = pd.read_csv(f'https://raw.githubusercontent.com/facebook/prophet/main/examples/example_{example_name}.csv')
    df.columns = ['ds','y']

    m = Prophet()
    m.fit(df)
    future = m.make_future_dataframe(periods=period)
    forecast = m.predict(future)
    fig = m.plot(forecast)
    return fig

with gr.Blocks() as demo:
    gr.Markdown(
    """
    時系列予測モデルの結果
    """)
    with gr.Row():
        example = gr.Dropdown(["air_passengers", "pedestrians_covid", "retail_sales"], label="データソース", value="air_passengers")
        period = gr.Slider(25, 250, 25, step=25, label="予測期間")
 
    plt = gr.Plot()

    example.change(plot_forecast, [example,period], plt, queue=False)
    period.change(plot_forecast, [example,period], plt, queue=False)
    demo.load(plot_forecast, [example,period], plt, queue=False)    

demo.launch()

ポイント
:heavy_check_mark: Markdownを用いて概要説明を加える
:heavy_check_mark: Dropdownを用いて利用するデータを切り替える
:heavy_check_mark: Sliderを用いて予測範囲をインタラクティブに変更する

Markdown,Dropdown,Slider機能を抜粋
gr.Markdown("""時系列予測モデルの結果""")
with gr.Row():
    example = gr.Dropdown(["air_passengers", "pedestrians_covid", "retail_sales"], label="データソース", value="air_passengers")
    period = gr.Slider(25, 250, 25, step=25, label="予測期間")

出力イメージ
timesiries.gif

開発周りの疑問

デプロイはどうするのか?

現状gradioを永続的にホスティングしたければhuggingface spacesにデプロイするしか方法がありません
ただし、72時間以内であればdemo.launch(share=True)でデプロイすればパブリックなURLが発行され、どこからでもアプリにアクセスできます

デプロイ時にAuthは設定できるのか?

パブリックなURLが発行される場合にはAuthの設定が欲しいところですが、以下の処理を加えることで簡易的な認証画面を設定できます

# 任意のユーザーIDとパスワードを設定
demo.launch(auth=("user", "password"), share=True)

認証画面
image.png

プライベートな環境でデモアプリを共有できるか?

現状gradioを永続的にデプロイできる環境がhuggingface spacesに限られているため、huggingfaceのPrivate Hubをリクエストしない限りはプライベートな環境でgradioのデモアプリを共有することはできない状況です


おしまい

参考文献

24
18
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
24
18