順次書いていきます。
defn
の中に記述できないNx関数
次のようなプログラムは Nx.to_batched_list/2 is not allowed inside defn
というエラーになりました。
defmodule DefnNest do
import Nx.Defn
defn to_list_and_concat(t) do
t
|> Nx.to_batched_list(1)
|> Nx.concatenate()
end
end
defn
の中に書けるかどうかを決めているのは,Nx.Defn.Compiler
で定義されている @forbidden_ops
です。
...
# These operations do not have valid meaning for Nx.Defn.Expr
@forbidden_ops [:backend_copy, :backend_deallocate, :backend_transfer] ++
[:to_binary, :to_scalar, :to_flat_list, :to_heatmap, :to_batched_list]
...
ここに記述されているNx関数は,defn
の中で使用することができません。
これらの関数は,Nx.Defn.Kernel.transform/2
の中でも使用できません。たとえば次のように書いた時,
defn to_list_and_concat(t) do
t
|> transform(fn t -> t |> Nx.to_batched_list(1) |> Nx.concatenate() end)
end
実行するとエラーになります。このエラーは厄介なことにコンパイル時にはエラーにならず,実行してはじめてエラーになります。
iex> Nx.tensor([1,2]) |> to_list_and_concat()
** (ArgumentError) cannot invoke to_batched_list/3 on Nx.Defn.Expr.
This typically means you are invoking an unsupported Nx function
by code inside `defn` or JIT/AOT compiled code
(nx 0.1.0-dev) lib/nx/defn/expr.ex:754: Nx.Defn.Expr.to_batched_list/3
(defn_test 0.1.0) lib/defn_test.ex:33: anonymous fn/1 in DefnTest."__defn:to_list_and_concat__"/1
(defn_test 0.1.0) lib/defn_test.ex:31: anonymous fn/3 in DefnTest.to_list_and_concat/1
(nx 0.1.0-dev) lib/nx/defn/evaluator.ex:27: Nx.Defn.Evaluator.__jit__/4
iex(1)>
Nx.to_batched_list(t, 1)
と等価な処理をdefn
に書きたい
Nx.to_batched_list(t, 1)
は次のように,与えられたテンソルを,1ランク下のテンソルのリストに分解する処理なので,とても重宝しますが,残念ながら前述のようにdefn
の中では使えません。
iex> t = Nx.iota({2,2})
#Nx.Tensor<
s64[2][2]
[
[0, 1],
[2, 3]
]
>
iex> Nx.to_batched_list(t, 1)
[#Nx.Tensor<
s64[1][2]
[
[0, 1]
]
>, #Nx.Tensor<
s64[1][2]
[
[2, 3]
]
>]
幸い,Nx.to_batched_list(t, 1)
と等価な処理を次のように書けます。
iex> t = Nx.iota({2,2})
#Nx.Tensor<
s64[2][2]
[
[0, 1],
[2, 3]
]
>
iex> Enum.map(0..((Nx.shape(t) |> elem(0)) - 1), & Nx.new_axis(t[&1], 0))
[#Nx.Tensor<
s64[1][2]
[
[0, 1]
]
>, #Nx.Tensor<
s64[1][2]
[
[2, 3]
]
>]
この記法だと,Nx.Defn.Kernel.transform/2
を使うことでdefn
の中に書けます。
defn to_list_and_concat(t) do
transform(t, fn t -> Enum.map(0..((Nx.shape(t) |> elem(0)) - 1), & Nx.new_axis(t[&1], 0)) |> Nx.concatenate() end)
end
使ってみると次のような感じです。
iex> t = Nx.iota({2,2})
#Nx.Tensor<
s64[2][2]
[
[0, 1],
[2, 3]
]
>
iex(7)> to_list_and_concat(t)
#Nx.Tensor<
s64[2][2]
[
[0, 1],
[2, 3]
]
>
引数をスカラー値に限定するようなガード条件を書きたい
現状,defn
にguardを書くことはできません。例えば引数 n
にはスカラー値であってほしいということで,次のように書きたくなることがあるのですが,エラーになります。
defn add_tensor_and_scalar(t, s) when is_number(s) or Nx.size(s) == 1 do
Nx.add(t, s)
end
このような時に便利な記述を定義した次のPRが取り込まれました。(既存のプロジェクトに取り込む時には,mix deps.clean --all
rm -rf mix.lock deps _build
としてから mix deps.get
とします。)
このような時には,Defn.Kernel.assert_shape/2
を使って次のように書きます。
defn add_tensor_and_scalar(t, s) do
assert_shape(s, {})
Nx.add(t, s)
end
実行してみると次のような感じです。
iex> DefnTest.add_tensor_and_scalar(Nx.tensor([1,2]), 1)
#Nx.Tensor<
s64[2]
[2, 3]
>
iex> DefnTest.add_tensor_and_scalar(Nx.tensor([1,2]), Nx.tensor([1, 2]))
** (ArgumentError) expected tensor to be a scalar, got tensor with shape {2}
(nx 0.1.0-dev) lib/nx/defn/kernel.ex:1050: Nx.Defn.Kernel.assert_shape/2
(defn_test 0.1.0) lib/defn_test.ex:21: DefnTest."__defn:add_tensor_and_scalar__"/2
(defn_test 0.1.0) lib/defn_test.ex:20: anonymous fn/3 in DefnTest.add_tensor_and_scalar/2
(nx 0.1.0-dev) lib/nx/defn/evaluator.ex:27: Nx.Defn.Evaluator.__jit__/4
iex>
なお,defn
に与える引数は,数値であったとしてもスカラー値を表す Nx.Tensor
に変換されるとのことです。
iex> DefnTest.add_tensor_and_scalar(Nx.tensor([1,2]), Nx.tensor(1))
#Nx.Tensor<
s64[2]
[2, 3]
>
要素数1のベクトルや1x1の行列は,受け付けません。
iex(7)> DefnTest.add_tensor_and_scalar(Nx.tensor([1,2]), Nx.tensor([1]))
** (ArgumentError) expected tensor to be a scalar, got tensor with shape {1}
(nx 0.1.0-dev) lib/nx/defn/kernel.ex:1050: Nx.Defn.Kernel.assert_shape/2
(defn_test 0.1.0) lib/defn_test.ex:21: DefnTest."__defn:add_tensor_and_scalar__"/2
(defn_test 0.1.0) lib/defn_test.ex:20: anonymous fn/3 in DefnTest.add_tensor_and_scalar/2
(nx 0.1.0-dev) lib/nx/defn/evaluator.ex:27: Nx.Defn.Evaluator.__jit__/4
Fix されました。Nx.Defn.Kernel.assert_shape_pattern/2
はうまく動きませんでしたので,PRにコメントしました。
defn add_and_matrix(m1, m2) do
assert_shape_pattern(m1, {_, _})
assert_shape_pattern(m2, {_, _})
Nx.add(m1, m2)
end
このようにすると,第1引数と第2引数が共に行列の時のみ実行され,それ以外の場合はエラーになります。
iex> DefnTest.add_and_matrix(Nx.tensor([[1,2], [3,4]]), Nx.tensor([[2,3], [3,4]]))
#Nx.Tensor<
s64[2][2]
[
[3, 5],
[6, 8]
]
>
iex> DefnTest.add_and_matrix(Nx.tensor([1,2]), Nx.tensor([2,3]))
** (ArgumentError) expected tensor to match shape {_, _}, got tensor with shape {2}
(nx 0.1.0-dev) lib/nx/defn/kernel.ex:1133: Nx.Defn.Kernel.__assert_shape_pattern__!/2
(defn_test 0.1.0) lib/defn_test.ex:26: DefnTest."__defn:add_and_matrix__"/2
(defn_test 0.1.0) lib/defn_test.ex:25: anonymous fn/3 in DefnTest.add_and_matrix/2
(nx 0.1.0-dev) lib/nx/defn/evaluator.ex:27: Nx.Defn.Evaluator.__jit__/4
iex>
スカラーやベクターの時(階数が1以下のテンソルの時)にエラーにしたい
次のように書きます。
defn assert_rank_greater_than_1 do
transform(t, fn t -> Nx.rank(t) <= 1 && raise ArgumentError, "expected the rank of input tensor #{inspect(t)} to be grater than 1" end)
end
使ってみると次のような感じです。
iex> t = Nx.tensor(1)
#Nx.Tensor<
s64
1
>
iex> assert_rank_greater_than_1(t)
** (ArgumentError) expected the rank of input tensor #Nx.Tensor<
s64
Nx.Defn.Expr
parameter a s64
> to be grater than 1
(snip)
iex> t = Nx.tensor([1, 2])
#Nx.Tensor<
s64[2]
[1, 2]
>
iex> assert_rank_greater_than_1(t)
** (ArgumentError) expected the rank of input tensor #Nx.Tensor<
s64
Nx.Defn.Expr
parameter a s64
> to be grater than 1
(snip)
iex> t = Nx.tensor([[1, 2], [3, 4]])
#Nx.Tensor<
s64[2][2]
[
[1, 2],
[3, 4]
]
>
iex> assert_rank_greater_than_1(t)
#Nx.Tensor<
s64[2][2]
[
[1, 2],
[3, 4]
]
>
defn
で定数との比較をする
次のようなコードを定義したとします。
defn pattern_match_fn(t) do
transform(t, fn t -> pattern_match_sub(t) end)
end
def pattern_match_f(t) do
pattern_match_sub(t)
end
defp pattern_match_sub(t) do
if t == Nx.tensor(1) do
IO.puts("matched")
t
else
IO.puts("unmatched")
t
end
end
def
を使った場合には Nx.tensor(1)
と等しい場合に matched
と表示し,そうでない場合には unmatched
と表示します。
iex> match_f(Nx.tensor(1))
matched
#Nx.Tensor<
s64
1
>
iex> match_f(Nx.tensor(2))
unmatched
#Nx.Tensor<
s64
2
>
しかし,defn
を使った場合には,なぜかどちらも unmatched
になります。
iex> match_fn(Nx.tensor(1))
unmatched
#Nx.Tensor<
s64
1
>
iex> match_fn(Nx.tensor(2))
unmatched
#Nx.Tensor<
s64
2
>
Nx 開発者の Paulo Valente に相談したのですが,解決せず,Issue を書くことにしました。
José Valim 曰く,
要は defn
はコンパイル時に評価される式なのに対し,transform/2
で呼び出した先の式は実行時に評価される式なので,同一性を評価できないということです。
というわけで次のようにする必要があります。
defn match_fn(t) do
if t == 1 do
t * 2
else
t / 2
end
end
しかし,次のようにしてもうまく動きません。
defn match_fn(t) do
if t == 1 do
transform(
t,
fn t ->
IO.puts("matched")
t
end
)
else
transform(
t,
fn t ->
IO.puts("unmatched")
t
end
)
end
end
実行すると次のようになります。
iex> match_fn(1)
matched
unmatched
#Nx.Tensor<
s64
1
>
defn
の中の条件分岐の分岐先は,実行時には両方評価されるということなんですかね。
現在,defn
の中で IO.inspect
などでデバッグする手段を構想中ということでした。