LoginSignup
1
1

More than 5 years have passed since last update.

今更ながらニューラルネットワーク(MNIST)

Posted at

 pythonなどに素晴らしいNNのフレームワークがあるのですが、イマイチ理解できないので、随分と前からC++ templateの勉強がてらNNをフルスクラッチしてます。CUDAの学習なども併せて行うと頭がパンクするのでC++AMPでGPUを利用しています。github

 C++ templateを使うことにしたので、NNモデルは以下の様に型として定義しています(C++ templateに苦戦)。各ノードの入出力に間違いがあればコンパイル時に分かるって寸法です。しかし、パラメータを色々試しながら最適なモデルを探すことが多いと思いますので実用性ないと思います。

TestMnist.cpp

    using InputShape = IOShapeDef<BATCHSIZE, MnistDataLoader::HEIGHT, MnistDataLoader::WIDTH, 1>;

    using MnistSeq = Sequence <
        InputShape,

        Convolution<32, Size2d<5>>,
        ConvolutionBias<>,
        Activation<Relu>,

        MaxPooling<Size2d<2>>,

        Convolution<64, Size2d<5>>,
        ConvolutionBias<>,
        Activation<Relu>,

        MaxPooling<Size2d<2>>,
        Flatten<>,

        Dropout<Float<5, 10>>,

        FullConnection<128>,
        FullConnectionBias<>,
        Activation<Relu>,

        FullConnection<10>,
        FullConnectionBias<>
    >;

    using ModelType = NeuralNetworkModel<BATCHSIZE, MnistSeq, SoftmaxCrossEntropyMulticlass, Adam>;
    auto spModel = ModelType::create(gpu);
    spModel->optimizer->learningRate = float_type(0.001);
    spModel->ready(Random, float_type(-0.1), float_type(0.2));

 FullConnectionやConvolutionalなどの各ノードの実装は以下の様にdetailsネームスペースで記述され、グローバル空間で利用し易い形(入力型(前のノードの出力型)が省略できる)にしたクラスがあります。

Neu.h
namespace details {
    template<typename InputShapeT, int TOutputChannels, typename KernelSizeT, typename 
    KernelStrideT = Stride2d<1>>
    struct Convolution : base::WeightedNode
    {
      //(実装略)
    };
}

template<int TOutputChannels, typename KernelSizeT, typename KernelStrideT = Stride2d<1>, int TTag = -1>
struct Convolution
{
    constexpr static const int Tag = TTag;

    template<typename InputShapeT>
    using Impl = details::Convolution<InputShapeT, TOutputChannels, KernelSizeT, KernelStrideT>;
};
template<int TOutputChannels, typename KernelSizeT, typename KernelStrideT = Stride2d<1>, int TTag = -1>
using Conv = Convolution<TOutputChannels, KernelSizeT, KernelStrideT, TTag>;

 一部のノード(BatchNormalize)などはちゃんと動いているのか怪しいので、書いたもののコメントアウトされています。

 仮想通貨マイニングにせよDNNにせよ良いGPUが欲しい今日この頃です。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1