はじめに
kaggle等でMLPの実装を探すと,kerasのSequentialにDense層やDropout層をaddしまくるような実装をよく見かけます.
torch.nn.Moduleを用いた実装では,Linearと活性化関数を2,3個繰り返す実装が多いように思います.
これらはMLPの概念に忠実な実装です.
しかしMLPの中間層の数を簡単に自由に変更できれば,より柔軟な実装ができるかもしれません.
ここで紹介するアイデアはシンプルで誰でも思いつきそうなものですが,レイヤ数が固定の実装ばかり見ていると考え方が固くなってしまい,筆者にとっては新鮮なものであったので記事にしました.
Keras (Sequential)
中間層のサイズを指定して,for文でn_layer-2個のDense層とDropout層を追加すれば,可変個数の層を追加できる.
from keras.models import Sequential
from keras.layers import Dense, Dropout
model = Sequential()
model.add(Dense(hidden_size, activation='relu', input_shape=input_shape))
for _ in range(n_layer-2):
model.add(Dense(hidden_size, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(num_classes, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer=Adam(),
metrics=['accuracy'])
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
verbose=2,
validation_data=(x_val, y_val))
score = model.evaluate(x_test, y_test, verbose=0)
torch.nn.Module
LinearとBatchNorm1dをforで交互に追加し,ModuleList()に追加する.
ModuleListなんてあったんですね……
forwardでも同じ順番で,活性化関数とdropoutを加えてfor文で回す必要があり,ちょっとめんどくさい.
import torch
import torch.nn.functional as F
class MLP(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
super(MLP, self).__init__()
# 1st layer
self.lins = torch.nn.ModuleList()
self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
self.bns = torch.nn.ModuleList()
self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
for _ in range(num_layers - 2):
self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
# last layer
self.lins.append(torch.nn.Linear(hidden_channels, out_channels))
self.dropout = dropout
def reset_parameters(self):
for lin in self.lins:
lin.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
def forward(self, x):
for i in range(num_layers - 1):
x = self.lins[i](x)
x = self.bns[i](x)
x = F.relu(x, inplace=True)
x = F.dropout(x, p=self.dropout)
x = self.lins[-1](x)
return F.log_softmax(x, dim=-1)
さいごに
参考