LoginSignup
1
0

wslでnerves その27

Last updated at Posted at 2023-11-21

概要

wsl(wsl2じゃない)でnervesやってみる。
qemu(x86_64エミュレータ、ラズパイじゃない)でやってみた。
生成したnerves_livebook.imgを、quemで起動してテストしてみた。
練習問題、やってみた。

練習問題

nxを使え。

写真

image.png

セットアップ

Mix.install([
  {:nx, "~> 0.6"}
])

サンプルコード

xorを学習して、推定する。

defmodule Xor do
  import Nx.Defn

  defn init_random_params do
    key = Nx.Random.key(42)
    {w1, new_key} = Nx.Random.normal(key, 0.0, 0.1, shape: {2, 8}, names: [:input, :layer])
    {b1, new_key} = Nx.Random.normal(new_key, 0.0, 0.1, shape: {8}, names: [:layer])
    {w2, new_key} = Nx.Random.normal(new_key, 0.0, 0.1, shape: {8, 1}, names: [:layer, :output])
    {b2, _new_key} = Nx.Random.normal(new_key, 0.0, 0.1, shape: {1}, names: [:output])
    {w1, b1, w2, b2}
  end

  defn predict({w1, b1, w2, b2}, x) do
    x
    |> Nx.dot(w1)
    |> Nx.add(b1)
    |> Nx.tanh()
    |> Nx.dot(w2)
    |> Nx.add(b2)
    |> Nx.sigmoid()
  end

  defn loss({w1, b1, w2, b2}, x, y) do
    preds = predict({w1, b1, w2, b2}, x)
    Nx.mean(Nx.power(y - preds, 2))
  end

  defn update({w1, b1, w2, b2} = params, x, y, step) do
    {grad_w1, grad_b1, grad_w2, grad_b2} = grad(params, &loss(&1, x, y))
    {w1 - grad_w1 * step, b1 - grad_b1 * step, w2 - grad_w2 * step, b2 - grad_b2 * step}
  end

  def train(params, x, y) do
    for i <- 0..31, reduce: params do
      acc ->
        update(acc, x[i], y[i], 0.1)
    end
  end
end

z0 = for _ <- 0..31, do: Enum.random(0..1)
z1 = for _ <- 0..31, do: Enum.random(0..1)
x0 = Nx.tensor(z0)
x1 = Nx.tensor(z1)
y = Nx.logical_xor(x0, x1)
x = Nx.concatenate([Nx.reshape(x0, {32, 1}), Nx.reshape(x1, {32, 1})], axis: 1)

params = Xor.init_random_params()

Xor.loss(params, x[0], y[0])
|> IO.inspect()

params =
  for i <- 0..310, reduce: params do
    acc ->
      for i <- 0..31, reduce: acc do
        bcc ->
          Xor.update(bcc, x[i], y[i], 0.1)
      end
  end

Xor.loss(params, x[0], y[0])
|> IO.inspect()

IO.puts("0 1")

Xor.predict(params, Nx.tensor([0, 1]))
|> IO.inspect()

IO.puts("0 0")

Xor.predict(params, Nx.tensor([0, 0]))
|> IO.inspect()

IO.puts("1 1")

Xor.predict(params, Nx.tensor([1, 1]))
|> IO.inspect()

IO.puts("1 0")

Xor.predict(params, Nx.tensor([1, 0]))
|> IO.inspect()

実行結果

warning: variable "i" is unused (if the variable is not meant to be used, prefix it with an underscore)
  /data/livebook/xor.livemd#cell:zc3vabsfd6ftw2ineimecdebjxd3whxt:54

#Nx.Tensor<
  f32
  0.2594543695449829
>
#Nx.Tensor<
  f32
  9.000560385175049e-4
>
0 1
#Nx.Tensor<
  f32[output: 1]
  [0.9725738763809204]
>
0 0
#Nx.Tensor<
  f32[output: 1]
  [0.03000093437731266]
>
1 1
#Nx.Tensor<
  f32[output: 1]
  [0.02718157321214676]
>
1 0
#Nx.Tensor<
  f32[output: 1]
  [0.9590940475463867]
>
warning: Nx.power/2 is deprecated. Use pow/2 instead
  /data/livebook/xor.livemd#cell:zc3vabsfd6ftw2ineimecdebjxd3whxt:25: Xor."__defn:loss__"/3


以上。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0