LoginSignup
2
1

More than 1 year has passed since last update.

NNモデルの中間層出力テンソルをC#で取得する

Posted at

備忘録のためのメモ
NNの中間層を特徴量抽出器として扱うことがあったりします

pytorchでforward関数の書き換え

モデル本来の推論をしつつ中間層を取得するのは(恐らく)容易ではないため,forward関数を書き換えて中間層の出力を無理やりreturnします.

python
import torch
model :torch.nn.Module = torch.hub.load('pytorch/vision:v0.10.0', 'wide_resnet50_2', pretrained=True)
model.eval()

def _forward_impl(self, x):
    # See note [TorchScript super()]
    # xs = []
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x) # <- 欲しい中間層でreturn
    # x = self.layer2(x)
    # x = self.layer3(x)
    # x = self.layer4(x)

    # x = self.avgpool(x)
    # x = torch.flatten(x, 1)
    # x = self.fc(x)

    return x

# %% モデルのforward関数を書き換える
model._forward_impl = type(model._forward_impl)(_forward_impl, model)

ONNX形式でモデル出力

python
#入力テンソルのダミー
x = torch.randn(1, 3, 224, 224, requires_grad=True)
 
# モデル出力
torch.onnx.export(
    model, x, "WideResNet50Layer1.onnx", export_params=True, 
    opset_version=10, do_constant_folding=True,
    input_names = ['input'], output_names = ['output']
)

C#で推論しList型で取得

ONNX RuntimeとImageSharpをNuGetでインストールします.

C#
using System.IO;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

//追加:
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using SixLabors.ImageSharp;
using SixLabors.ImageSharp.Processing;
using SixLabors.ImageSharp.PixelFormats;
using SixLabors.ImageSharp.Formats;


namespace ResNetIntermediateLayerWithONNX
{

    class Program
    {
        static void Main(string[] args)
        {
            string imageFilePath = @"入力画像パス";
            Image<Rgb24> image = Image.Load<Rgb24>(imageFilePath, out IImageFormat format);

            using (Stream imageStream = new MemoryStream())
            {
                image.Mutate(x =>
                {
                    x.Resize(new ResizeOptions
                    {
                        Size = new Size(224, 224),
                        Mode = ResizeMode.Crop
                    });
                });
                image.Save(imageStream, format);
            }

            Tensor<float> input = new DenseTensor<float>(new[] { 1, 3, 224, 224 });
            var mean = new[] { 0.485f, 0.456f, 0.406f };
            var stddev = new[] { 0.229f, 0.224f, 0.225f };
            image.ProcessPixelRows(accessor =>
            {
                for (int y = 0; y < accessor.Height; y++)
                {
                    Span<Rgb24> pixelSpan = accessor.GetRowSpan(y);
                    for (int x = 0; x < accessor.Width; x++)
                    {
                        input[0, 0, y, x] = ((pixelSpan[x].R / 255f) - mean[0]) / stddev[0];
                        input[0, 1, y, x] = ((pixelSpan[x].G / 255f) - mean[1]) / stddev[1];
                        input[0, 2, y, x] = ((pixelSpan[x].B / 255f) - mean[2]) / stddev[2];
                    }
                }
            });

            var inputs = new List<NamedOnnxValue>
            {
                //ここのstring "input" はONNXモデル出力時に決定したinputレイヤの名前
                NamedOnnxValue.CreateFromTensor("input", input) 
            };

            string modelFilePath = "WideResNet50_Layer1_1.onnx";

            var session = new InferenceSession(modelFilePath);
            IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results = session.Run(inputs);

            IEnumerable<float> output = results.First().AsEnumerable<float>();

            List<float> outputAsList = output.ToList();
            Console.WriteLine(outputAsList.Count); 
        }
    }
}

参考

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1