LoginSignup
15
9

More than 1 year has passed since last update.

ARの中で物体認識してワードクラウドを表示する

Last updated at Posted at 2021-05-08

作ったもの

こういうものを作りました。
カメラに写っているものを物体認識して、その名称や関連語を空間内に散りばめて表示するものです。

ソースコードはこちらになります。
https://github.com/shibuiwilliam/ARWithWord

前置き

ARアプリではスマホやグラスを通して現実空間に視覚的な演出を表示、操作しますが、そのほとんどが画像やエフェクトに限られているように思います。たとえばスマホでGoogle検索した動物の一部をARのオブジェクトとして仮想現実内に表示してみることができます。

goole検索

もちろんこうした視覚的な効果は面白いのですが、ARの使い途はモノに限られないと感がています。モノそのもの以外にも言葉や文字を空間に表示して意味づけする用途があっても良い気がします。というわけで、ARにワードクラウドを表示するシステムを作ってみました。

ワードクラウドとは

ワードクラウドは自然言語処理で使われるデータの可視化手法の一つです。テキストデータを解析して頻出単語やキーワードを2次元画像に表示します。以下の画像がその例ですが、テキストデータを分析して頻出語を表示することで、そのテキストの内容や重要な単語を可視化します。

wordcloud

簡単なワードクラウドは以下のサイトで作ることができます。

ワードクラウドは基本的には2次元で画像として作りますが、ARと組み合わせて3次元にしても良いと思いたって、今回アプリを作ってみました。

どう作るか

AR

ARアプリはUnityで作ります。AR Foundationでカメラを通した仮想現実のプレーンを認識し、文字を表示します。
今回はアプリはAndroidで動かしますが、Unity+AR FoundationなのでiOSでも動くと思います(未検証)。

物体認識

物体認識はUnity Barracudatiny YOLO v3を載せて使っています。本当はUnity Perceptionを試してみたかったのですが、ライブラリインストールが安定せず、Barracudaを使いました。
モデルはONNXによるtiny YOLO v3の学習済みモデルで、学習データはCOCOデータセットです。ですので認識できるものはここにある80カテゴリのみです。

自然言語処理

認識した物体に関連する単語はfastTextで検索しています。fastTextはFacebookが開発した単語をベクトル化する仕組みで、特定の単語を入力すると、その単語に近い関連語を得ることができます。今回はWikipediaでクローリングして作られた学習済みモデルを使っています。
単語分散表現のモデルは容量が大きい(数GB・・・)になるため、アプリには組み込まずクラウド側にREST APIサーバを作って関連語をリクエストできるようにします。
サイズは仕方ないにしても、word2vecやfastTextのような単語分散表現をONNX変換できるようになってほしいです。

全体像

全体像は以下になります。

02.png

UnityでARアプリのベースを開発し、その中でONNXベースのtiny YOLO v3の物体認識を実行します。認識した物体の関連語をKubernetesに作ったfastTextのREST APIにリクエストし、結果をAR Foundationでアンカーを差して表示するというものです。というわけで、システムとしてはUnityによるAndroidアプリとKubernetesに乗せるREST APIを開発しています。
アプリサイドとサーバサイド両方に、AR、Edge AI、サーバサイドのAIを組み合わせた仕組みになります。いろいろと組み合わさっていますが、基本的には既存のモデルとライブラリを組み合わせれば作れます。

もうちょい詳しく

Unity

Unityによるアプリ開発を説明します。
Unityアプリのコードは以下にあります。
https://github.com/shibuiwilliam/ARWithWord/tree/main/ARWithWord

Unityでは最初に3Dプロジェクトを作ります。

03.png

作成したプロジェクトにはAR FoundationもBarracudaも入っていないので、Package Managerで必要なパッケージを追加します。

  • AR Foundation:Package Managerで検索してインストール。
  • Barracuda:ここ参照。Package Managerで見つからない場合は Add package from git URLcom.unity.barracuda を入力すればインストールできます。
  • Android Logcat:UnityでAndroidアプリのログを表示するツールです。

04.png

パッケージをインストールしたら必要なリソースを追加します。
AR Session Originを作成し、そのコンポーネントとして AR Raycast Manager, AR Anchor Manager, AR Plane ManagerAdd Component します。加えて物体認識のコンポーネント(ObjectDetector)と関連語検索のクライアント(SimilarWordClient)を作っておきます。

05.png

AR Session Originにセッション中の実行コード(Spawn Manager Script)を実装します。コード全文は長いので省略しますが、以下のような内容になります。

// SpawnManager.cs

using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Linq;

using UnityEngine;
using UnityEngine.UI;
using UnityEngine.XR.ARFoundation;
using UnityEngine.XR.ARSubsystems;
using Unity.Collections;
using Unity.Collections.LowLevel.Unsafe;
using Unity.Barracuda;

public class DetectionTarget
{
    public Texture2D CurrentTexture2D;
    public Pose HitPose;
}

public class Detected
{
    public Pose HitPose;
    public IList<ItemDetected> ItemsDetected;
}

public class DetectedSimilarWords
{
    public Pose HitPose;
    public SimilarWords SimilarWords;
}

public class SpawnManager : MonoBehaviour
{
    [SerializeField] GameObject goText;

    [SerializeField]
    ARCameraManager m_arCameraManager;
    public ARCameraManager arCameraManager
    {
        get { return m_arCameraManager; }
        set { m_arCameraManager = value; }
    }

    [SerializeField]
    GameObject m_goObjectDetector;
    public GameObject goObjectDetector
    {
        get { return m_goObjectDetector; }
        set { m_goObjectDetector = value; }
    }
    private ObjectDetector objectDetector;


    [SerializeField]
    GameObject m_goSimilarWordClient;
    public GameObject goSimilarWordClient
    {
        get { return m_goSimilarWordClient; }
        set { m_goSimilarWordClient = value; }
    }
    private SimilarWordClient similarWordClient;

    public float shiftX = 0f;
    public float shiftY = 0f;
    public float scaleFactor = 1;

    private ARRaycastManager arRaycastManager;
    private List<ARRaycastHit> hits = new List<ARRaycastHit>();

    private static Texture2D _texture;

    private bool isDetecting = false;

    Texture2D m_Texture;

    Queue<DetectionTarget> detectionTargetQueue = new Queue<DetectionTarget>();
    Queue<Detected> detectedQueue = new Queue<Detected>();
    Queue<DetectedSimilarWords> detectedSimilarWordsQueue = new Queue<DetectedSimilarWords>();

    private List<Color32> colors = new List<Color32>() {
        new Color32(255, 115, 200, 255),
        new Color32(241, 233, 137, 255),
        new Color32(108, 109, 101, 255),
        new Color32(128, 244, 222, 255),
        new Color32(134, 222, 249, 255),
        new Color32(228, 178, 249, 255)
    };

    private void OnEnable()
    {
        // アプリ起動時処理
        Debug.Log("initialize AR camera manager frame received");
        m_arCameraManager.frameReceived += OnCameraFrameReceived;
        _texture = new Texture2D(1, 1);
        _texture.SetPixel(0, 0, new Color(0.3843137f, 0, 0.9333333f));
        _texture.Apply();
    }

    void Start()
    {
        // AR起動時処理
        this.arRaycastManager = GetComponent<ARRaycastManager>();

        // 物体認識
        this.objectDetector = goObjectDetector.GetComponent<TinyYolo3Detector>();
        this.objectDetector.Start();

        // 関連語REST APIクライアント
        this.similarWordClient = goSimilarWordClient.GetComponent<SimilarWordClient>();
        this.similarWordClient.Start();

    }

    void Update()
    {
        // 画面タッチでカメラに写っているものを物体認識する
        if (Input.touchCount > 0)
        {
            Touch touch = Input.GetTouch(0);
            if (touch.phase == TouchPhase.Began)
            {
                if (arRaycastManager.Raycast(touch.position, hits, TrackableType.PlaneWithinPolygon))
                {
                    Pose hitPose = hits[0].pose;
                    var detectionTarget = new DetectionTarget
                    {
                        CurrentTexture2D = this.m_Texture,
                        HitPose = hitPose
                    };
                    this.detectionTargetQueue.Enqueue(detectionTarget);
                    Debug.Log($"touched ({hitPose.position.x}, {hitPose.position.y}, {hitPose.position.z})!");
                }
            }
        }

        // 関連語の表示
        if (this.detectedSimilarWordsQueue.Count() > 0)
        {
            var detectedSimilarWord = this.detectedSimilarWordsQueue.Dequeue();
            Debug.Log($"similar word for {detectedSimilarWord.SimilarWords.word} has {detectedSimilarWord.SimilarWords.predictions.Count()} predictions");

            AllocateItem(detectedSimilarWord.SimilarWords.word, detectedSimilarWord.HitPose.position, detectedSimilarWord.HitPose.rotation);

            int i = 0;
            foreach (var prediction in detectedSimilarWord.SimilarWords.predictions)
            {
                Debug.Log($"{i} process {prediction.similar_word} with {prediction.similarity}");

                var hitPoseRandom = new Vector3(detectedSimilarWord.HitPose.position.x, detectedSimilarWord.HitPose.position.y, detectedSimilarWord.HitPose.position.z);

                float rX = UnityEngine.Random.Range(-1.0f, 1.0f);
                float rY = UnityEngine.Random.Range(-1.0f, 1.0f);
                float rZ = UnityEngine.Random.Range(-1.0f, 1.0f);
                hitPoseRandom.x += rX;
                hitPoseRandom.y += rY;
                hitPoseRandom.z += rZ;

                AllocateItem(prediction.similar_word, hitPoseRandom, detectedSimilarWord.HitPose.rotation);
            }
        }

        RequestSimilarWord();
    }


    private void AllocateItem(string word, Vector3 hitPose, Quaternion hitRotation)
    {
        Debug.Log($"allocate {word}");

        this.goText.GetComponent<TextMesh>().text = word;
        int colorIndex = UnityEngine.Random.Range(0, colors.Count());
        this.goText.GetComponent<TextMesh>().color = this.colors[colorIndex];

        int characterSize = UnityEngine.Random.Range(6, 18);
        this.goText.GetComponent<TextMesh>().characterSize = characterSize;

        Instantiate(goText, hitPose, hitRotation);
        Debug.Log($"allocated {word} on {hitPose.x}, {hitPose.y}, {hitPose.z}");
    }

    // ARカメラの画像を取得
    unsafe void OnCameraFrameReceived(ARCameraFrameEventArgs eventArgs)
    {
        if (!arCameraManager.TryAcquireLatestCpuImage(out XRCpuImage image))
        {
            return;
        }

        var conversionParams = new XRCpuImage.ConversionParams
        {
            inputRect = new RectInt(0, 0, image.width, image.height),
            outputDimensions = new Vector2Int(image.width, image.height),
            outputFormat = TextureFormat.RGBA32,
            transformation = XRCpuImage.Transformation.None
        };
        int imageSize = image.GetConvertedDataSize(conversionParams);
        var buffer = new NativeArray<byte>(imageSize, Allocator.Temp);
        image.Convert(conversionParams, new IntPtr(buffer.GetUnsafePtr()), buffer.Length);
        image.Dispose();

        this.m_Texture = new Texture2D(
            conversionParams.outputDimensions.x,
            conversionParams.outputDimensions.y,
            conversionParams.outputFormat,
            false
        );
        this.m_Texture.LoadRawTextureData(buffer);
        this.m_Texture.Apply();
        buffer.Dispose();

        Detect();
    }

    // 物体認識
    private void Detect()
    {
        if (this.isDetecting)
        {
            return;
        }
        if (this.detectionTargetQueue.Count() == 0)
        {
            return;
        }

        var detectionTarget = this.detectionTargetQueue.Dequeue();
        this.isDetecting = true;
        StartCoroutine(
            ProcessImage(
                this.objectDetector.IMAGE_SIZE, detectionTarget.CurrentTexture2D, picture =>
                {
                    StartCoroutine(
                        this.objectDetector.Detect(
                            picture, itemsDetected =>
                            {
                                if (itemsDetected.Count > 0)
                                {
                                    var detected = new Detected
                                    {
                                        HitPose = detectionTarget.HitPose,
                                        ItemsDetected = itemsDetected,
                                    };
                                    this.detectedQueue.Enqueue(detected);
                                }
                                Resources.UnloadUnusedAssets();
                                this.isDetecting = false;
                            }
                        )
                    );
                }
            )
        );
    }


    private IEnumerator ProcessImage(int inputSize, Texture2D texture2D, Action<Color32[]> callback)
    {
        Coroutine croped = StartCoroutine(
            TextureTools.CropSquare(
                texture2D, TextureTools.RectOptions.Center, snap =>
                {
                    var scaled = Scale(snap, inputSize);
                    var rotated = Rotate(scaled.GetPixels32(), scaled.width, scaled.height);
                    callback(rotated);
                }
            )
        );
        yield return croped;
    }


    private Texture2D Scale(Texture2D texture, int imageSize)
    {
        Texture2D scaled = TextureTools.scaled(texture, imageSize, imageSize, FilterMode.Bilinear);
        return scaled;
    }


    private Color32[] Rotate(Color32[] pixels, int width, int height)
    {
        Color32[] rotate = TextureTools.RotateImageMatrix(pixels, width, height, 90);
        return rotate;
    }

    // 関連語をリクエスト
    private void RequestSimilarWord()
    {
        if (this.detectedQueue.Count()==0)
        {
            return;
        }

        var detected = this.detectedQueue.Dequeue();
        StartCoroutine(
            this.similarWordClient.SimilarWordAPI(
                detected.ItemsDetected[0].PredictedItem.Label, 20, results =>
                {
                    Debug.Log($"result {results}");
                    var detectedSimilarWords = new DetectedSimilarWords
                    {
                        HitPose = detected.HitPose,
                        SimilarWords=results,
                    };
                    this.detectedSimilarWordsQueue.Enqueue(detectedSimilarWords);
                }
            )
        );
    }
}

コードの中でやっているフローは以下のようになります。

06.png

各処理はQueueに入れてCoroutineで非同期に進めていきます。メモリは消費しますが、物体認識やRESTリクエストの遅延でUIに影響を与えない工夫です。
ARカメラに写っている画像は OnCameraFrameReceived メソッドで取得します。この方法は公式ドキュメントで説明されているものです。
画面タッチ時に OnCameraFrameReceived で取得した最新の画像をtiny YOLO v3で物体認識します。この時点で認識した物体の名称と位置を取得し、キューに溜めます。
物体認識のコードは以下のようなものになります。

// TinyYolo3Detector.cs

using System;
using UnityEngine;
using Unity.Barracuda;
using System.Linq;
using System.Collections;
using System.Collections.Generic;


public class Parameters
{
    public int ROW_COUNT;
    public int COL_COUNT;
    public int CELL_WIDTH;
    public int CELL_HEIGHT;
    public Parameters(int ROW_COUNT, int COL_COUNT, int CELL_WIDTH, int CELL_HEIGHT)
    {
        this.ROW_COUNT = ROW_COUNT;
        this.COL_COUNT = COL_COUNT;
        this.CELL_WIDTH = CELL_WIDTH;
        this.CELL_HEIGHT = CELL_HEIGHT;
    }
}

public class TinyYolo3Detector : MonoBehaviour, ObjectDetector
{

    public NNModel modelFile;

    // ラベルの言語。英語or日本語
    public enum LabelLanguages
    {
        EN,
        JP
    };
    public LabelLanguages labelLanguage;


    public string inputName;
    public string outputNameL;
    public string outputNameM;

    private const int IMAGE_MEAN = 0;
    private const float IMAGE_STD = 255.0F;
    private const int _IMAGE_SIZE = 416;
    public int IMAGE_SIZE
    {
        get => _IMAGE_SIZE;
    }

    public float minConfidence = 0.25f;

    private IWorker worker;
    private Model model;

    public Parameters paramsL = new Parameters(13, 13, 32, 32);
    public Parameters paramsM = new Parameters(26, 26, 16, 16);

    public const int BOXES_PER_CELL = 3;
    public const int BOX_INFO_FEATURE_COUNT = 5;

    private int classLength;
    private string[] labels;


    public void Start()
    {
        switch(this.labelLanguage)
        {
            case LabelLanguages.EN:
                this.labels = Constants.cocoLabelEN;
                break;
            case LabelLanguages.JP:
                this.labels = Constants.cocoLabelJP;
                break;
            default:
                this.labels = Constants.cocoLabelEN;
                break;
        }
        this.classLength = this.labels.Length;

        this.model = ModelLoader.Load(this.modelFile);
        this.worker = GraphicsWorker.GetWorker(this.model);
        Debug.Log($"Initialized model and labels: {this.classLength} classes");
    }


    public IEnumerator Detect(Color32[] picture, Action<IList<ItemDetected>> callback)
    {
        // Coroutineで物体認識を実行
        Debug.Log("Run detection");
        using (var tensor = TransformInput(picture, this.IMAGE_SIZE, this.IMAGE_SIZE))
        {
            var inputs = new Dictionary<string, Tensor>();
            inputs.Add(this.inputName, tensor);
            yield return StartCoroutine(this.worker.StartManualSchedule(inputs));
            var outputL = this.worker.PeekOutput(this.outputNameL);
            var outputM = this.worker.PeekOutput(this.outputNameM);
            List<ItemDetected> results = ParseOutputs(outputL, outputM, this.paramsL, this.paramsM);
            Debug.Log($"yielded {results.Count()} results");
            callback(results);
        }
    }


    public static Tensor TransformInput(Color32[] pic, int width, int height)
    {
        float[] floatValues = new float[width * height * 3];

        for (int i = 0; i < pic.Length; ++i)
        {
            Color32 color = pic[i];

            floatValues[i * 3 + 0] = (color.r - IMAGE_MEAN) / IMAGE_STD;
            floatValues[i * 3 + 1] = (color.g - IMAGE_MEAN) / IMAGE_STD;
            floatValues[i * 3 + 2] = (color.b - IMAGE_MEAN) / IMAGE_STD;
        }

        return new Tensor(1, height, width, 3, floatValues);
    }

    private List<ItemDetected> ParseOutputs(Tensor yoloModelOutputL, Tensor yoloModelOutputM, Parameters parametersL, Parameters parametersM)
    {
        var itemsInCenter = new List<ItemDetected>();

        for (var box = 0; box < BOXES_PER_CELL; box++)
        {
            for (int cy = 0; cy < parametersL.COL_COUNT; cy++)
            {
                for (var cx = 0; cx < parametersL.ROW_COUNT; cx++)
                {
                    var result = Parse(cx, cy, box, yoloModelOutputL);
                    if (result != null)
                    {
                        itemsInCenter.Add(result);
                    }
                }
            }

            for (int cy = 0; cy < parametersM.COL_COUNT; cy++)
            {
                for (var cx = 0; cx < parametersM.ROW_COUNT; cx++)
                {
                    var result = Parse(cx, cy, box, yoloModelOutputM);
                    if (result != null)
                    {
                        itemsInCenter.Add(result);
                    }
                }
            }
        }

        return itemsInCenter;
    }

    private ItemDetected Parse(int cx, int cy, int box, Tensor yoloModelOutput)
    {
        var channel = (box * (this.classLength + BOX_INFO_FEATURE_COUNT));
        float confidence = GetConfidence(yoloModelOutput, cx, cy, channel);
        if (confidence < this.minConfidence)
        {
            return null;
        }

        float[] predictedClasses = ExtractClasses(yoloModelOutput, cx, cy, channel);
        var (topResultIndex, topResultScore) = GetTopResult(predictedClasses);
        var topScore = topResultScore * confidence;
        if (topScore < this.minConfidence)
        {
            return null;
        }

        var itemInCenter = new ItemDetected
        {
            PredictedItem = new Prediction
            {
                Label = labels[topResultIndex],
                Confidence = topScore,
            }
        };

        return itemInCenter;
    }


    private float Sigmoid(float value)
    {
        var k = (float)Math.Exp(value);
        return k / (1.0f + k);
    }


    private float[] Softmax(float[] values)
    {
        var maxVal = values.Max();
        var exp = values.Select(v => Math.Exp(v - maxVal));
        var sumExp = exp.Sum();
        return exp.Select(v => (float)(v / sumExp)).ToArray();
    }


    private float GetConfidence(Tensor modelOutput, int x, int y, int channel)
    {
        return Sigmoid(modelOutput[0, x, y, channel + 4]);
    }

    public float[] ExtractClasses(Tensor modelOutput, int x, int y, int channel)
    {
        float[] predictedClasses = new float[this.classLength];
        int predictedClassOffset = channel + BOX_INFO_FEATURE_COUNT;

        for (var predictedClass = 0; predictedClass < this.classLength; predictedClass++)
        {
            predictedClasses[predictedClass] = modelOutput[0, x, y, predictedClass + predictedClassOffset];
        }

        return Softmax(predictedClasses);
    }


    private ValueTuple<int, float> GetTopResult(float[] predictedClasses)
    {
        return predictedClasses
            .Select((predictedClass, index) => (Index: index, Value: predictedClass))
            .OrderByDescending(result => result.Value)
            .First();
    }


    private List<ValueTuple<int, float>> GetOrderedResult(float[] predictedClasses)
    {
        return predictedClasses
            .Select((predictedClass, index) => (Index: index, Value: predictedClass))
            .OrderByDescending(result => result.Value)
            .ToList();
    }
}

物体認識では TransformInput で画像をテンソルに変換、リサイズします。PeekOutでtiny YOLO v3から2種類の推論結果を取得します。ONNXモデルはUnity画面内で内部のレイヤーや入出力を見ることができるのですが、以下のとおりOutputが2個(いずれも物体のバウンディングボックスとラベル)出力されます。

07.png

認識した物体名はREST APIクライアントで関連語をリクエストします。REST APIクライアントにはUnityのUnityWebRequestを使っています。実装は以下のとおりです。
やはりCoroutineで起動する仕組みになっています。リクエストもレスポンスもJSONを使います。

// SimilarWordClient.cs

using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.Networking;

[Serializable]
public class SecretJson
{
    public string url;
    public string secret;
}

[Serializable]
public class PostData
{
    public string word;
    public int topn;
}

[Serializable]
public class SimilarWord
{
    public string similar_word;
    public float similarity;
}

[Serializable]
public class SimilarWords
{
    public string word;
    public SimilarWord[] predictions;
}


public class SimilarWordClient : MonoBehaviour
{
    public TextAsset secretFile;
    private string secretString;
    private SecretJson secretJson;
    private string similarWordUrl;

    public void Start()
    {
        this.secretString = Resources.Load<TextAsset>(this.secretFile.name).ToString();
        this.secretJson = SecretJson.Deserialize(this.secretString);
        this.similarWordUrl = $"{this.secretJson.url}/similar-word/";
    }

    public IEnumerator SimilarWordAPI(string word, int topn, Action<SimilarWords> callback)
    {
        using (var request = new UnityWebRequest(this.similarWordUrl, "POST"))
        {
            Debug.Log($"Request {word} for {topn}");
            PostData postData = new PostData();
            postData.word = word;
            postData.topn = topn;
            string postJson = PostData.Serialize(postData);
            byte[] byteData = System.Text.Encoding.UTF8.GetBytes(postJson);

            request.uploadHandler = (UploadHandler)new UploadHandlerRaw(byteData);
            request.downloadHandler = (DownloadHandler)new DownloadHandlerBuffer();
            request.SetRequestHeader("accept", "application/json");
            request.SetRequestHeader("Content-Type", "application/json");
            request.SetRequestHeader("X-API-KEY", this.secretJson.secret);
            yield return request.SendWebRequest();

            if (request.result == UnityWebRequest.Result.ConnectionError || request.result == UnityWebRequest.Result.ProtocolError)
            {
                Debug.Log("Error POST request");
                Debug.Log(request.error);
            }
            else
            {
                Debug.Log($"status code [{request.responseCode}]");
                if (request.responseCode==200)
                {
                    Debug.Log("POST request succeeded");
                    string json = request.downloadHandler.text;
                    SimilarWords similarWords = SimilarWords.Deserialize(json);
                    callback(similarWords);
                }
                else
                {
                    Debug.Log("POST request failed");
                }
            }
        }
    }

    public void Update()
    {

    }
}

Kubernetesと関連語REST APIサーバ

続いてサーバサイドの実装です。サーバサイドのコードは以下にあります。
https://github.com/shibuiwilliam/ARWithWord/tree/main/backend

fastTextは既存のモデルを使いますが、1ファイルが圧縮状態で4GB、解凍して7GB、メモリに展開すると15GBが必要になるというデカいモデルです。金があれば64GBメモリのVMを使って起動しても良いですが、モデルの取得やメモリロードで15分以上かかって使い勝手がとても悪いです。なので単語表現の次元数を減らします。標準の次元数は300ですが、サイズと精度を測って100次元に減らします。これでモデルサイズは2GB、メモリに展開して4GBというサイズになりました。モデルロードも5分以内で済みます。
fastTextの次元削減方法は公式ドキュメントにあります。単語分散表現の次元数は各単語を表す数値の数になります。これが多いほうが単語を詳細に表現できますが、容量が大きく計算も重くなります。減らしても実用上問題なければ減らしても良いでしょう。

REST APIはFastAPIを使います。PythonのWeb APIはいろいろ試しましたが、FastAPIが一番安定して構造的に書けて使いやすいです。
REST APIの実装は以下のようになります。

import os
from typing import List, Tuple, Dict
import shutil
import gzip
from urllib.request import urlopen

import fasttext
import fasttext.util
from gensim.models.fasttext import load_facebook_model
from google.cloud import storage

from fastapi import APIRouter, HTTPException, Security
from fastapi.security.api_key import APIKeyHeader
from starlette import status

from logging import getLogger

from src.constants import LANGUAGE_ENUM
from src.data.schema import Prediction, Predictions, PredictionRequest


logger = getLogger(__name__)

# 関連語推論クラス
class SimilarWordPredictor(object):
    def __init__(
        self,
        bucket_name: str,
        model_directory: str = "/opt/",
        language: LANGUAGE_ENUM = LANGUAGE_ENUM.ENGLISH,
        model_dimension: int = 100,
        threshold: float = 0.6,
    ):
        self.bucket_name = bucket_name
        self.client = storage.Client()
        self.bucket = self.client.get_bucket(self.bucket_name)

        self.model_directory = model_directory
        self.language = language
        self.model_dimension = model_dimension
        if self.model_dimension not in [100, 300]:
            raise ValueError("model dimension must be one of 100 or 300")

        self.file_path = self.download_model(force_download=False)
        self.fasttext_predictor = load_facebook_model(self.file_path)
        logger.info(f"loaded {self.file_path}")

        self.threshold = threshold

        # local cache
        self.cache: Dict[str, Predictions] = {}

    def predict(
        self,
        word: str,
        topn: int = 20,
    ) -> Predictions:
        logger.info(f"predict {word}")

        key = f"{word}_{topn}"
        if key in self.cache.keys():
            return self.cache[key]

        results = self.fasttext_predictor.wv.most_similar(
            word,
            topn=topn,
        )

        _predictions = []
        for r in results:
            if r[1] < self.threshold:
                continue
            if repr(r[0]).startswith("'\\u"):
                continue
            _predictions.append(
                Prediction(
                    similar_word=r[0],
                    similarity=r[1],
                )
            )
        logger.info(f"{word} prediction: {_predictions}")
        predictions = Predictions(
            word=word,
            predictions=_predictions,
        )

        self.cache[key] = predictions

        return predictions

    def download_model(
        self,
        force_download: bool = False,
    ):
        file_name = f"cc.{self.language.value}.{self.model_dimension}.bin"
        file_path = os.path.join(self.model_directory, file_name)
        logger.info(f"retrieve model {file_name}")

        if os.path.exists(file_path):
            if not force_download:
                logger.info(f"model {file_name} exists")
                return file_path

        blob = self.bucket.blob(file_name)
        blob.download_to_filename(file_path)

        logger.info(f"retrieved model {file_name}")

        return file_path


similar_word_predictor = SimilarWordPredictor(
    bucket_name=os.environ["BUCKET_NAME"],
    model_directory=os.getenv("MODEL_DIRECTORY", "/opt/"),
    language=LANGUAGE_ENUM[os.getenv("LANGUAGE", "ENGLISH").upper()],
    model_dimension=int(os.getenv("MODEL_DIMENSION", 100)),
    threshold=float(os.getenv("THRESHOLD", 0.6)),
)


router = APIRouter()

api_key_header_auth = APIKeyHeader(
    name="X-API-KEY",
    auto_error=True,
)


def get_api_key(api_key_header: str = Security(api_key_header_auth)):
    if api_key_header != os.environ["PASSPHRASE"]:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid API Key",
        )


# 関連語API
@router.post(
    "/",
    response_model=Predictions,
    dependencies=[Security(get_api_key)],
)
def predict(
    prediction_request: PredictionRequest,
):
    predictions = similar_word_predictor.predict(
        word=prediction_request.word,
        topn=prediction_request.topn,
    )
    return predictions


# サンプル用API
@router.get(
    "/sample/",
    response_model=Predictions,
    dependencies=[Security(get_api_key)],
)
def predict_sample():
    predictions = similar_word_predictor.predict(
        word="ネコ",
        topn=20,
    )
    return predictions

システムはGCPのGKE(Kubernetes)で動かす想定です。モデルファイルはGCPストレージに入れてあり、必要に応じてダウンロードします。
関連語は検索するごとに辞書に登録していき、2度目以降の検索は辞書からレスポンスすることで高速化を図ります。

これでDockerビルドしてKubernetesに乗せればAPIの完成です。

作ったもの

最初に書いたとおり、カメラに写っているものを認識して、その関連語をワードクラウドのように空間内に散りばめて表示します。

おわりに

ご覧のとおり、これでなにかの課題を解決するものではありません。
しかしARで現実空間にエフェクトを出すとき、空間内に存在する物体の名称とその意味を認識し、次のアクションを共起するような使い方ができると思い、その試作品として作ってみました。
AIとARは相性が良い割りには組み合わせたプロダクトを作っている例はあまり見ないですし、さらに自然言語処理を組み込んでいるのを見たことがないので、ひとまず作ってみたシステムになります。
次はImage to TextやGPT-3を組み合わせて空間の状況に説明を表示するARを作りたいです(時間があれば)。

まあまあ楽しかったです。緊急事態宣言でGW中引きこもっていた暇潰しになりました。

15
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
9