はじめに
本記事はオライリーのscikit-learn、Keras、TensorFlowによる実践機械学習 第2版
の10章人工ニューラルネットワークとKerasの初歩のPyhtonコードをLiveBookとAxonで書き換えてみた記事です
LiveBook
Elixir製のJupyterNotebookでElixirとMarkdownが記述できます
MarkdonwはMermaidにも対応しています
インストール方法は以下を参考にしてください
- 「JupyterNotebook + NumPyでサクッと画像加工するノリ」をElixir Livebookでやってみた
- Livebookの基礎知識 ーー LiveView JP#4:Livebook始めよう+Fly.ioでモブプロのLT資料
- livebook Github
Axon
Elixir製の行列演算ライブラリNxで実装されたKerasライクな深層学習フレームワークです
XLAやTorchxを使用して高速化を図っています
SciData
学習データセット各種をダウンロード,cacheしてくれるライブラリです
各モジュールに学習データ取得は donwload(), テストデータ取得はdonwload_test()が実装されています
NoteBook
それでは開始していきましょう
LiveBookにアクセスして新規のNoteBookを作成してください
コードはこちらを参考にしています
https://github.com/elixir-nx/scidata/blob/master/README.md
https://github.com/elixir-nx/axon/blob/main/examples/vision/mnist.exs
ライブラリのインストール
ライブラリのインストールの注意点ですが基本的に1 Notebookに1つで、2つめを実行するとエラーになります
新しく追加したりする場合はMix.installに追加して ReEvaluteを実行し、Reconnectを要求してくるのでReconnect runtime
ボタンを押してインストールを開始します
Mix.install([
{:axon, github: "elixir-nx/axon"},
{:exla, github: "elixir-nx/nx", sparse: "exla"},
{:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
{:scidata, "~> 0.1.3"}
])
インストール中のログが流れて最終的に以下のようにokが完了です
:ok
EXLAで使うデバイスを設定
行列演算を高速化するのにExlaというライブラリを使うのですが、どのデバイスを使って高速化するかを設定します
warning: EXLA.set_preferred_defn_options/1 is deprecated. Use set_as_nx_default/2 instead
set_preferred_defn_options
はdeprecated.らしいのでset_as_nx_default
を使います
EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host])
前から順に試していき、デバイスが違う場合は次のオプションと試していきます
今回はM1 MacなのでCPUを使うhostになります
:host
準備が整ったので次からは実際のpythonコードとの対比させながら書いていきます
実際にLiveBookで実行する場合はElixirのコードのみ貼り付けてください
10.2.2.1 Kerasを使ったデータセットのロード
Python
fashion_mnist = keras.datasets.fashion_mnist
(X_train_full, y_train_full), (X_test, Y_test) = fashion_mnist.load_data()
Elixir
SciDataはdownloadだけで学習に適したデータにはなってないので加工していきます
Download
{x_train_full, y_train_full} = Scidata.FashionMNIST.download()
{x_test, y_test} = Scidata.FashionMNIST.download_test()
train data
画像データ
{{<<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>, {:u, 8}, {60000, 1, 28, 28}}
from_binaryでバイナリデータから行列を作成し
reshapeで 28x28x1の画像データが60000になるように整形します
最後に色はグレースケールで0~255となっているので255で割って正規化します
ラベルデータ
{<<9, 0, 0, 3, 0, 2, 7, 2, 5, 5, 0, 9, 5, 5, 7, 9, 1, 0, 6, 4, 3, 1, 4, 8, 4, 3, 0, 2, 4, 4, 5, 3,
6, 6, 0, 8, 5, 2, 1, 6, 6, 7, 9, 5, 9, 2, 7, ...>>, {:u, 8}, {60000}}}
from_binaryでバイナリデータから行列を作成し
new_axisで{60000},{60000,1}に変換し
最後にone_hotベクトルを追加して {60000,1,10}にします
{binary, unit, shape} = x_train_full
x_train_full =
binary
|> Nx.from_binary(unit)
|> Nx.reshape(shape)
|> Nx.divide(255)
{label, unit_l, _shape} = y_train_full
y_train_full =
label
|> Nx.from_binary(unit)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
test data
test dataもtrain dataと同じ用に加工します
{binary, unit, shape} = x_test
x_test =
binary
|> Nx.from_binary(unit)
|> Nx.reshape(shape)
|> Nx.divide(255)
{label, unit_l, _shape} = y_test
y_test =
label
|> Nx.from_binary(unit_l)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
書籍のほうでは以下のようにvaidとtrainを分けていますがAxonはいい感じにしてくれてるのでそのままで問題ありません
(引数にvalidation_dataを入れる箇所はないが、軽く見ても分割している箇所は見当たらない?)
X_valid, X_train = X_train_full[:5000] / 255.0, X_train_full[5000:] / 255.0
y_vaild, y_train = y_train_full[:5000], y_train_full[5000:]
Elixirで分割したい場合はEnum.split
を使いましょう
{x_valid, x_train} = Enum.split(x_train_full, 5000)
{y_valid, y_train} = Enum.split(y_train_full, 5000)
クラス名のリスト
pythonのコードをそのままで問題ありません
class_names = [
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot"
]
["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag",
"Ankle boot"]
リストの特定のindexの値を取得するときはEnum.atを使用します
y_train_fullはone-hotベクトルになっているの1になっている箇所をNx.argmax()
で取得し、Tensor構造体なのでNx.to_number()
で数字にしています
Enum.at(class_names, y_train_full[5000] |> Nx.argmax() |> Nx.to_number())
書籍と同じCoatが取得できました
"Coat"
10.2.2.2 シーケンシャルAPIを使ったモデルの作成
ネットワークの構築
ネットワークを構築していきます
Kerasはinput_shapeを1つ目の層で指定しますが
Axonはinputを作成して各層にpipeで繋いでいきます
{nil, 1, 28, 28}
最初はnil、未知のバッチサイズの代わりとして使用されています
model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=[28, 28]))
model.add(keras.layers.Dense(300, activation="relu")
model.add(keras.layers.Dense(100, activation="relu")
model.add(keras.layers.Dense(10, activation="softmax")
model =
Axon.input({nil, 1, 28, 28})
|> Axon.flatten()
|> Axon.dense(300, activation: :relu)
|> Axon.dense(100, activation: :relu)
|> Axon.dense(10, activation: :softmax)
モデルの構造表示
python
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) (None, 784) 0
dense (Dense) (None, 300) 235500
dense_1 (Dense) (None, 100) 30100
dense_2 (Dense) (None, 10) 1010
=================================================================
Total params: 266,610
Trainable params: 266,610
Non-trainable params: 0
_________________________________________________________________
Elixir
modelが返り値だと以下のように表示されます
途中で表示したい場合はIO.inspect(model)
でも表示できます
IO.inspect(model)
----------------------------------------------------------------------------------------------------------
Model
==========================================================================================================
Layer Shape Policy Parameters Parameters Memory
==========================================================================================================
input_0 ( input ) {nil, 1, 28, 28} p=f32 c=f32 o=f32 0 0 bytes
flatten_0 ( flatten["input_0"] ) {nil, 784} p=f32 c=f32 o=f32 0 0 bytes
dense_0 ( dense["flatten_0"] ) {nil, 300} p=f32 c=f32 o=f32 235500 942000 bytes
relu_0 ( relu["dense_0"] ) {nil, 300} p=f32 c=f32 o=f32 0 0 bytes
dense_1 ( dense["relu_0"] ) {nil, 100} p=f32 c=f32 o=f32 30100 120400 bytes
relu_1 ( relu["dense_1"] ) {nil, 100} p=f32 c=f32 o=f32 0 0 bytes
dense_2 ( dense["relu_1"] ) {nil, 10} p=f32 c=f32 o=f32 1010 4040 bytes
softmax_0 ( softmax["dense_2"] ) {nil, 10} p=f32 c=f32 o=f32 0 0 bytes
----------------------------------------------------------------------------------------------------------
get_weights(), set_weigths()
Axonに該当するメソッドはありませんが、学習パラメーターとネットワークは切り離されているのと
学習パラメーターはNx.Tensorなので、変更処理等を行う場合はNxのメソッドで行いましょう
10.2.2.3 モデルのコンパイル
損失関数と最適化関数等を指定します
elixirはそのまま訓練を行うのでここは特になし
python
model.compile(loss="sparse_categorical_crossentropy",
optimizer="sgd",
metrics=["accuracy"])
次の節で全体像がありますが、該当する箇所はここになります
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.sgd(0.01))
|> Axon.Loop.metric(:accuracy)
10.2.2.4 モデルの訓練と評価
python
history = model.fit(X_train, y_train, epochs=30,
validation_data=(X_valid, y_valid))
Train on 55000 samples, validate on 5000 samples
Epoch 1/30
55000/55000 [======] - 3s 49us/sample - loss: 0.7218 - accuracy: 0.7660
- val_loss: 0.4973 - val_accuracy: 0.8366
Epoch 2/30
55000/55000 [======] - 2s 45us/sample - loss: 0.4840 - accuracy: 0.8327
- val_loss: 0.4456 - val_accuracy: 0.8480
[...]
Epoch 30/30
55000/55000 [======] - 3s 53us/sample - loss: 0.2252 - accuracy: 0.9192
- val_loss: 0.2999 - val_accuracy: 0.8926
Elixir
学習データをバッチサイズに応じたリストにしてtrainerに突っ込みます
完了すると最終的な学習パラメーターのMapが返ってきます
inputs = x_train_full |> Nx.to_batched_list(32)
targets = y_train_full |> Nx.to_batched_list(32)
model_state =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.sgd(0.01))
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(Stream.zip(inputs, targets), epochs: 30)
Epoch: 0, Batch: 1850, accuracy: 0.9542719 loss: 0.6901455
Epoch: 1, Batch: 1850, accuracy: 0.9677529 loss: 0.5816337
Epoch: 2, Batch: 1850, accuracy: 0.9705983 loss: 0.5312806
Epoch: 3, Batch: 1850, accuracy: 0.9722843 loss: 0.4994828
Epoch: 4, Batch: 1850, accuracy: 0.9734979 loss: 0.4764614
Epoch: 5, Batch: 1850, accuracy: 0.9745985 loss: 0.4584843
Epoch: 6, Batch: 1850, accuracy: 0.9754986 loss: 0.4437402
Epoch: 7, Batch: 1850, accuracy: 0.9762914 loss: 0.4312303
Epoch: 8, Batch: 1850, accuracy: 0.9769894 loss: 0.4203333
Epoch: 9, Batch: 1850, accuracy: 0.9775428 loss: 0.4106765
Epoch: 10, Batch: 1850, accuracy: 0.9780806 loss: 0.4019816
Epoch: 11, Batch: 1850, accuracy: 0.9786006 loss: 0.3940594
Epoch: 12, Batch: 1850, accuracy: 0.9790562 loss: 0.3867979
Epoch: 13, Batch: 1850, accuracy: 0.9794934 loss: 0.3800708
Epoch: 14, Batch: 1850, accuracy: 0.9799964 loss: 0.3738033
Epoch: 15, Batch: 1850, accuracy: 0.9803599 loss: 0.3679344
Epoch: 16, Batch: 1850, accuracy: 0.9807381 loss: 0.3624099
Epoch: 17, Batch: 1850, accuracy: 0.9812099 loss: 0.3571859
Epoch: 18, Batch: 1850, accuracy: 0.9815733 loss: 0.3522262
Epoch: 19, Batch: 1850, accuracy: 0.9818426 loss: 0.3475078
Epoch: 20, Batch: 1850, accuracy: 0.9821869 loss: 0.3429988
Epoch: 21, Batch: 1850, accuracy: 0.9825315 loss: 0.3386921
Epoch: 22, Batch: 1850, accuracy: 0.9828553 loss: 0.3345581
Epoch: 23, Batch: 1850, accuracy: 0.9832116 loss: 0.3305819
Epoch: 24, Batch: 1850, accuracy: 0.9835541 loss: 0.3267476
Epoch: 25, Batch: 1850, accuracy: 0.9838923 loss: 0.3230499
Epoch: 26, Batch: 1850, accuracy: 0.9841914 loss: 0.3194728
Epoch: 27, Batch: 1850, accuracy: 0.9844572 loss: 0.3160088
Epoch: 28, Batch: 1850, accuracy: 0.9847838 loss: 0.3126440
Epoch: 29, Batch: 1850, accuracy: 0.9850463 loss: 0.3093799
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[300]
[0.13571378588676453, 0.03549302741885185, 0.08724245429039001, -0.0061872657388448715, 2.2010751126799732e-4, 0.037931330502033234, 0.021099954843521118, 0.024455856531858444, 0.02081124298274517, 0.023745786398649216, -2.9873353196308017e-4, 0.05602224916219711, 0.012563621625304222, 0.0575307235121727, 0.0011330501874908805, 0.1084568053483963, 0.06705764681100845, -0.06733907014131546, 0.05662992596626282, 0.02480822615325451, 0.22499433159828186, -0.06035986915230751, 0.20579959452152252, 0.03935956954956055, 0.009435584768652916, 0.010274690575897694, -0.01043284684419632, 0.11485638469457626, -0.06310120224952698, 0.027927584946155548, 0.1819688379764557, -0.04840763658285141, 0.08819303661584854, 0.14863930642604828, -0.08883821219205856, 0.005014372523874044, -0.014123741537332535, 0.0704503059387207, 0.18479794263839722, 0.06806794553995132, 0.04929867386817932, -0.011439140886068344, -0.012110976502299309, 0.06211903691291809, -0.07319530844688416, 0.025473466143012047, -0.12689639627933502, 0.018913164734840393, ...]
>,
"kernel" => #Nx.Tensor<
f32[784][300]
[
[0.013576650992035866, -0.014004906639456749, 0.05055975168943405, 0.06146484613418579, -0.0734097957611084, 0.05356033891439438, -0.035959236323833466, -0.044105228036642075, -0.047992460429668427, 0.058227330446243286, 0.029727330431342125, -0.0384567454457283, -0.039290882647037506, 0.07082822173833847, 0.018462810665369034, 0.018885670229792595, 0.029172591865062714, 0.05742579326033592, -0.0596354678273201, 0.04225178062915802, -0.07297362387180328, 0.028878074139356613, -0.056826017796993256, -0.03331564739346504, 0.0063286819495260715, 0.027812685817480087, -0.04472225531935692, 0.05244554206728935, 0.06330916285514832, -0.008841929025948048, -0.02647828683257103, 0.01984805054962635, -0.030981721356511116, -0.007359237410128117, -0.04982060194015503, 0.03356190770864487, -0.03447606414556503, 0.006800201255828142, 0.04560250788927078, -0.06059754639863968, -0.0571490079164505, 0.00783382449299097, -0.010420969687402248, 0.02866566926240921, 0.056918710470199585, 0.027809524908661842, 0.06493263691663742, ...],
...
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[100]
[-0.033903174102306366, 0.0031299397815018892, 0.042292580008506775, 0.03492650389671326, 0.10128480195999146, 0.028303274884819984, -0.05309673026204109, 0.0676303580403328, 0.04182380065321922, 0.05242175981402397, -0.020688258111476898, 0.14884760975837708, -0.011936621740460396, 0.041801001876592636, 0.018748031929135323, 0.06619488447904587, 0.11927098035812378, 0.03428078070282936, -0.01305422093719244, -0.01333137508481741, 0.029846984893083572, 0.11659515649080276, 0.05557134747505188, -0.014377977699041367, 0.04988781735301018, 0.0533335842192173, 0.006171948276460171, -0.060610681772232056, -0.014924745075404644, -0.020609134808182716, 0.04662292078137398, 0.1432056427001953, 0.03395121917128563, 0.12744517624378204, 0.02471437118947506, -0.018728500232100487, 0.038396526128053665, 0.16800907254219055, 0.07645541429519653, 0.016965867951512337, -0.031188689172267914, -0.01604599319398403, 0.011045312508940697, 0.010093159042298794, 0.1250711977481842, 0.2722034156322479, 0.053765349090099335, ...]
>,
"kernel" => #Nx.Tensor<
f32[300][100]
[
[-0.13397899270057678, 0.13074228167533875, -0.12614311277866364, 0.06232032924890518, -0.04008568450808525, -0.0919671431183815, 0.09870906919240952, -0.05578409507870674, 0.026536675170063972, 0.05412725731730461, 0.15791819989681244, 0.11373333632946014, -0.04124131053686142, -0.12274332344532013, -0.04453011229634285, -0.036030229181051254, -0.029974598437547684, 0.019989633932709694, 0.011036165058612823, -0.025078900158405304, -0.021867990493774414, 0.06484955549240112, 0.014939786866307259, -0.04385797679424286, -0.09211859107017517, -0.023844990879297256, 0.11535773426294327, 0.10587604343891144, -0.024618981406092644, 0.10546504706144333, -0.048852160573005676, 0.1869676113128662, -7.510890136472881e-4, 0.10661966353654861, -0.07498490810394287, 0.022715985774993896, -0.1567826271057129, -0.08192359656095505, -0.11506399512290955, 0.010153728537261486, -0.07428283989429474, 0.018338680267333984, 0.056297894567251205, 0.025882074609398842, 0.07265932112932205, 0.0853208601474762, ...],
...
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[10]
[-0.006278525572270155, -0.11440706253051758, 0.10179226845502853, 0.09025454521179199, -0.2640182375907898, 0.2954282760620117, 0.11500847339630127, 0.10747787356376648, -0.06677260994911194, -0.2584887146949768]
>,
"kernel" => #Nx.Tensor<
f32[100][10]
[
[0.1874592900276184, -0.08722195774316788, -0.26049819588661194, 0.20030535757541656, 0.327762633562088, -0.15166854858398438, 0.3747589886188507, -0.32310038805007935, -0.30735719203948975, 0.08006537705659866],
[-0.3092937171459198, 0.13376988470554352, 0.2121509164571762, -0.13248851895332336, 0.15469135344028473, 0.034889236092567444, 0.2814216911792755, -0.012385893613100052, -0.023418258875608444, -0.2099449187517166],
[0.31322064995765686, -0.06482202559709549, -0.22563078999519348, 0.0010128654539585114, -0.18612903356552124, -0.114998959004879, -0.32411989569664, 0.3558351397514343, 0.27040356397628784, 0.28674325346946716],
[-0.019515056163072586, -0.34636208415031433, 0.18379873037338257, 0.38537105917930603, 0.28832852840423584, -0.3494378626346588, -0.06889533251523972, 0.2788335382938385, -0.20729109644889832, 0.30687129497528076],
[0.019088683649897575, -0.3099006712436676, -0.14253011345863342, -0.5100916624069214, 0.15435951948165894, ...],
...
]
>
}
}
historyからのグラフ描写
history相当が見当たらないのでスキップ
evaluate()
検証データセットを使って評価します
python
>>> model.evaluate(X_test, y_test)
10000/10000 [==========] - 0s 29us/sample - loss: 0.3340 - accuracy: 0.8851
[0.3339798209667206, 0.8851]
Elixir
先程のmodel_stateを evaluatorに突っ込んで検証データセットでrunを実行します
model
|> Axon.Loop.evaluator(model_state)
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.run(Stream.zip(x_test, y_test))
Batch: 312, Accuracy: 0.9776765
最終的に以下のMapが返ってきます
%{
0 => %{
"Accuracy" => #Nx.Tensor<
f32
0.9776764512062073
>
}
}
10.2.2.5 モデルを使った予測
python
>>> X_new = X_test[:3]
>>> y_proba = model.predict(X_new)
>>> y_proba.round(2)
array([[0. ,0. ,0. ,0. ,0. ,0.03,0. ,0.01,0. ,0.96],
[0. ,0. ,0.98,0. ,0.02,0. ,0. ,0. ,0. ,0. ],
[0. ,1. ,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. ]], dtype=float32)
>>> y_pred = model.predict_classes(X_new)
>>> y_pred
array([9, 2, 1])
>>> np.array(class_names)[y_pred]
array(['Ankle boot', 'Pullover', 'Trouser'], dtype='<U11')
>>> y_new = y_test[:3]
>>> y_new
array([9, 2, 1])
predictを実行する場合はrequire Axonをする必要があります
slice_along_axis
で0~3番目の画像のデータを取得
predict
を実行すると、3つの画像の予測結果が返ってきているので
argmax(axis: 1)
で縦軸のそれぞれの最大値のindexを返します
elixir
require Axon
inputs =
x_train_full
|> Nx.slice_along_axis(0, 3, axis: 0)
ans =
Axon.predict(model, model_state, inputs)
|> IO.inspect()
|> Nx.argmax(axis: 1)
|> IO.inspect()
#Nx.Tensor<
f32[3][10]
[
[1.920465919624803e-8, 2.8275014329892656e-8, 1.7683338127127968e-10, 4.753846760685576e-10, 2.8052097103548057e-11, 1.4351893332786858e-4, 8.373903237490765e-10, 0.006102824583649635, 3.937300618872541e-7, 0.9937532544136047],
[0.9988293051719666, 2.654123865131197e-10, 6.825270247645676e-6, 1.2458690434868913e-7, 1.249891995458441e-10, 1.2247265647382516e-13, 0.0011638033902272582, 1.6492996204875432e-11, 5.709294370603857e-8, 5.31898067257508e-11],
[0.6422454714775085, 0.013506175950169563, 0.00808799359947443, 0.2108849436044693, 0.018253562971949577, 7.227252353914082e-6, 0.10664960741996765, 9.29184679989703e-5, 7.107374403858557e-5, 2.0104010764043778e-4]
]
>
#Nx.Tensor<
s64[3]
[9, 0, 0]
>
正解データと照らし合わせてみましょう
targets =
y_train_full
|> Nx.slice_along_axis(0, 3, axis: 0)
|> Nx.to_batched_list(1)
Enum.map(
targets,
fn label -> Enum.at(class_names, Nx.argmax(label) |> Nx.to_number())
end)
|> IO.inspect()
Enum.map(
ans |> Nx.to_flat_list(),
fn index -> Enum.at(class_names, index)
end) |> IO.inspect()
["Ankle boot", "T-shirt/top", "T-shirt/top"]
["Ankle boot", "T-shirt/top", "T-shirt/top"]
to_heatmapで該当の画像も見てみましょう
Enum.map(
inputs |> Nx.to_batched_list(1),
fn input -> Nx.to_heatmap(input)
end)
最後に
ある程度書籍をなぞることができました!
最終的にGANとかが組めるとよさそうですね
コード
※ 2022/04/26現在
MIX_ENV=prod mix phx.server
だとMix.installができないので
livebook server
で動かしてください
実践機械学習 axon convert
Section
Mix.install([
{:axon, github: "elixir-nx/axon"},
{:exla, github: "elixir-nx/nx", sparse: "exla"},
{:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
{:scidata, "~> 0.1.3"}
])
EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host])
{x_train_full, y_train_full} = Scidata.FashionMNIST.download()
{binary, unit, shape} = x_train_full
x_train_full =
binary
|> Nx.from_binary(unit)
|> Nx.reshape(shape)
|> Nx.divide(255)
{label, unit_l, _shape} = y_train_full
y_train_full =
label
|> Nx.from_binary(unit)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
class_names = [
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot"
]
Enum.at(class_names, y_train_full[5000] |> Nx.argmax() |> Nx.to_number())
model =
Axon.input({nil, 1, 28, 28})
|> Axon.flatten()
|> Axon.dense(300, activation: :relu)
|> Axon.dense(100, activation: :relu)
|> Axon.dense(10, activation: :softmax)
inputs = x_train_full |> Nx.to_batched_list(32)
targets = y_train_full |> Nx.to_batched_list(32)
model_state =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.sgd(0.01))
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(Stream.zip(inputs, targets), epochs: 30)
{x_test, y_test} = Scidata.FashionMNIST.download_test()
{binary, unit, shape} = x_test
x_test =
binary
|> Nx.from_binary(unit)
|> Nx.reshape(shape)
|> Nx.divide(255)
|> Nx.to_batched_list(32)
{label, unit_l, _shape} = y_test
y_test =
label
|> Nx.from_binary(unit_l)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
|> Nx.to_batched_list(32)
model
|> Axon.Loop.evaluator(model_state)
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.run(Stream.zip(x_test, y_test))
require Axon
inputs =
x_train_full
|> Nx.slice_along_axis(0, 3, axis: 0)
targets =
y_train_full
|> Nx.slice_along_axis(0, 3, axis: 0)
|> Nx.to_batched_list(1)
ans =
Axon.predict(model, model_state, inputs)
|> IO.inspect()
|> Nx.argmax(axis: 1)
|> IO.inspect()
Enum.map(targets, fn label -> Enum.at(class_names, Nx.argmax(label) |> Nx.to_number()) end)
|> IO.inspect()
Enum.map(ans |> Nx.to_flat_list(), fn index -> Enum.at(class_names, index) end) |> IO.inspect()
Enum.map(inputs |> Nx.to_batched_list(1), fn input -> Nx.to_heatmap(input, [:ansi_enabled]) end)