ObjectDetectionをアプリに実装するときに大体のYOLOモデルはGPLv3ライセンスで配布しており、アプリは配布したいけどコードの公開はしたくないケースが業務ではままあります。
そろそろ業務アプリで使う必要がありそうなので、さてどうしようかと思っていたところタイムリーにApache2ライセンスのYoloXにCoreML変換のPRが寄稿されました。
試しに動かしてみたところ、エラーになります🤔
PRにコメントを投稿してみましたが修正いただけないこともあり、気になる点が多々ありましたので必要な部分以外は書き直してみました。
最終的に下記のコード以外は大体書き直しました。
class YOLOXDetectModel(nn.Module):
"""Wrap an Ultralytics YOLO model for Apple iOS CoreML export."""
def __init__(self, model, im, num_of_class):
"""Initialize the YOLOXDetectModel class with a YOLO model and example image."""
super().__init__()
_, _, h, w = im.shape
self.model = model
self.nc = num_of_class # number of classes
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h])
def forward(self, x):
"""Normalize predictions of object detection model with input size-dependent factors."""
out_pred = self.model(x)
xywh = out_pred[:, :, :4][0]
objectness = out_pred[0][:, 4:5]
class_conf = out_pred[0][:, 5:5 + self.nc]
class_scores = objectness * class_conf
return class_scores, xywh * self.normalize
その他の、YoloXを扱う上での注意点としてはモデルの入力に正規化が必要ない点です。
ですので、CoreMLに変換する際、scaleを設定するとドハマリするのでお気をつけてください。
出力したモデル
コード