概要
シンボリック数式を入力や出力とする機械学習タスクにおいて、数式をランダムに生成したくなることがあります。今回は Lample and Charton, 2019 の Appendix C を参考に、数式構文木をランダム生成するコードを実装してみました。さらに得られた構文木をパースすることで、最終的には下図のようにたくさんの数式を生成することができます。
実装
全体のアルゴリズムを記した擬似コードは次のようになります (元論文 Appendix C より引用)。最終的に現れてほしい演算子の個数を$n$とし、1ループごとに1個ずつ演算子をサンプリングしながら木を育てていきます。$e$は各時点で値が確定していないノードの個数を表しています。
Start with an empty node, set e = 1;
while n > 0 do
Sample a position k and arity a from L(e, n) (if a = 1 the next internal node is unary);
Sample the k next empty nodes as leaves;
if a = 1 then
Sample a unary operator;
Create one empty child;
Set e = e − k;
end
else
Sample a binary operator;
Create two empty children;
Set e = e − k + 1;
end
Set n = n − 1;
end
1. 必要なライブラリのインポート
まず必要なライブラリをインポートしておきます。シンボリック数式の扱いには SymPy を、構文木の扱いには NetworkX を使います。
import sympy as sp
import numpy as np
import networkx as nx
2. (k, a)
のサンプリング
いくつか記号の説明をします。
与えられた$e, n$に対し、$D(e, n)$は可能な部分木の個数を表します。たとえば$e = 0$ならば、既に全てのノードたちの値が確定しているため、可能な部分木の個数は$0$となります。また$e \neq 0, n = 0$ならば、未確定のノードたちに演算子を入れることはできず、全てに被演算子を入れて葉とするしかないため、可能な部分木の個数は$1$となります。
また、$k \in \{0, \dots, e-1\}$は次に被演算子を入れるノードの個数、$a \in \{1, 2\}$は次に入れる演算子の引数の個数を表します。
ここでは$D(e, n)$を用いて$(k, a)$の従う分布を与え、サンプリングの手続きを定義しています。
# define D(e, n)
def D(e, n):
if e == 0:
return 0
if n == 0:
return 1
return D(e - 1, n) + D(e, n - 1) + D(e + 1, n - 1)
# define a probability table of (k, a)
# for k = 0, ..., e - 1 and a = 0, 1
def P(e, n):
p = np.zeros((e, 2))
for k in range(e):
p[k, 0] = D(e - k, n - 1) / D(e, n)
p[k, 1] = D(e - k + 1, n - 1) / D(e, n)
return p
# return possible choices of (k, a)
def L(e):
return [{'k': k, 'a': a} for a in range(1, 3) for k in range(e)]
# sample a pair (k, a) of position and arity
def sample_position_and_arity(e, n):
res = np.random.choice(L(e), p=P(e, n).flatten())
return (res['k'], res['a'])
3. 演算子・被演算子のサンプリング
ここでは演算子・被演算子のサンプリングの手続きを定義しています。ここでは簡単のため、各演算子・被演算子の出方は一様分布に従うこととしていますが、別の分布を採用することもできます。
# possible symbols and operators
x = sp.symbols('x')
symbols = {
"x": x,
"pi": sp.pi
}
unary_ops = {
"neg": lambda x: -x,
"exp": lambda x: sp.exp(x, evaluate=False),
"log": lambda x: sp.log(x, evaluate=False),
"sqrt": lambda x: sp.sqrt(x, evaluate=False),
"sin": lambda x: sp.sin(x, evaluate=False),
"cos": lambda x: sp.cos(x, evaluate=False),
}
binary_ops = {
"+": lambda x, y: x + y,
"-": lambda x, y: x - y,
"*": lambda x, y: x * y,
"/": lambda x, y: sp.Rational(x, y) if x.is_Number and y.is_Number else x / y,
"pow": lambda x, y: sp.Pow(x, y, evaluate=False)
}
# sample an operator
def sample_operator(arity):
if arity == 1:
return np.random.choice(list(unary_ops.keys()))
else:
return np.random.choice(list(binary_ops.keys()))
# sample a symbol
def sample_symbol():
return np.random.choice(
np.concatenate((
np.arange(1, 6),
list(symbols.keys()),
))
)
4. 構文木のサンプリング
ここでは1個の構文木をサンプリングする手続きを定義しています。このコードが上の擬似コードにほぼ対応します。
def sample_graph(n):
# Initialize an empty graph
G = nx.DiGraph()
def insert_node(value):
name = len(G.nodes)
G.add_node(name, value=value)
return (name, G.nodes[name])
e = 1
empty_nodes = [insert_node(None)]
while n > 0:
k, a = sample_position_and_arity(e, n)
# update empty nodes with leaf values
for _ in range(k):
empty_nodes.pop(-1)[1]["value"] = sample_symbol()
if a == 1:
op = sample_operator(a)
op_node = empty_nodes.pop(-1)
op_node[1]["value"] = op
n1_node = insert_node(None)
empty_nodes.append(n1_node)
G.add_edge(op_node[0], n1_node[0])
e = e - k
else:
op = sample_operator(a)
op_node = empty_nodes.pop(-1)
op_node[1]["value"] = op
n1_node = insert_node(None)
n2_node = insert_node(None)
empty_nodes.append(n1_node)
empty_nodes.append(n2_node)
G.add_edge(op_node[0], n1_node[0])
G.add_edge(op_node[0], n2_node[0])
e = e - k + 1
n -= 1
# update empty nodes with leaf values
for _ in range(len(empty_nodes)):
empty_nodes.pop(-1)[1]["value"] = sample_symbol()
return G
5. 構文木から数式への変換
ここでは構文木を数式に変換する関数を定義しています。
# Convert a graph to a SymPy expression
def convert_graph_to_expr(graph):
# parse the graph to a prefix-style list of tokens
tokens = [graph.nodes[node]["value"] for node in nx.dfs_tree(G, source=0).nodes]
def parse(q):
token = q.pop(0)
if token in symbols:
return symbols[token]
elif token in unary_ops:
return unary_ops[token](parse(q))
elif token in binary_ops:
return binary_ops[token](parse(q), parse(q))
else:
return sp.Float(token)
return parse(tokens)
6. 数式を生成してみる
以上で実装は完了です。演算子の個数$n = 6$のもとでいくつか数式を生成してみると次のようになります。
n = 6
for _ in range(10):
G = sample_graph(n)
expr = convert_graph_to_expr(G)
display(expr)
おまけ
SymPy の lambdify
関数を使うと数式を関数化することができます。
n = 6
G = sample_graph(n)
expr = convert_graph_to_expr(G)
f = np.vectorize(sp.lambdify(x, expr))
X = np.linspace(-1, 1, 100)
plt.title("$" + sp.latex(expr) + "$")
plt.plot(X, f(X))
plt.show()