はじめに
約半年前、2023年06月30日に「Livebook で AtCoder 始めてみた」という記事を投稿しました
それから約1ヶ月後、2023年8月6日から2023年8月12日にかけて、 AtCoder で使用可能な言語のバージョンが更新されました
この更新により Elixir は 1.10.2 から 1.15.2 にアップグレードされ、 Nx も使えるようになりました
Nx が使える = 行列演算ができるということなので、同じ問題でも全く違う解法が使えます
その他、使えるようになった関数は以下の記事にまとめています
その後、結局 AtCoder に触れていなかったのですが、改めて再入門しようと思います
実際に AtCoder 新バージョンで Nx を使ってみました
回答してみるのは ABC042A です
セットアップ
AtCoder の新環境では Nx に加え EXLA もインストールされており、デフォルトバックエンドの設定もされています
そのため、 Livebook でも同じ設定をしておきます
Mix.install(
[
{:nx, "~> 0.6"},
{:exla, "~> 0.6"}
],
config: [
nx: [
default_backend: EXLA.Backend,
default_defn_options: [compiler: EXLA]
]
]
)
ここで注意点があります
EXLA バックエンドの場合、テンソル内の各値が同じであっても ==
で比較すると結果が false になります
EXLA バックエンドでは、値以外にも参照している情報があるようです
ちなみにバイナリーバックエンドの場合は true になります
単純に ==
で比較できないので、値が全て同じであることを判定する関数を別途用意します
Nx を使わない場合
以前のバージョンで実装したモジュールです
defmodule Main do
def main do
:stdio
|> IO.read(:all)
|> solve()
|> IO.puts()
end
defp split_words(words) do
String.split(words, " ")
end
def solve(input) do
input
|> String.trim()
|> split_words()
|> Enum.sort()
|> case do
["5", "5", "7"] -> "YES"
_ -> "NO"
end
end
end
実行時間は 784 msec でした
ノートブックはこちら
EXLA バックエンドを使う場合
defn
で高速化可能な処理は defn
を使い、テンソルの値を比較します
defmodule Main do
import Nx.Defn
def main do
:stdio
|> IO.read(:all)
|> solve()
|> IO.puts()
end
defp split_words(words) do
String.split(words, " ")
end
defn equal_tensor(left, right) do
Nx.equal(left, right)
|> Nx.all()
end
def equal(left, right) do
equal_tensor(left, right)
|> Nx.to_number()
|> then(& &1 == 1)
end
def solve(input) do
input
|> String.trim()
|> split_words()
|> Enum.map(&String.to_integer(&1))
|> Nx.tensor()
|> Nx.sort()
|> equal(Nx.tensor([5, 5, 7]))
|> if(do: "YES", else: "NO")
end
end
実行時間は 1959 msec でした
ノートブックはこちら
バイナリーバックエンドを使う場合
もちろん、明示的にバイナリーバックエンドを使えばもっと単純化できます
defmodule Main do
@target Nx.tensor([5, 5, 7], backend: Nx.BinaryBackend)
def main do
:stdio
|> IO.read(:all)
|> solve()
|> IO.puts()
end
defp split_words(words) do
String.split(words, " ")
end
def solve(input) do
input
|> String.trim()
|> split_words()
|> Enum.map(&String.to_integer(&1))
|> Nx.tensor(backend: Nx.BinaryBackend)
|> Nx.sort()
|> case do
@target -> "YES"
_ -> "NO"
end
end
end
実行時間は 1679 msec でした
ノートブックはこちら
まとめ
今回の問題は全く行列演算することにメリットがないものだったため、余計に時間がかかってしまいました
とりあえず Nx を使うことはできたので、他の問題も行列演算で解いてみたいと思います