LoginSignup
0
2

More than 3 years have passed since last update.

【PyTorch】手話の数字をPyTorchで転移学習してtfliteに変換してみた ※後編

Last updated at Posted at 2020-12-19

手話の数字(1~10)をPyTorchで転移学習してtfliteに変換してみた ※後編

前回作成したモデルをpytorch →onnx →tensorflow →tflite と変換。pytorchとtfliteのモデルで画像に対する出力の動作チェックを行いました。環境や参考にさせていただいたページは前回記事に乗せております。

コードと結果サマリ

前回はpytorchで転移学習、動作チェックまで、以下続きです

#pytorchからonnxへモデルを変換

# Create random input
input_data = torch.randn(1,3,224,224, device="cuda:0")
input_data.cuda()

# Create network
model = model_ft.cuda()

# Forward Pass
output = model(input_data)

# Export model to onnx
filename_onnx = "jsl_one_to_tenl.onnx"
torch.onnx.export(model, input_data, filename_onnx)

# onnxからtensorflowへモデルを変換
filename_tf = "jsl_one_to_ten_tensorflow"
onnx_model = onnx.load(filename_onnx)
tf_rep = prepare(onnx_model)
tf_rep.export_graph(filename_tf)
#tensorflowからtfliteへモデルを変換
converter = tf.lite.TFLiteConverter.from_saved_model(filename_tf)
tflite_model = converter.convert()
open("jsl_one_to_ten.tflite", "wb").write(tflite_model)

※tfliteのモデルとしては43Mくらいのものが保存されました

pytorchモデルのロードと実行

python版のmediapipeが(未確認ですが)CPUとのことなので、Colab ノートブックのランタイムをCPUに変更してモデルのロードを実行
 

from google.colab import drive
drive.mount('/content/drive')

from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os 
import copy
from PIL import Image

plt.ion()
os.chdir('/content/drive/My Drive/任意のパスに書き換えてください')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

出力:cpu 

saved_model = 'model/jsl_one_to_ten_cpu.pth'
model_ft = models.resnet18(pretrained=False)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 10)
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
lr_scheduler = torch.optim.lr_scheduler
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

model_ft.load_state_dict(torch.load(saved_model, map_location=torch.device(device)))

classes = ['1', '10', '2', '3', '4', '5', '6', '7', '8', '9']

※pytorchのモデルをstate_dict(推奨)で保存した場合、学習時と同じようにモデルを定義してからstate_dictを読み込ませます


test_img = "images/val/4/20201107044432_421.jpeg"
image = Image.open(test_img)
image = image.resize((224, 224))

trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])
#trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

input_tensor = trans(image)
input_tensor = input_tensor.unsqueeze(dim=0)
print(input_tensor)

tensor([[[[-0.8849, -1.0048, -1.1418, ..., 1.0331, 1.6324, 1.5982],
※テンソルの出力は省略、手形4を集めたフォルダから画像を1枚読み込んでいます

model_ft.eval()
outputs = model_ft(input_tensor)
preds = torch.max(outputs, 1)[1]
#preds = outputs.max(0, keepdim=True)[1] 
print(outputs)
print("jsl number is " + classes[preds.item()])

tensor([[-2.4689, -2.7054, -3.1834, 0.1467, 14.4020, -0.0529, -1.4861, -1.0645,
-2.0314, -2.6948]], grad_fn=)
jsl number is 4 ※モデルの判定が成功しました

tfliteモデルのロードと実行

mediapipeで使用されているモデルがtfliteのため、tfliteの実行方法を確認

from google.colab import drive
drive.mount('/content/drive')

import os
import sys
import numpy as np
import tensorflow as tf
from PIL import Image
os.chdir('/content/drive/My Drive/任意のパスに書き換えてください')
# Load model
interpreter = tf.lite.Interpreter(model_path="jsl_one_to_ten.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# 入出力フォーマットを確認
print('入出力フォーマットを確認')
print(input_details)
print(output_details)

# 入力のshape取得
input_shape = input_details[0]['shape']
print('shape確認')
print(input_shape)

#ラベルの宣言
classes = ['1', '10', '2', '3', '4', '5', '6', '7', '8', '9']

入出力フォーマットの確認と入力shapeの確認

入出力フォーマットを確認
[{'name': 'serving_default_input.1:0', 'index': 0, 'shape': array([ 1, 3, 224, 224], dtype=int32), 'shape_signature': array([ 1, 3, 224, 224], dtype=int32), 'dtype': , 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
[{'name': 'PartitionedCall:0', 'index': 156, 'shape': array([ 1, 10], dtype=int32), 'shape_signature': array([ 1, 10], dtype=int32), 'dtype': , 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
shape確認
[ 1 3 224 224]

# テスト画像
import torch
import torchvision
test_img = "images/val/4/20201107044432_421.jpeg"

image = Image.open(test_img)
image = image.resize((224, 224))

trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])

input_tensor = trans(image)
input_tensor = input_tensor[np.newaxis,:,:,:]
print(input_tensor[0][0])
print(input_tensor.shape)

tensor([[-0.8849, -1.0048, -1.1418, ..., 1.0331, 1.6324, 1.5982],
torch.Size([1, 3, 224, 224])
※テンソルの値と次元を確認しております

#input_shape = input_details[0]['shape']
#reshaped_img = img_data.reshape(input_shape)

print('入力データ')
print(input_tensor)
interpreter.set_tensor(input_details[0]['index'], input_tensor)

# 実行
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])

print('出力データ')
output_data = [f'{output_data:.2f}' for output_data in output_data[0]]
print(output_data)
print("jsl number is " + classes[np.argmax(output_data)])

入力データ
tensor([[[[-0.8849, -1.0048, -1.1418, ..., 1.0331, 1.6324, 1.5982],
※テンソルの出力は省略、手形4の画像のテンソルです

出力データ
['-2.47', '-2.71', '-3.18', '0.15', '14.40', '-0.05', '-1.49', '-1.06', '-2.03', '-2.69']
jsl number is 4 ※モデルでの判定成功

所感

自分でモデルを作れるようになりましたのでmediapipeに組み込んで動きを見てみたいと思います。pythonで動くようになったようですのでmediapipe+自作モデル+PILで日本語表示などを実験してみたいと思います。

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2