LoginSignup
14
1

Livebook + Nx で AtCoder に再入門 -ABC042A-

Last updated at Posted at 2023-11-17

はじめに

約半年前、2023年06月30日に「Livebook で AtCoder 始めてみた」という記事を投稿しました

それから約1ヶ月後、2023年8月6日から2023年8月12日にかけて、 AtCoder で使用可能な言語のバージョンが更新されました

この更新により Elixir は 1.10.2 から 1.15.2 にアップグレードされ、 Nx も使えるようになりました

Nx が使える = 行列演算ができるということなので、同じ問題でも全く違う解法が使えます

その他、使えるようになった関数は以下の記事にまとめています

その後、結局 AtCoder に触れていなかったのですが、改めて再入門しようと思います

実際に AtCoder 新バージョンで Nx を使ってみました

回答してみるのは ABC042A です

セットアップ

AtCoder の新環境では Nx に加え EXLA もインストールされており、デフォルトバックエンドの設定もされています

そのため、 Livebook でも同じ設定をしておきます

Mix.install(
  [
    {:nx, "~> 0.6"},
    {:exla, "~> 0.6"}
  ],
  config: [
    nx: [
      default_backend: EXLA.Backend,
      default_defn_options: [compiler: EXLA]
    ]
  ]
)

ここで注意点があります

EXLA バックエンドの場合、テンソル内の各値が同じであっても == で比較すると結果が false になります

EXLA バックエンドでは、値以外にも参照している情報があるようです

スクリーンショット 2023-11-15 23.49.36.png

ちなみにバイナリーバックエンドの場合は true になります

スクリーンショット 2023-11-15 23.50.28.png

単純に == で比較できないので、値が全て同じであることを判定する関数を別途用意します

Nx を使わない場合

以前のバージョンで実装したモジュールです

defmodule Main do
  def main do
    :stdio
    |> IO.read(:all)
    |> solve()
    |> IO.puts()
  end

  defp split_words(words) do
    String.split(words, " ")
  end

  def solve(input) do
    input
    |> String.trim()
    |> split_words()
    |> Enum.sort()
    |> case do
      ["5", "5", "7"] -> "YES"
      _ -> "NO"
    end
  end
end

実行時間は 784 msec でした

ノートブックはこちら

EXLA バックエンドを使う場合

defn で高速化可能な処理は defn を使い、テンソルの値を比較します

defmodule Main do
  import Nx.Defn

  def main do
    :stdio
    |> IO.read(:all)
    |> solve()
    |> IO.puts()
  end

  defp split_words(words) do
    String.split(words, " ")
  end

  defn equal_tensor(left, right) do
    Nx.equal(left, right)
    |> Nx.all()
  end

  def equal(left, right) do
    equal_tensor(left, right)
    |> Nx.to_number()
    |> then(& &1 == 1)
  end

  def solve(input) do
    input
    |> String.trim()
    |> split_words()
    |> Enum.map(&String.to_integer(&1))
    |> Nx.tensor()
    |> Nx.sort()
    |> equal(Nx.tensor([5, 5, 7]))
    |> if(do: "YES", else: "NO")
  end
end

実行時間は 1959 msec でした

ノートブックはこちら

バイナリーバックエンドを使う場合

もちろん、明示的にバイナリーバックエンドを使えばもっと単純化できます

defmodule Main do
  @target Nx.tensor([5, 5, 7], backend: Nx.BinaryBackend)

  def main do
    :stdio
    |> IO.read(:all)
    |> solve()
    |> IO.puts()
  end

  defp split_words(words) do
    String.split(words, " ")
  end

  def solve(input) do
    input
    |> String.trim()
    |> split_words()
    |> Enum.map(&String.to_integer(&1))
    |> Nx.tensor(backend: Nx.BinaryBackend)
    |> Nx.sort()
    |> case do
      @target -> "YES"
      _ -> "NO"
    end
  end
end

実行時間は 1679 msec でした

ノートブックはこちら

まとめ

今回の問題は全く行列演算することにメリットがないものだったため、余計に時間がかかってしまいました

とりあえず Nx を使うことはできたので、他の問題も行列演算で解いてみたいと思います

14
1
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
1