はじめに
機械学習入門の定番であるMNISTを題材にn行プログラミング(ショートコーディング)をやってみました。
いろいろと頑張った結果、MNIST精度95%のモデル+アプリを含めて10行(改行なしなら780バイト)のHTMLアプリができました。
成果
↓の埋め込みアプリで遊べます(スマートフォンは未対応です)。文字を書き直すには右下の「rerun」をクリックしてください。
See the Pen handwritten digit recognizer with 10 lines (MNIST95%) by Satoshi Tanaka (@stnk20) on CodePen.
コードも掲載しておきます:
<body onload='C.width=C.height=S=84;T=C.getContext("2d");C.onmousemove=e=>{X=e.
offsetX/3;Y=e.offsetY/3;e.buttons&&T.lineTo(X,Y)+T.stroke();T.moveTo(X,Y)};W=
o=>"W~$w =~*1jBc @@kKnts[sxAO5vN c~t`eNXlk LsJZR|>I~a~~-~~^`NV^Y!Em~ [7_VHaXe@\
nlTy&Ve\\8`c/RWScZcJRp0PTrkYPU0E[P^^BGHS\\#IOHN\\PKb3AXYG[mXNCW<YZLFPFIOiwEM]Z\
X]`EJcONO=B^ZX \\@G\\O^<SN~,Q[QVW[UaUyXLDSYLPGpo\\RYS+".charCodeAt(o);N=(z,n)=>
2**((W(h+n)/4-36.4+z)*(W(h+n+20)/366-1/57))**2;T.scale(3,3);M=Math.max;'onclick
='y=Array(10).fill(0);for(i=26;i--;)for(j=26;j--;)for(h=10;h--;y=y.map((t,d)=>t
+M(0,s+W(h%5)-119)/N(j,15)/N(i,25)*(W(d*10+h+100)-81.4)+W(d+5)/20))for(k=9,s=0;
k--;)s+=T.getImageData((j+k%3)*3,i*3+k,1,1).data[3]*(W(h%5*9+k+55)/116-.73);P.
innerHTML=y.indexOf(M(...y))'><p id=P>*</p><canvas id=C style=border:solid>
モデル
とにかくパラメータ数を小さくすることが重要です。
例えば線形分類は7850、画像を1/2リサイズしたとしても1970のパラメータが必要になるのですが、これでは多すぎます。
なんとかして数百程度まで落とす必要があります。
hand-craftedな特徴量を使うことも考えられますが、その場合パラメータ数を減らすのと引き換えにコード量が増えてしまいます。精度を意識しつつこのバランスを取るの大変そうなので、今回はモデル構造を工夫する方針としました。
とはいえ、通常のニューラルネットワークなどを使うとすぐにパラメータが増えてしまうので、HoughVotingを参考とした特殊なGlobalPooling層を用いることで対応します。このレイヤは学習可能なガウス分布で空間の重み付けを行ったうえで、チャンネルごとに空間方向全体での平均をとるものです。Tensorflow実装は以下のようになります:
class GlobalHoughPooling2D(Layer):
def __init__(self, depth_multiplier=1, **kwargs):
super().__init__(**kwargs)
self.depth_multiplier = depth_multiplier
def build(self, input_shape):
self.input_height, self.input_width, input_dim = [int(s) for s in input_shape[1:]]
self.offset = self.add_weight(
shape=(2, 1, 1, input_dim, self.depth_multiplier),
initializer=keras.initializers.Constant(0),
constraint=keras.constraints.MaxNorm(1, axis=1),
name='offset')
self.lambda_ = self.add_weight(
shape=(2, 1, 1, input_dim, self.depth_multiplier),
initializer=keras.initializers.Constant(2),
constraint=keras.constraints.MinMaxNorm(1, 10, axis=1),
name='lambda_')
# Set input spec.
self.input_spec = InputSpec(ndim=4, axes={3: input_dim})
self.built = True
def call(self, inputs, training=None):
h, w = self.input_height, self.input_width
x = K.reshape(K.arange(w, dtype="float32")+0.5, (1, -1, 1, 1))/w-0.5+self.offset[0]
y = K.reshape(K.arange(h, dtype="float32")+0.5, (-1, 1, 1, 1))/h-0.5+self.offset[1]
u = x*self.lambda_[0]
v = y*self.lambda_[1]
kernel = K.exp(-0.5*(u**2+v**2)) # (h, w, input_dim, depth_multiplier)
outputs = K.depthwise_conv2d(inputs, kernel, padding="valid")
return outputs[:, 0, 0, :]
def compute_output_shape(self, input_shape):
return [input_shape[0], input_shape[-1]]
def get_config(self):
config = super().get_config()
config["depth_multiplier"] = self.depth_multiplier
return config
やはり畳み込みは性能向上に寄与しますので、パラメータ数が増えにくい1層目だけ利用します。モデル全体は「3x3畳み込み(filters=5)→上記のHoughプーリング(depth_multiplier=2)→全結合層」の構成としました。これで実行時に必要となるパラメータ数は200まで削減できます。
なお、後述する量子化の影響を抑えるため、regularizerとconstraintを適当に設定しています。
若干初期値ガチャ要素が強いですが、validationセットで精度96%前後まで到達します。
パラメータの記述方法
パラメータ数は小さくることができましたが、普通のJSON等で記述すると非常に冗長ですので、ここも対策が必要です。
今回は、パラメータの数値をASCIIコードの範囲に移動して整数化することで、数字←→文字の変換をする方法をとりました。JavaScriptの文字列として有効なASCIIコードの範囲は32-126ですので、95段階の量子化をすれば1パラメータをほぼ1文字で記述することができます(エスケープシーケンスの分だけ長くなります)。
量子化は最小最大値をもとに丸めるだけの簡易な方法を取っています。量子化の段階数が少ないことから精度が悪化することがあります(そこそこ上手く行くこともある)。本当はファインチューニングなどしたほうが精度が安定すると思います。
ちなみにbase64化も可能性としてはあると思いますが、手頃な展開方法がわからなかったため見送りました。
JavaScript実装
画像認識部分
先人の記事(「JavaScript ショートコードテクニック集(ES6含む)」「14行(元15行)のhtmlでマインスイーパーを作った」を参考にして一般的なコード短縮を行っています。
その他、計算順序をいじってループの個数を減らす・事前計算できる部分をまとめる・数値を丸めるなど様々な手法を使っていますが、説明が面倒なので割愛します(記事の反響が大きければ書くかもしれません)。
この時点で精度を確認したところ、95.12%となりました。
下記が短縮後のコードです。xが画像の配列、yが各数字のスコアの配列となっています。
S=28;
M=Math.max;
W=o=>'W~$w =~*1jBc @@kKnts[sxAO5vN c~t`eNXlk LsJZR|>I~a~~-~~^`NV^Y!Em~ [7_VHaXe@nlTy&Ve\\8`c/RWScZcJRp0PTrkYPU0E[P^^BGHS\\#IOHN\\PKb3AXYG[mXNCW<YZLFPFIOiwEM]ZX]`EJcONO=B^ZX \\@G\\O^<SN~,Q[QVW[UaUyXLDSYLPGpo\\RYS+'.charCodeAt(o);
N=(z,n)=>2**((ma*W(h+n)+mb+z)*(la*W(h+n+20)+lb))**2;
y=Array(10).fill(0);
for(i=26;i--;)for(j=26;j--;)
for(h=10;h--;y=y.map((t,d)=>t+M(0,s+W(h%5)-119)/N(j,15)/N(i,25)*(W(d*10+h+100)-81.4)+W(d+5)/20))
for(k=9,s=0;k--;)s+=x[(i+k/3|0)*S+j+k%3]*(W(h%5*9+k+55)/116-.73);
ちなみに、実は計算速度を犠牲にして31バイト減らせる方法がわかっているのですが、流石に遅すぎたので非採用としました。
キャンバス部分
キャンバス部分のつくりはこんな感じです。
キャンバスサイズについては、原理的には28ピクセルで動作できるのですが、操作しやすいよう3倍の84ピクセルとしました。
<body onload='
C.width=C.height=84;
T=C.getContext("2d");
T.scale(3,3);
C.onmousemove=e=>{
X=e.offsetX/3,Y=e.offsetY/3;
e.buttons&&T.lineTo(X,Y)+T.stroke();
T.moveTo(X,Y)
};
'>
<canvas id=C onclick='
//ここで画像認識を実装
//T.getImageData(x,y,1,1).data[3]で各画素値が得られる
'style=border:solid>
仕上げ
上記の部分を画像アクセス部分を調整しつつ組み合わせます。出力部分は短いpタグを使います。
最後に、n行プログラミングでは79文字で1行としますので、そこに単語の区切りが来ないようにコードの順序を調整して完成です。
ちなみに改行なしで一行にまとめると、780バイトになります。
感想
なかなか面白い結果が出せたと思っています。
元々はパラメータ数や層数をどうやって減らせるかを考えていたのですが、HoughVotingを試したところ上手く行ったので、勢いにまかせてショートコーディングまで頑張ってみました。
機械学習要素を含んだショートコーディングは、目標精度によって取れるアプローチが変わってくるところに面白さがあると思います。
個人的にはもう少し精度を落としたとして、どこまで短くできるかという方向に興味があります(精度は90%程度あれば見栄えがするでしょうか?)。精度を上げる方向については、97%か98%ぐらいが手頃な難易度になると思います(99%はかなりしんどそうなので私はやりません)。
おまけ1:GlobalHoughPoolingについて
今回考案したGlobalHoughPoolingですが、チャンネル数を増やしたモデルではMNIST99%、FashionMNIST90%を達成できており、案外実力があると思います。ちょっとコードをいじれば畳み込みの形になるので、層を減らしつつも受容野を拡げることができます。DepthwiseConvolutionやAttentionよりも計算量を削減できるので、使いどころがあるかもしれません。