備忘録のためのメモ
NNの中間層を特徴量抽出器として扱うことがあったりします
pytorchでforward関数の書き換え
モデル本来の推論をしつつ中間層を取得するのは(恐らく)容易ではないため,forward関数を書き換えて中間層の出力を無理やりreturnします.
python
import torch
model :torch.nn.Module = torch.hub.load('pytorch/vision:v0.10.0', 'wide_resnet50_2', pretrained=True)
model.eval()
def _forward_impl(self, x):
# See note [TorchScript super()]
# xs = []
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x) # <- 欲しい中間層でreturn
# x = self.layer2(x)
# x = self.layer3(x)
# x = self.layer4(x)
# x = self.avgpool(x)
# x = torch.flatten(x, 1)
# x = self.fc(x)
return x
# %% モデルのforward関数を書き換える
model._forward_impl = type(model._forward_impl)(_forward_impl, model)
ONNX形式でモデル出力
python
#入力テンソルのダミー
x = torch.randn(1, 3, 224, 224, requires_grad=True)
# モデル出力
torch.onnx.export(
model, x, "WideResNet50Layer1.onnx", export_params=True,
opset_version=10, do_constant_folding=True,
input_names = ['input'], output_names = ['output']
)
C#で推論しList型で取得
ONNX RuntimeとImageSharpをNuGetでインストールします.
C#
using System.IO;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
//追加:
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using SixLabors.ImageSharp;
using SixLabors.ImageSharp.Processing;
using SixLabors.ImageSharp.PixelFormats;
using SixLabors.ImageSharp.Formats;
namespace ResNetIntermediateLayerWithONNX
{
class Program
{
static void Main(string[] args)
{
string imageFilePath = @"入力画像パス";
Image<Rgb24> image = Image.Load<Rgb24>(imageFilePath, out IImageFormat format);
using (Stream imageStream = new MemoryStream())
{
image.Mutate(x =>
{
x.Resize(new ResizeOptions
{
Size = new Size(224, 224),
Mode = ResizeMode.Crop
});
});
image.Save(imageStream, format);
}
Tensor<float> input = new DenseTensor<float>(new[] { 1, 3, 224, 224 });
var mean = new[] { 0.485f, 0.456f, 0.406f };
var stddev = new[] { 0.229f, 0.224f, 0.225f };
image.ProcessPixelRows(accessor =>
{
for (int y = 0; y < accessor.Height; y++)
{
Span<Rgb24> pixelSpan = accessor.GetRowSpan(y);
for (int x = 0; x < accessor.Width; x++)
{
input[0, 0, y, x] = ((pixelSpan[x].R / 255f) - mean[0]) / stddev[0];
input[0, 1, y, x] = ((pixelSpan[x].G / 255f) - mean[1]) / stddev[1];
input[0, 2, y, x] = ((pixelSpan[x].B / 255f) - mean[2]) / stddev[2];
}
}
});
var inputs = new List<NamedOnnxValue>
{
//ここのstring "input" はONNXモデル出力時に決定したinputレイヤの名前
NamedOnnxValue.CreateFromTensor("input", input)
};
string modelFilePath = "WideResNet50_Layer1_1.onnx";
var session = new InferenceSession(modelFilePath);
IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results = session.Run(inputs);
IEnumerable<float> output = results.First().AsEnumerable<float>();
List<float> outputAsList = output.ToList();
Console.WriteLine(outputAsList.Count);
}
}
}
参考