LoginSignup
5
4

30days of Streamlit の29日目(Hugging Face API)を、Streamlit in Snowflakeで実装してみた

Last updated at Posted at 2023-10-20

はじめに

Snowflakeで最近Public PreviewになったStreamlit in Snowflake(以下SiS)とExternal Network Access(外部ネットワークアクセス)を組み合わせたアプリを作れないかなーという所で、30days of Streamlitの魔の29日目ことHugging Face APIによる機械学習アプリをSiSで動作させてみました。

今回は主に、動作させるために必要な操作や重要な点を解説していきます。

動作画面

SiSでの動作画面です。本家様(https://share.streamlit.io/charlywargnier/zero-shot-classifier/main )とそれほど変わらないアプリを構築できています。
Animation.gif
※再生速度は約2倍速です。また、API呼び出し待ちはカット編集を入れています。

アプリでは、入力した文章を自由に指定したカテゴリの内どのカテゴリに分類されるかを、Hugging Face APIを使用して推定しています。

アーキテクチャ図

今回構成したアーキテクチャのイメージ図です。詳細は後述していきます。
image.png

概念説明と準備

本章では、Snowflakeで最近PuPrとなった注目機能、外部ネットワークアクセス機能について簡単な解説と、SiSから使用するためのコマンドとその解説を行います。

外部ネットワークアクセスとは?

Snowflakeの外部にあるネットワークへのアクセスを、UDFやストアドプロシージャといった自作関数から行えるようになる機能です。そのために、下記のSnowflakeオブジェクトを作成します。

  • ネットワークルール(NETWORK RULE):外部のURLに対する通信の許可・拒否設定をします。
  • シークレット(SECRET):API認証キーを強固に保存し、統合やUDFから読み込みます。
  • 外部アクセス統合(EXTERNAL ACCESS INTEGRATION):ネットワークルールとシークレットを指定してこの統合を作成することで、自作関数から外部ネットワークにアクセスできるようになります。

外部ネットワークアクセス機能をSiSから使うには?

まず、現時点ではSiSから直接外部ネットワークアクセス機能を使用することはできません。そのため、ストアドプロシージャかUDFを介して実行します。今回のユースケースの場合、並列性の観点からUDFで実行するのが最適だと考えられます。なお、本家様のコードは、for文で各入力行を逐次処理されています。

SiSから外部ネットワークアクセスを使うための準備

まずは外部ネットワークアクセス機能の設定コード(SQL)を示します。適宜、データベースやスキーマなどのオブジェクト名は変更してください。また、Hugging Face APIキーも各自で取得したものを使用してください。

-- 事前準備
use role sysadmin;
use warehouse compute_wh;

create database if not exists sandbox;
use database sandbox;
create schema if not exists hugging_face;
use schema hugging_face;

use role accountadmin;

-- NETWORK_RULEの作成:https://docs.snowflake.com/en/sql-reference/sql/create-network-rule
CREATE OR REPLACE NETWORK RULE hugging_face_network_rule
  MODE = EGRESS
  TYPE = HOST_PORT
  VALUE_LIST = ('api-inference.huggingface.co');

-- シークレットの作成:https://docs.snowflake.com/en/developer-guide/external-network-access/creating-using-external-network-access#creating-a-secret-to-represent-credentials
CREATE OR REPLACE SECRET hugging_token
  TYPE = GENERIC_STRING
  SECRET_STRING = '<Hugging Face のAPIキー>';

-- External Network Accessの作成:https://docs.snowflake.com/en/developer-guide/external-network-access/creating-using-external-network-access#label-creating-using-external-access-integration-access-integration
CREATE OR REPLACE EXTERNAL ACCESS INTEGRATION hugging_face_access_integration
  ALLOWED_NETWORK_RULES = (hugging_face_network_rule)
  ALLOWED_AUTHENTICATION_SECRETS = (hugging_token)
  ENABLED = true;

上記の設定では先述の通り、Hugging Face APIに関するネットワークルールとシークレットオブジェクトを作成し、それらを外部アクセス統合オブジェクトのパラメータとして指定しています。

続いて、下記に、Hugging Face APIを実行するUDFの作成コードを示します。

use role accountadmin;

-- 外部ネットワークアクセスをするUDFの作成:https://docs.snowflake.com/en/developer-guide/external-network-access/creating-using-external-network-access#using-the-external-access-integration-in-a-function-or-procedure
CREATE OR REPLACE FUNCTION hugging_face_python(sentence STRING, parameters ARRAY)
RETURNS VARIANT
LANGUAGE PYTHON
RUNTIME_VERSION = 3.8
HANDLER = 'get_hugging_classification'
EXTERNAL_ACCESS_INTEGRATIONS = (hugging_face_access_integration)
SECRETS = ('cred' = hugging_token)
PACKAGES = ('snowflake-snowpark-python','requests')
AS
$$
import _snowflake
import requests
import json
from typing import List

session = requests.Session()
def get_hugging_classification(sentence: str, parameters: List[str]=["positive", "negative"]):
  url = "https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3"
  token = _snowflake.get_generic_secret_string('cred')  # APIに関するドキュメント:https://docs.snowflake.com/en/developer-guide/external-network-access/secret-api-reference#python-api-for-secret-access
  header = {"Authorization": "Bearer " + token}
  data = {
    "inputs": sentence,
    "parameters": {"candidate_labels": parameters},
    "options": {"wait_for_model": True},
  }
  response = session.post(url, headers = header, json = data)
  return response.json()
$$;

-- 作成したUDFの使用権限をSYSADMINに付与
grant usage on function hugging_face_python(STRING, ARRAY) to role sysadmin;

-- Hugging Face APIにアクセスするUDFの動作確認
use role sysadmin;
select * from (select hugging_face_python('you are so geneous!', ['positive', 'negative']) as result);
select * from (select hugging_face_python('you are so rude!', ['positive', 'negative']) as result);

UDFの作成時には、EXTERNAL_ACCESS_INTEGRATIONSをリストで指定しています。また、SECRETSについても、UDF内で参照される環境変数名と共に指定しています。
UDF内のPythonコードでは、_snowflake.get_generic_secret_string()メソッドを使用して、シークレットを読み取って使用しています。UDF内のソースコードとしては、それ以外は特に変更点はありません。
作成したUDFは、ACCOUNTADMINで作成しているので、このUSAGE権限をSYSADMINに付与して完了となります。

最後に、Streamlitアプリを作成するコマンドを示します。こちらは、SnowsightのStreamlit GUIからプログラムをコピペする形でも大丈夫です。

-- Streamlitアプリの作成(ステージにファイルを格納している前提)
use role sysadmin;
create or replace streamlit HUGGING_FACE_29DAYS
root_location=@SANDBOX.HUGGING_FACE.HUGGING_FACE_STG
main_file='streamlit_main.py'
query_warehouse=COMPUTE_WH;

ソースコード解説

本章では、本家様のものから改変した部分について解説します。外部APIを使うために必須の部分と、UIに関して自分好みに変えた部分やちょっとした小ネタに分けています。なおソースコード全文は、当記事の最後に掲載しています。

必須の改変点(Hugging Face API 呼び出しUDF)

本節では、29日目のStreamlitアプリをSiSで動作させるために変更した点の内、必須・重要な点を解説します。

本記事の実装では、Hugging Face APIの呼び出しをUDFで実装しています。これにより、API呼び出しをSnowflakeの特徴である同時実行性を担保した実装とすることができます※。SnowflakeのスケーリングをPythonで活用できるようになるので、SnowflakeでPythonを使う最大のメリットの一つにこのUDFがあると感じています。
※Hugging Face APIでは同時接続数が制限されている可能性もあるため、一概に並列で実行されるとは限りません。

UDF用のインプットテーブル作成

Hugging Face APIを呼び出すUDFを実行するため、インプットテーブルを作成します。作成するテーブルの内容は、下記の様になります。

input
Today is September 25th.
The weather today was fantastic.
...

このテーブルを作成するために、本家様のコードで作成したlinesListリストを引数に、create_dataframe()メソッドを使い、input_dfデータフレーム(テーブル)を作成します。

input_df = snow_session.create_dataframe(
    linesList, StructType([StructField("input", StringType())])
)

UDFの実行

作成したinput_dfデータフレームを、事前に作成したUDFhugging_face_pythonに入力し、output_dfデータフレームにresult列として格納します。この処理は、with_columnメソッドにより行われ、第一引数に新しい列名を、第二引数に値を指定します。F.call_udfメソッドは、第一引数にUDFを取り、第二引数以降にそのUDFの引数を指定することで、UDFを呼び出すことができます。すなわち、このUDFの呼び出し結果がresult列に格納されることになります。

output_df = input_df.with_column(
    "result",
    F.call_udf(
        "hugging_face_python",
        F.col("input"),
        F.lit(st.session_state["multiselected"]),
    ),
)
output_df.collect()

今回UDFで実装したため若干複雑な処理が必要でしたが、ストアドプロシージャで実行すればリストをそのまま入力させることも可能です。(ただし、並列性はなくなるため性能は下がります。)

最終的に、UDFに入力されるテーブルは次の通りです。このテーブルの各行がUDFとして、ウェアハウスのプロセスで並列処理されます。

sentence parameters
Today is September 25th. ["Positive", "Negative", "Neutral"]
The weather today was fantastic. ["Positive", "Negative", "Neutral"]
... ...

UDFの出力結果となるoutput_dfデータフレームは次の通りです。

INPUT RESULT
Today is September 25th. {"labels": ["Positive", "Negative", "Neutral"], "scores": [0.499, 0.386, 0.115], "sequence": "Today is September 25th."}
The weather today was fantastic. {"labels": ["Positive", "Negative", "Neutral"], "scores": [0.956, 0.030, 0.014], "sequence": "The weather today was fantastic."}
... ...

若干Neutralに入る割合が低いですが、適切に入力したテキストが、指定したカテゴリに分類されていることが分かります。

この時点で外部ネットワークアクセス機能を使用してHugging FaceのAPIを実行するという目標は達成できました🎉 UDFの実行結果を整形する処理はAppendixとして次節にて解説します。

UIの改変点や小ネタなど

ここでは、小ネタの紹介をします。

画像の表示

st.image() は記事編集時点(2024/01/22)でサポートされているため、下記のような対応は不要となりました。ステージファイルのファイルパスや署名付きURLなどを使用して画像ファイルを使用することができます。

現時点でSiSでは、st.imageによる画像表示をサポートしていないようなので、 (記事編集時点(2024/01/22)で、st.image() がサポートされるようになっているようです。詳細はこちらから。)
plotlyを使用して画像を表示します。デフォルトでは軸やグリッドなどが表示されてしまうため、すべてオフにして表示する関数を用意しました。なお、画像は、CREATE STREAMLITコマンドのROOT_LOCATIONオプションで指定したパスの配下およびそのサブディレクトリにあるもののみ読み込むことができます。

import plotly.express as px
from skimage import io

def show_image(filepath: str, width: int, height: int) -> None:
    img = io.imread(filepath)
    fig = px.imshow(img)
    fig.update_layout(width=width, height=height, margin=dict(l=0, r=0, b=0, t=0))
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)
    fig.update_xaxes(showgrid=False, zeroline=False)
    fig.update_yaxes(showgrid=False, zeroline=False)
    st.plotly_chart(fig)

show_image("30days_logo.png", width=300, height=300)

st_tagsをst.text_input & st.multiselectに変更

現時点ではSiSでStreamlitのカスタムコンポーネントを使用できないため、Hugging Faceで使用する分類ラベルを別の手段で入力するUIが必要となります。そこで、下記の様にst.text_input & st.multiselectを使用する構成へと変更しています。

# st.session_state の初期化
if "label_options" not in st.session_state:
    st.session_state["label_options"] = [
        "Informational",
        "Transactional",
        "Navigational",
        "Positive",
        "Negative",
        "Neutral",
    ]

if "multiselected" not in st.session_state:
    st.session_state["multiselected"] = [
        "Positive",
        "Negative",
        "Neutral",
    ]

# 分類ラベルの追加
adding_label = st.text_input("Add below select content")
if adding_label != "":
    if adding_label not in st.session_state["label_options"]:
        st.session_state["label_options"].append(adding_label)

# 分類ラベルの選択
st.multiselect(
    label="Add labels - 3 max",
    options=st.session_state["label_options"],
    default=st.session_state["multiselected"],
    max_selections=3,
    key="multiselected",
)

若干冗長になりますが、機能としては再現できています。st.text_inputで取得した文字列を、st.multiselectのoptions引数に指定しているst.session_state["label_options"]に追加しています。なお、st.multiselectで選択された値は、key引数で指定された値multiselectedよりst.session_state["multiselected"]に格納されます。

APIの実行制限に達した場合のエラーハンドリング

Hugging Face APIの実行制限に達した場合、出力結果が下記の様にすべてNoneで出力されるため、それを利用したエラーハンドリング処理です。

input text labels scores
None None None
None None None
... ... ...
def is_all_none(df, column_name: str) -> bool:
    not_none_count = df.filter(df[column_name].isNotNull()).count()
    return not_none_count == 0
    
if is_all_none(output_df, "scores"):
    st.error(
        "Perhaps you have reached the API usage limit. "
        "Please try again in a few minutes. "
    )
    st.stop()

エラー出力時の様子です。
image.png

UDFの実行結果表示

前節のJSONの出力結果をそのまま出力する形でも結果を知ることはできますが、それではつまらないので出力結果を整形します。コードが複雑なので、細切れに紹介していきます。(ここの処理は実はもっと簡略化できるんじゃないかと思いつつも取り組めていないです。)

まず、UDFの出力がJSON(VARIANT型)なので、パースします。

output_df = output_df.select(
    output_df["result"]["sequence"].alias("input text"),
    F.as_array(output_df["result"]["labels"]).alias("labels"),
    F.as_array(output_df["result"]["scores"]).alias("scores"),
)
output_df.collect()

上記のコードの出力結果は次の通りです。元々JSON形式で表示されていたのが、2つの列としてリストを表示できるようになりました。

input text labels scores
"Today is September 25th." ["Positive", "Negative", "Neutral"] [0.499, 0.386, 0.115]
"The weather today was fantastic." ["Positive", "Negative", "Neutral"] [0.956, 0.030, 0.014]
... ... ...

次に、分類ラベル列とスコア列のリストを辞書型のオブジェクトに変換したのち、F.explodeというAPIで個々の行に分解します。

@F.udf
def convert_to_dict(a: List[str], b: List[float]) -> Dict[str, float]:
    return dict(zip(a, b))

output_df = output_df.with_column(
    "label_score", convert_to_dict(F.col("labels"), F.col("scores"))
)
output_df.collect()

output_df = output_df.select(
    F.col("input text"),
    F.explode(F.col("label_score")).alias("score_key", "score_value"),
)
output_df.collect()

上記のコードの出力結果は次の通りです。レコードにリストが格納されていたのが複数の行に分割されています。

input text score_key score_value
"Today is September 25th." "Positive" 0.499
"Today is September 25th." "Negative" 0.386
"Today is September 25th." "Neutral" 0.115
"The weather today was fantastic." "Positive" 0.956
"The weather today was fantastic." "Negative" 0.030
"The weather today was fantastic." "Neutral" 0.014
... ... ...

ここまでくれば、あとはテーブルをピボットして完成です。

distinctKeys = [
    row["SCORE_KEY"] for row in output_df.select("score_key").distinct().collect()
]
output_df = output_df.pivot("score_key", distinctKeys).sum("score_value")
output_df.collect()

上記のコードの出力結果は次の通りであり、最終的に出力されるテーブルとなります。縦持ちを横持ちに変換し、表示として分かりやすい形式になりました!

input text Positive Negative Neutral
"Today is September 25th." 0.499 0.386 0.115
"The weather today was fantastic." 0.956 0.030 0.014
... ... ... ...

なお、解説のため細かくdf.collect()を呼び出していますが、なるべく一括でcollect()処理を行った方が、効率的なクエリを実行できるはずです。

おわりに

SiSからでも簡単に外部APIを使用できることが分かりました!今回は、入力された文章を任意のカテゴリに入力するアプリを構築しましたが、テーブルに格納されたレビューやコメントからカテゴリを分類するようなアプリも同様の方法で構築できそうです。

外部ネットワークアクセス機能は、概念さえつかんでしまえば非常に使いやすい機能だと感じました。権限については触れていませんが、外部ネットワークアクセスというセンシティブな機能であるため、ACCOUNTADMINか、必要な権限を付与されたロール以外はできない点についても想定と合います。外部アクセス統合を使用する際も、UDFの定義時に2つのパラメータ(EXTERNAL_ACCESS_INTEGRATIONS、SECRETS)を指定し、UDFのコードから環境変数として読み込むだけで良いため簡単でした。

今回は、少量のデータしか使用していませんが、これが大容量のデータであったとしても非常に高速に処理できるはずです(API制限を除いて)。今回紹介したように、既にSnowflakeのデータを活用して、機械学習や可視化、アプリケーション化と、様々なことができるようになっているので、是非こうした最新の機能で活用の幅を広げて、この領域を盛り上げていければ嬉しいです!

参考文献

仲間募集

NTTデータ テクノロジーコンサルティング事業本部 では、以下の職種を募集しています。

1. クラウド技術を活用したデータ分析プラットフォームの開発・構築(ITアーキテクト/クラウドエンジニア)

クラウド/プラットフォーム技術の知見に基づき、DWH、BI、ETL領域におけるソリューション開発を推進します。
https://enterprise-aiiot.nttdata.com/recruitment/career_sp/cloud_engineer

2. データサイエンス領域(データサイエンティスト/データアナリスト)

データ活用/情報処理/AI/BI/統計学などの情報科学を活用し、よりデータサイエンスの観点から、データ分析プロジェクトのリーダーとしてお客様のDX/デジタルサクセスを推進します。
https://enterprise-aiiot.nttdata.com/recruitment/career_sp/datascientist

3.お客様のAI活用の成功を推進するAIサクセスマネージャー

DataRobotをはじめとしたAIソリューションやサービスを使って、
お客様のAIプロジェクトを成功させ、ビジネス価値を創出するための活動を実施し、
お客様内でのAI活用を拡大、NTTデータが提供するAIソリューションの利用継続を推進していただく人材を募集しています。
https://nttdata.jposting.net/u/job.phtml?job_code=804

4.DX/デジタルサクセスを推進するデータサイエンティスト《管理職/管理職候補》 データ分析プロジェクトのリーダとして、正確な課題の把握、適切な評価指標の設定、分析計画策定や適切な分析手法や技術の評価・選定といったデータ活用の具現化、高度化を行い分析結果の見える化・お客様の納得感醸成を行うことで、ビジネス成果・価値を出すアクションへとつなげることができるデータサイエンティスト人材を募集しています。

https://nttdata.jposting.net/u/job.phtml?job_code=898

ソリューション紹介

Trusted Data Foundationについて

~データ資産を分析活用するための環境をオールインワンで提供するソリューション~
https://enterprise-aiiot.nttdata.com/tdf/
最新のクラウド技術を採用して弊社が独自に設計したリファレンスアーキテクチャ(Datalake+DWH+AI/BI)を顧客要件に合わせてカスタマイズして提供します。
可視化、機械学習、DeepLearningなどデータ資産を分析活用するための環境がオールインワンで用意されており、これまでとは別次元の量と質のデータを用いてアジリティ高くDX推進を実現できます。

TDFⓇ-AM(Trusted Data Foundation - Analytics Managed Service)について

~データ活用基盤の段階的な拡張支援(Quick Start) と保守運用のマネジメント(Analytics Managed)をご提供することでお客様のDXを成功に導く、データ活用プラットフォームサービス~
https://enterprise-aiiot.nttdata.com/service/tdf/tdf_am
TDFⓇ-AMは、データ活用をQuickに始めることができ、データ活用の成熟度に応じて段階的に環境を拡張します。プラットフォームの保守運用はNTTデータが一括で実施し、お客様は成果創出に専念することが可能です。また、日々最新のテクノロジーをキャッチアップし、常に活用しやすい環境を提供します。なお、ご要望に応じて上流のコンサルティングフェーズからAI/BIなどのデータ活用支援に至るまで、End to Endで課題解決に向けて伴走することも可能です。

NTTデータとTableauについて

ビジュアル分析プラットフォームのTableauと2014年にパートナー契約を締結し、自社の経営ダッシュボード基盤への採用や独自のコンピテンシーセンターの設置などの取り組みを進めてきました。さらに2019年度にはSalesforceとワンストップでのサービスを提供開始するなど、積極的にビジネスを展開しています。

これまでPartner of the Year, Japanを4年連続で受賞しており、2021年にはアジア太平洋地域で最もビジネスに貢献したパートナーとして表彰されました。
また、2020年度からは、Tableauを活用したデータ活用促進のコンサルティングや導入サービスの他、AI活用やデータマネジメント整備など、お客さまの企業全体のデータ活用民主化を成功させるためのノウハウ・方法論を体系化した「デジタルサクセス」プログラムを提供開始しています。
https://enterprise-aiiot.nttdata.com/service/tableau

NTTデータとAlteryxについて
Alteryxは、業務ユーザーからIT部門まで誰でも使えるセルフサービス分析プラットフォームです。

Alteryx導入の豊富な実績を持つNTTデータは、最高位にあたるAlteryx Premiumパートナーとしてお客さまをご支援します。

導入時のプロフェッショナル支援など独自メニューを整備し、特定の業種によらない多くのお客さまに、Alteryxを活用したサービスの強化・拡充を提供します。

https://enterprise-aiiot.nttdata.com/service/alteryx

NTTデータとDataRobotについて
DataRobotは、包括的なAIライフサイクルプラットフォームです。

NTTデータはDataRobot社と戦略的資本業務提携を行い、経験豊富なデータサイエンティストがAI・データ活用を起点にお客様のビジネスにおける価値創出をご支援します。

https://enterprise-aiiot.nttdata.com/service/datarobot

NTTデータとInformaticaについて

データ連携や処理方式を専門領域として10年以上取り組んできたプロ集団であるNTTデータは、データマネジメント領域でグローバルでの高い評価を得ているInformatica社とパートナーシップを結び、サービス強化を推進しています。
https://enterprise-aiiot.nttdata.com/service/informatica

NTTデータとSnowflakeについて
NTTデータでは、Snowflake Inc.とソリューションパートナー契約を締結し、クラウド・データプラットフォーム「Snowflake」の導入・構築、および活用支援を開始しています。

NTTデータではこれまでも、独自ノウハウに基づき、ビッグデータ・AIなど領域に係る市場競争力のあるさまざまなソリューションパートナーとともにエコシステムを形成し、お客さまのビジネス変革を導いてきました。
Snowflakeは、これら先端テクノロジーとのエコシステムの形成に強みがあり、NTTデータはこれらを組み合わせることでお客さまに最適なインテグレーションをご提供いたします。

https://enterprise-aiiot.nttdata.com/service/snowflake

(参考)ソースコード全文とディレクトリ構成

Streamlitアプリのためのディレクトリおよびソースコードは下記の通りです。下記の資材をステージにPUTしたうえで、CREATE STREAMLITコマンドを実行します。

ステージ配下のディレクトリ構成
@SANDBOX.HUGGING_FACE.HUGGING_FACE_STG/
 ├ environment.yml
 ├ streamlit_main.py
 └ logo.png
environment.yml
name: app_environment
channels:
  - snowflake
dependencies:
  - plotly=5.9.0
  - python=3.8.*
  - scikit-image=0.20.0
  - snowflake-snowpark-python=
  - streamlit=1.22.0
streamlit_main.py
from typing import Dict, List

import streamlit as st
import plotly.express as px
import snowflake.snowpark.functions as F
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.types import StringType, StructType, StructField
from skimage import io


snow_session = get_active_session()


def show_image(filepath: str, width: int, height: int) -> None:
    img = io.imread(filepath)
    fig = px.imshow(img)
    fig.update_layout(width=width, height=height, margin=dict(l=0, r=0, b=0, t=0))
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)
    fig.update_xaxes(showgrid=False, zeroline=False)
    fig.update_yaxes(showgrid=False, zeroline=False)
    st.plotly_chart(fig)


c1, c2 = st.columns([0.4, 2])
with c1:
    # show_image("logo.png", width=100, height=100)
    st.image("logo.png") # st.image()のサポートにつき修正

with c2:
    st.caption("")
    st.title("Zero-Shot Text Classifier")

st.markdown(
    "このアプリは、文章をカテゴリに高速で分類します。機械学習モデルの学習は必要ありません!\n\n"
    "分類させたいカテゴリ(e.g. `Positive`, `Negative` and `Neutral`)と文章を入力し、実行するだけです。"
)


def main():
    st.subheader("分類させたいカテゴリの入力")
    if "label_options" not in st.session_state:
        st.session_state["label_options"] = [
            "Informational",
            "Transactional",
            "Navigational",
            "Positive",
            "Negative",
            "Neutral",
        ]

    if "multiselected" not in st.session_state:
        st.session_state["multiselected"] = [
            "Positive",
            "Negative",
            "Neutral",
        ]

    adding_label = st.text_input("Add below select content")
    if adding_label != "":
        if adding_label not in st.session_state["label_options"]:
            st.session_state["label_options"].append(adding_label)

    st.multiselect(
        label="Add labels - 3 max",
        options=st.session_state["label_options"],
        default=st.session_state["multiselected"],
        max_selections=3,
        key="multiselected",
    )

    st.subheader("分類させたい文章の入力")
    new_line = "\n"
    nums = [
        "Today is September 25th.",
        "The weather today was fantastic.",
        "The service at this restaurant was very slow.",
        "That comment hurt his feelings.",
        "This dish is the best I've ever tasted.",
    ]

    sample = f"{new_line.join(map(str, nums))}"

    MAX_LINES = 5
    text = st.text_area(
        "Enter keyphrases to classify",
        sample,
        height=200,
        key="2",
        help="At least two keyphrases for the classifier to work, one per line, "
        + str(MAX_LINES)
        + " keyphrases max as part of the demo",
    )

    lines = text.split("\n")  # A list of lines
    linesList = []
    for x in lines:
        linesList.append(x)
    linesList = list(dict.fromkeys(linesList))  # Remove dupes
    linesList = list(filter(None, linesList))  # Remove empty

    linesList = linesList[:MAX_LINES]

    submit_button = st.button("Submit")
    if submit_button:
        st.info("Calling hugging face API!")

        input_df = snow_session.create_dataframe(
            linesList, StructType([StructField("input", StringType())])
        )
        output_df = input_df.with_column(
            "result",
            F.call_udf(
                "hugging_face_python",
                F.col("input"),
                F.lit(st.session_state["multiselected"]),
            ),
        )
        output_df.collect()

        output_df = output_df.select(
            output_df["result"]["sequence"].alias("input text"),
            F.as_array(output_df["result"]["labels"]).alias("labels"),
            F.as_array(output_df["result"]["scores"]).alias("scores"),
        )
        output_df.collect()

        def is_all_none(df, column_name: str) -> bool:
            not_none_count = df.filter(df[column_name].isNotNull()).count()
            return not_none_count == 0

        if is_all_none(output_df, "scores"):
            st.error(
                "Perhaps you have reached the API usage limit. "
                "Please try again in a few minutes. "
            )
            st.stop()

        else:

            @F.udf
            def convert_to_dict(a: List[str], b: List[float]) -> Dict[str, float]:
                return dict(zip(a, b))

            output_df = output_df.with_column(
                "label_score", convert_to_dict(F.col("labels"), F.col("scores"))
            )
            output_df.collect()

            output_df = output_df.select(
                F.col("input text"),
                F.explode(F.col("label_score")).alias("score_key", "score_value"),
            )
            output_df.collect()

            distinctKeys = [
                row["SCORE_KEY"]
                for row in output_df.select("score_key").distinct().collect()
            ]
            output_df = output_df.pivot("score_key", distinctKeys).sum("score_value")
            output_df.collect()

            st.subheader("分類結果の出力")

            st.dataframe(output_df)

            st.success("✅ Done!")


if __name__ == "__main__":
    main()
5
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
4