はじめに
ラズパイカメラで撮った画像を画像分類して、結果の信頼度が微妙だったときにS3バケットに自動でアップロードして再学習用の画像を簡単に増やす。今回はそんな仕組みの雛形をGreenGrassとLambdaで作ります。
なおこの記事は以下の記事の続きです。
今回作成するLambdaの全体像
Lambdaの構成図
camera.pyとinferene.pyにコードを足していく
camera.py
はtest.jpg
としてカメラの画像を保存する行camera.capture('/home/pi/test.jpg')
だけ足します。
#
# Copyright 2010-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# This class is a camera that uses picamera to take a photo and DLC compiled
# Resnet-50 model to perform image classification, identifying the objects
# shown in the photo.
#
from io import BytesIO
import picamera
import time
import datetime
import boto3
class Camera(object):
r"""
Camera that captures an image for performing inference
with DLC compiled model.
"""
def capture_image(self):
r"""
Capture image with PiCamera.
"""
camera = picamera.PiCamera()
imageData = BytesIO()
try:
camera.resolution = (224, 224)
print("Taking a photo from your camera...")
camera.start_preview()
time.sleep(2)
camera.capture(imageData, format = "jpeg", resize = (224, 224))
camera.stop_preview()
imageData.seek(0)
# とりあえず/home/piに'test.jpg'という形で画像を保存
camera.capture('/home/pi/test.jpg')
return imageData
finally:
camera.close()
raise RuntimeError("There is a problem with your camera.")
inference.pyは複数のモジュールのインポートの行と条件付け、S3アップロードの行を足してあります。
#
# Copyright 2010-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Greengrass lambda function to perform Image Classification with example model
# Resnet-50 that was compiled by DLC.
#
#
import logging
import os
from dlr import DLRModel
from PIL import Image
import numpy as np
import greengrasssdk
import camera
import utils
# 追加したコードに必要なモジュール
import time
import datetime
import boto3
# Create MQTT client
mqtt_client = greengrasssdk.client('iot-data')
# Initialize logger
customer_logger = logging.getLogger(__name__)
# LambdaをGGコンテナ内で動かすとき
# model_resource_path = os.environ.get('MODEL_PATH', '/trained_models')
# LambdaをGGコンテナ無しで動かすとき
model_resource_path = os.getenv("AWS_GG_RESOURCE_PREFIX") + "/trained_models"
dlr_model = DLRModel(model_resource_path, 'cpu')
# Read synset file
synset_path = os.path.join(model_resource_path, 'synset.txt')
with open(synset_path, 'r') as f:
synset = eval(f.read())
def predict(image_data):
r"""
Predict image with DLR. The result will be published
to MQTT topic '/resnet-50/predictions'.
:param image: numpy array of the Image inference with.
"""
flattened_data = image_data.astype(np.float32).flatten()
prediction_scores = dlr_model.run({'data' : flattened_data})
max_score_id = np.argmax(prediction_scores)
max_score = np.max(prediction_scores)
# Prepare result
predicted_class = synset[max_score_id]
result = 'Inference result: "{}" with probability {}.'.format(predicted_class, max_score)
# Send result
send_mqtt_message(
'Prediction Result: {}'.format(result))
# ここで推論結果に対して条件付けをして、条件に引っ掛かったらS3バケットにアップロードする。ここでは信頼度が80%以下の時アップロード
if max_score < 0.8:
s3 = boto3.resource('s3')
# あげる対象となるS3バケット
bucketName = 'minagawabucket'
send_mqtt_message("The probability isn't high enough, sending data to {}".format(bucketName))
# アップロードする時のタイムスタンプをファイルの末尾につけたい
date = datetime.datetime.now() # date to be like datetime.datetime(2021, 6, 11, 17, 7, 8, 805672)
timestamp = date.strftime("%Y-%m-%d %H:%M:%S") # timestamp to be like '2021-06-11 17:07:08'
# S3バケットのPicturesForRetraining配下に推論結果のクラス(例えば猫や車)のフォルダを作ってそこに格納する
s3Key = 'PicturesForRetraining/' + predicted_class + '/' + timestamp + '.jpg'
# camera.pyで保存したtest.jpgをアップロードする
data = open('/home/pi/test.jpg', mode='rb')
s3.Bucket(bucketName).put_object(Key = s3Key, Body = data)
def predict_from_cam():
r"""
Predict with the photo taken from your pi camera.
"""
send_mqtt_message("Taking a photo...")
my_camera = camera.Camera()
image = Image.open(my_camera.capture_image())
image_data = utils.transform_image(image)
send_mqtt_message("Start predicting...")
predict(image_data)
def send_mqtt_message(message):
r"""
Publish message to the MQTT topic:
'/resnet-50/predictions'.
:param message: message to publish
"""
mqtt_client.publish(topic='/resnet-50/predictions',
payload=message)
# The lambda to be invoked in Greengrass
def handler(event, context):
try:
predict_from_cam()
except Exception as e:
customer_logger.exception(e)
send_mqtt_message(
'Exception occurred during prediction. Please check logs for troubleshooting: /greengrass/ggc/var/log.')
テスト
テストしてみます。#
にサブスクライブしてtest
になんでも良いのでメッセージをパブリッシュ。
7%
の確率でマッチ棒
。自動的に指定したバケットの再学習用の画像フォルダに送ります。
無事送れていますね。
これでSageMaker GroundTruth等でのアノテーション作業や再学習もやりやすくなりました。
あとがき
・一度SDカードにファイルを書く処理を挟んでいるのはあまりスマートではない気がします。カメラで撮影した画像データをそのまま推論を実行する関数に渡してそのままS3にアップロードするようにしたい。SDカードの寿命のためになるのとそれだけ無駄な処理が減るので。
・会社から自宅のラズパイにMQTTでLambdaを起動してカメラ撮影をしたのですが、部屋が暗くて画像が真っ黒でした。現在使っているMLモデルが真っ黒だとマッチ棒と推論するらしく、結果は全てマッチ棒というちょっと微妙な結果になりました。あ、そうだSwitchbotで電気つければいいじゃんと思ったのですが、物理スイッチの方がOFFになっていたらしく反応なし。。Switchbotあるあるですね。