LoginSignup
9
7

More than 1 year has passed since last update.

Nxのdefnで書けること書けないこと

Last updated at Posted at 2021-09-26

順次書いていきます。

defn の中に記述できないNx関数

次のようなプログラムは Nx.to_batched_list/2 is not allowed inside defn
というエラーになりました。

defmodule DefnNest do
  import Nx.Defn

  defn to_list_and_concat(t) do
    t
    |> Nx.to_batched_list(1)
    |> Nx.concatenate()
  end
end

defn の中に書けるかどうかを決めているのは,Nx.Defn.Compiler で定義されている @forbidden_ops です。

Nx.Defn.Compiler
...
  # These operations do not have valid meaning for Nx.Defn.Expr
  @forbidden_ops [:backend_copy, :backend_deallocate, :backend_transfer] ++
                   [:to_binary, :to_scalar, :to_flat_list, :to_heatmap, :to_batched_list]
...

ここに記述されているNx関数は,defnの中で使用することができません。

これらの関数は,Nx.Defn.Kernel.transform/2 の中でも使用できません。たとえば次のように書いた時,

  defn to_list_and_concat(t) do
    t
    |> transform(fn t -> t |> Nx.to_batched_list(1) |> Nx.concatenate() end)
  end

実行するとエラーになります。このエラーは厄介なことにコンパイル時にはエラーにならず,実行してはじめてエラーになります。

iex> Nx.tensor([1,2]) |> to_list_and_concat()
** (ArgumentError) cannot invoke to_batched_list/3 on Nx.Defn.Expr.

This typically means you are invoking an unsupported Nx function
by code inside `defn` or JIT/AOT compiled code

    (nx 0.1.0-dev) lib/nx/defn/expr.ex:754: Nx.Defn.Expr.to_batched_list/3
    (defn_test 0.1.0) lib/defn_test.ex:33: anonymous fn/1 in DefnTest."__defn:to_list_and_concat__"/1
    (defn_test 0.1.0) lib/defn_test.ex:31: anonymous fn/3 in DefnTest.to_list_and_concat/1
    (nx 0.1.0-dev) lib/nx/defn/evaluator.ex:27: Nx.Defn.Evaluator.__jit__/4
iex(1)> 

Nx.to_batched_list(t, 1)と等価な処理をdefnに書きたい

Nx.to_batched_list(t, 1) は次のように,与えられたテンソルを,1ランク下のテンソルのリストに分解する処理なので,とても重宝しますが,残念ながら前述のようにdefnの中では使えません。

iex> t = Nx.iota({2,2})         
#Nx.Tensor<
  s64[2][2]
  [
    [0, 1],
    [2, 3]
  ]
>
iex> Nx.to_batched_list(t, 1)
[#Nx.Tensor<
    s64[1][2]
    [
      [0, 1]
    ]
>, #Nx.Tensor<
    s64[1][2]
    [
      [2, 3]
    ]
>]

幸い,Nx.to_batched_list(t, 1)と等価な処理を次のように書けます。

iex> t = Nx.iota({2,2})         
#Nx.Tensor<
  s64[2][2]
  [
    [0, 1],
    [2, 3]
  ]
>
iex> Enum.map(0..((Nx.shape(t) |> elem(0)) - 1), & Nx.new_axis(t[&1], 0))
[#Nx.Tensor<
    s64[1][2]
    [
      [0, 1]
    ]
>, #Nx.Tensor<
    s64[1][2]
    [
      [2, 3]
    ]
>]

この記法だと,Nx.Defn.Kernel.transform/2を使うことでdefnの中に書けます。

  defn to_list_and_concat(t) do
    transform(t, fn t -> Enum.map(0..((Nx.shape(t) |> elem(0)) - 1), & Nx.new_axis(t[&1], 0)) |> Nx.concatenate() end)
  end

使ってみると次のような感じです。

iex> t = Nx.iota({2,2})         
#Nx.Tensor<
  s64[2][2]
  [
    [0, 1],
    [2, 3]
  ]
>
iex(7)> to_list_and_concat(t)
#Nx.Tensor<
  s64[2][2]
  [
    [0, 1],
    [2, 3]
  ]
>

引数をスカラー値に限定するようなガード条件を書きたい

現状,defnにguardを書くことはできません。例えば引数 n にはスカラー値であってほしいということで,次のように書きたくなることがあるのですが,エラーになります。

defn add_tensor_and_scalar(t, s) when is_number(s) or Nx.size(s) == 1 do
  Nx.add(t, s)
end

このような時に便利な記述を定義した次のPRが取り込まれました。(既存のプロジェクトに取り込む時には,mix deps.clean --all rm -rf mix.lock deps _build としてから mix deps.get とします。)

このような時には,Defn.Kernel.assert_shape/2 を使って次のように書きます。

  defn add_tensor_and_scalar(t, s) do
    assert_shape(s, {})
    Nx.add(t, s)
  end

実行してみると次のような感じです。

iex> DefnTest.add_tensor_and_scalar(Nx.tensor([1,2]), 1)
#Nx.Tensor<
  s64[2] 
  [2, 3]
>
iex> DefnTest.add_tensor_and_scalar(Nx.tensor([1,2]), Nx.tensor([1, 2]))
** (ArgumentError) expected tensor to be a scalar, got tensor with shape {2}
    (nx 0.1.0-dev) lib/nx/defn/kernel.ex:1050: Nx.Defn.Kernel.assert_shape/2
    (defn_test 0.1.0) lib/defn_test.ex:21: DefnTest."__defn:add_tensor_and_scalar__"/2
    (defn_test 0.1.0) lib/defn_test.ex:20: anonymous fn/3 in DefnTest.add_tensor_and_scalar/2
    (nx 0.1.0-dev) lib/nx/defn/evaluator.ex:27: Nx.Defn.Evaluator.__jit__/4
iex> 

なお,defn に与える引数は,数値であったとしてもスカラー値を表す Nx.Tensor に変換されるとのことです。

iex> DefnTest.add_tensor_and_scalar(Nx.tensor([1,2]), Nx.tensor(1))  
#Nx.Tensor<
  s64[2] 
  [2, 3]
>

要素数1のベクトルや1x1の行列は,受け付けません。

iex(7)> DefnTest.add_tensor_and_scalar(Nx.tensor([1,2]), Nx.tensor([1]))
** (ArgumentError) expected tensor to be a scalar, got tensor with shape {1}
    (nx 0.1.0-dev) lib/nx/defn/kernel.ex:1050: Nx.Defn.Kernel.assert_shape/2
    (defn_test 0.1.0) lib/defn_test.ex:21: DefnTest."__defn:add_tensor_and_scalar__"/2
    (defn_test 0.1.0) lib/defn_test.ex:20: anonymous fn/3 in DefnTest.add_tensor_and_scalar/2
    (nx 0.1.0-dev) lib/nx/defn/evaluator.ex:27: Nx.Defn.Evaluator.__jit__/4

Nx.Defn.Kernel.assert_shape_pattern/2はうまく動きませんでしたので,PRにコメントしました。 Fix されました。

  defn add_and_matrix(m1, m2) do
    assert_shape_pattern(m1, {_, _})
    assert_shape_pattern(m2, {_, _})
    Nx.add(m1, m2)
  end

このようにすると,第1引数と第2引数が共に行列の時のみ実行され,それ以外の場合はエラーになります。

iex> DefnTest.add_and_matrix(Nx.tensor([[1,2], [3,4]]), Nx.tensor([[2,3], [3,4]]))
#Nx.Tensor<
  s64[2][2]
  [
    [3, 5],
    [6, 8]
  ]
>
iex> DefnTest.add_and_matrix(Nx.tensor([1,2]), Nx.tensor([2,3]))
** (ArgumentError) expected tensor to match shape {_, _}, got tensor with shape {2}
    (nx 0.1.0-dev) lib/nx/defn/kernel.ex:1133: Nx.Defn.Kernel.__assert_shape_pattern__!/2
    (defn_test 0.1.0) lib/defn_test.ex:26: DefnTest."__defn:add_and_matrix__"/2
    (defn_test 0.1.0) lib/defn_test.ex:25: anonymous fn/3 in DefnTest.add_and_matrix/2
    (nx 0.1.0-dev) lib/nx/defn/evaluator.ex:27: Nx.Defn.Evaluator.__jit__/4
iex>

スカラーやベクターの時(階数が1以下のテンソルの時)にエラーにしたい

次のように書きます。

defn assert_rank_greater_than_1 do
  transform(t, fn t -> Nx.rank(t) <= 1 && raise ArgumentError, "expected the rank of input tensor #{inspect(t)} to be grater than 1" end)
end

使ってみると次のような感じです。

iex> t = Nx.tensor(1)
#Nx.Tensor<
  s64
  1
>
iex> assert_rank_greater_than_1(t)
** (ArgumentError) expected the rank of input tensor #Nx.Tensor<
  s64

  Nx.Defn.Expr
  parameter a  s64
> to be grater than 1
(snip)
iex> t = Nx.tensor([1, 2])
#Nx.Tensor<
  s64[2] 
  [1, 2]
>
iex> assert_rank_greater_than_1(t)
** (ArgumentError) expected the rank of input tensor #Nx.Tensor<
  s64

  Nx.Defn.Expr
  parameter a  s64
> to be grater than 1
(snip)
iex> t = Nx.tensor([[1, 2], [3, 4]])
#Nx.Tensor<
  s64[2][2]
  [
    [1, 2],
    [3, 4]
  ]
>
iex> assert_rank_greater_than_1(t)
#Nx.Tensor<
  s64[2][2]
  [
    [1, 2],
    [3, 4]
  ]
>

defn で定数との比較をする

次のようなコードを定義したとします。

  defn pattern_match_fn(t) do
    transform(t, fn t -> pattern_match_sub(t) end)
  end

  def pattern_match_f(t) do
    pattern_match_sub(t)
  end

  defp pattern_match_sub(t) do
    if t == Nx.tensor(1) do
      IO.puts("matched")
      t
    else
      IO.puts("unmatched")
      t
    end
  end

def を使った場合には Nx.tensor(1)と等しい場合に matchedと表示し,そうでない場合には unmatchedと表示します。

iex> match_f(Nx.tensor(1))
matched
#Nx.Tensor<
  s64
  1
>
iex> match_f(Nx.tensor(2))
unmatched
#Nx.Tensor<
  s64
  2
>

しかし,defn を使った場合には,なぜかどちらも unmatched になります。

iex> match_fn(Nx.tensor(1))
unmatched
#Nx.Tensor<
  s64
  1
>
iex> match_fn(Nx.tensor(2))
unmatched
#Nx.Tensor<
  s64
  2
>

Nx 開発者の Paulo Valente に相談したのですが,解決せず,Issue を書くことにしました。

José Valim 曰く,

要は defn はコンパイル時に評価される式なのに対し,transform/2 で呼び出した先の式は実行時に評価される式なので,同一性を評価できないということです。

というわけで次のようにする必要があります。

  defn match_fn(t) do
    if t == 1 do
      t * 2
    else
      t / 2
    end
  end

しかし,次のようにしてもうまく動きません。

  defn match_fn(t) do
    if t == 1 do
      transform(
        t,
        fn t ->
          IO.puts("matched")
          t
        end
      )
    else
      transform(
        t,
        fn t ->
          IO.puts("unmatched")
          t
        end
      )
    end
  end

実行すると次のようになります。

iex> match_fn(1)
matched
unmatched
#Nx.Tensor<
  s64
  1
>

defn の中の条件分岐の分岐先は,実行時には両方評価されるということなんですかね。

現在,defn の中で IO.inspect などでデバッグする手段を構想中ということでした。

9
7
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
7