当記事ではAttentionを用いたTransformerライクなGNN(Graph Neural Network)を用いてセールスマン巡回問題(TSP)や配送計画問題(VRP)に取り組んだ研究であるAttention, Learn to Solve Routing Problems!の論文やPyTorch
実装について確認を行いました。
Attention, Learn to Solve Routing Problems!論文の確認
論文の概要
Attention, Learn to Solve Routing Problems!はAttentionに基づくTransformerライクなGNNを用いてルート最適化について取り組んだ研究です。
Encoderの処理概要(Attention, Learn to Solve Routing Problems!論文 Figure.1)
Attention, Learn to Solve Routing Problems!では上図のようなGNNに基づくEncoderが用いられています。MHA(Multi Head Attention)とFF(Feed-Forward)の組み合わせである点ではTransformerに類似した処理であると言えます。
Decoderの処理概要(Attention, Learn to Solve Routing Problems!論文 Figure.2)
同様にDecoderでは上図のような処理が行われます。図だけでは処理の詳細まで確認できないので次項ではEncoderとDecoderの式定義について詳しく確認します。
EncoderとDecoderの式定義
Encoderの式定義
Encoderの式定義は基本的にはTransformerの式定義に類似しています。下記では出力に近い順に数式をまとめました。
\begin{align}
\bar{h}^{N} &= \sum_{i=1}^{n} h_{i}^{N} \\
\hat{h}_{i} &= \mathrm{BN}^{l} \left[ h_{i}^{l-1} + \mathrm{MHA}_{i}^{l} \left( h_{1}^{l-1}, \cdots , h_{n}^{l-1} \right) \right] \\
h_{i}^{l} &= \mathrm{BN}^{l} \left[ \hat{h}_{i} + \mathrm{FF}^{l} \left( \hat{h}_{i} \right) \right] \\
l & \in \{ 1, \cdots , N \} \\
h_{i}^{0} &= W^{x} \mathbf{x}_{i} + \mathbf{b}^{x}, \, \mathbf{x}_{i} \in \mathbb{R}^{d_x}, \, h_{i}^{l} \in \mathbb{R}^{d_h}
\end{align}
Encoderは概ね上記のように式定義されます。1行目のグラフの特徴量ベクトル$\bar{h}^{N} \in \mathbb{R}^{d_h}$を抽出するにあたって2行目や3行目のMHAやFFの演算を繰り返した結果を用います。$N$はGNNのレイヤーの数、$n$はノードの数にそれぞれ対応することに注意しておくと良いです。また、入力の次元数の$d_x$はセールスマン巡回問題(TSP; Travelling Salesperson Problem)の場合$d_x=2$、中間層の特徴量ベクトルの次元数の$d_h$ではこの研究では$d_h=128$に指定されます。
Decoderの式定義
セールスマン巡回問題などの問題設定のインスタンスの$s$が与えられた場合、Decoderの式の理解にあたっては訪れるノードの順番(tour)の$\boldsymbol{\pi}$を下記のように定義することにまず着目すると良いです。
\begin{align}
\boldsymbol{\pi} = (\pi_{1}, \pi_{2}, \cdots , \pi_{n})
\end{align}
上記の$\pi_{t}, t \in {1, \cdots , n}$は$t$番目に訪れる場所のように解釈すると良いです。Decoderでは$t$番目に訪れる場所が$i$である確率を下記のように定義します。
\begin{align}
p_{\theta}(\pi_{t}=i | s, \boldsymbol{\pi}_{1:t-1}) &= \frac{e^{u_{(c)i}}}{\sum_{j} e^{u_{(c)j}}}
\end{align}
上記の$u_{(c)j}$は下記のように定義される関数です。
\begin{align}
u_{(c)j} &= \left\{ \,
\begin{aligned}
& C \cdot \tanh{\left( \frac{q_{(c)}^{\mathrm{T}} k_{j}}{\sqrt{d_{k}}} \right)} \quad \mathrm{if} \, j \neq \pi_{t'} \,\, {}^{\forall} t' < t \\
& - \infty \qquad \mathrm{otherwise}
\end{aligned}
\right. \\
-10 \, & \leq \, C \, \leq \, 10
\end{align}
上記の$-\infty$は既に訪れたノードに二度訪れないようにマスキングする意図で設定されています。また、$u_{(c)j}$に出てくるQueryの$q_{(c)}$とKeyの$k_{j}$はそれぞれ下記のように計算されます。
\begin{align}
q_{(c)} &= W^{Q} h_{(c)} \in \mathbb{R}^{d_{h}} \\
k_{j} &= W^{K} h_{i} \in \mathbb{R}^{d_{h}} \\
v_{i} &= W^{V} h_{i} \in \mathbb{R}^{d_{h}}
\end{align}
上記に出てくる$h_{(c)} \in \mathbb{R}^{3 d_{h}}$は下記のような式で定義されます。
\begin{align}
h_{(c)} = \left\{ \,
\begin{aligned}
& \left[ \bar{h}^{N}, h_{\pi_{t-1}}^{N}, h_{\pi_{1}} \right] \quad t>1 \\
& \left[ \bar{h}^{N}, v^{l}, v^{f} \right] \qquad t=1
\end{aligned}
\right.
\end{align}
上記の$[\cdot , \cdot , \cdot]$はベクトルの連結(concat)にあたっての演算子です。
方策勾配法(Policy Gradint)に基づくネットワークの学習
前項のネットワークの学習にあたって、セールスマン巡回問題では厳密な正解を用意するのが難しいです。この際に方策勾配法(Policy Gradient)基づく強化学習でネットワークの学習を行うことは原理的に可能であり、この研究では方策勾配法の手法の1つであるREINFORCEMENTが用いられています。このようにニューラルネットワークの学習にあたって方策勾配法を用いる手法はLLMのRLHF(Reinforcement Learning with Human Feedback)でも用いられているので合わせて抑えておくと良いです。
Attention, Learn to Solve Routing Problems!の論文では下記のように目的関数や目的関数の勾配の式を定義・導出し、勾配降下法(Gradient Descent)を用いて学習を行います。
\begin{align}
L(\theta|s) &= \mathbb{E}_{(\boldsymbol{\pi}|s)} [L(\boldsymbol{\pi})] \\
\nabla L(\theta|s) &= \mathbb{E}_{(\boldsymbol{\pi}|s)} \left[ (L(\boldsymbol{\pi}) - b(s)) \nabla \log{p_{\theta} (\boldsymbol{\pi}|s)} \right]
\end{align}
上記の勾配の$\nabla L(\theta|s)$の式はlog勾配Trickなどを用いることで導出することができます。また、$b(s)$はベースライン(baseline)でこの研究では貪欲法(現在地に近いノードから選択する方法)に基づいて用意されます(greedy rollout)。
\begin{align}
b(s) &= M_{i} \\
M_{0} &= L(\boldsymbol{\pi}_{0}) \\
M_{i+1} &= \beta M_{i} + (1 - \beta) L(\boldsymbol{\pi}_{i})
\end{align}
ベースラインは上記のように計算されます。論文の式ではインデックスがありませんでしたが可読性を上げるにあたって当記事ではインデックスを追記しました。
また、Algorithmについては論文に記載があるので合わせて確認しておくと良いと思います。
REINFORCEMENTを用いた学習のアルゴリズム(Attention, Learn to Solve Routing Problems!論文 Algorithm.1)
Attention, Learn to Solve Routing Problems!のPyTorch実装の確認
PyTorch実装の入手とdemoの実行
Attention, Learn to Solve Routing Problems!の実装は上記より入手することができます。demo用のコードは用意されていないので、attention-learn-to-route/eval.py
を下記のように改変して実行すると出力がわかりやすいと思います。
if __name__ == "__main__":
...
for width in widths:
for dataset_path in opts.datasets:
costs, tours, durations = eval_dataset(dataset_path, width, opts.softmax_temperature, opts)
print("Cost0-9: {}, Tours0-9:{}, num_estimates: {}".format(costs[0:10], tours[0:10], len(costs)))
・実行コマンド
$ python eval.py -f data/tsp/tsp20_test_seed1234.pkl --model pretrained/tsp_20 --decode_strategy greedy
・実行結果
Average cost: 3.8469574451446533 +- 0.006181415915489197
Average serial duration: 0.0451378002166748 +- 0.0012258773853210062
Average parallel duration: 4.4079883024096486e-05
Calculated total duration: 0:00:00
Cost0-9: (3.66383, 3.7619956, 3.7999754, 4.118915, 3.294455, 3.8607297, 3.8752038, 3.9516256, 4.1289177, 4.0149283), Tours0-9:([16, 2, 15, 12, 4, 9, 17, 6, 1, 13, 19, 18, 3, 0, 14, 10, 7, 5, 11, 8], [6, 16, 11, 4, 10, 18, 2, 15, 9, 8, 7, 13, 17, 1, 5, 3, 19, 0, 14, 12], [12, 17, 11, 1, 19, 10, 3, 8, 6, 16, 14, 9, 13, 5, 0, 7, 4, 15, 18, 2], [8, 6, 11, 15, 18, 0, 17, 1, 19, 9, 5, 3, 10, 12, 4, 7, 16, 2, 14, 13], [15, 14, 5, 3, 10, 9, 11, 2, 1, 4, 19, 13, 8, 6, 7, 16, 0, 17, 12, 18], [7, 19, 2, 0, 11, 16, 1, 9, 18, 6, 12, 5, 15, 3, 14, 4, 10, 13, 17, 8], [13, 9, 14, 18, 12, 17, 10, 7, 15, 5, 11, 8, 4, 3, 1, 19, 2, 0, 16, 6], [4, 14, 19, 3, 6, 8, 18, 10, 5, 13, 17, 0, 16, 11, 2, 7, 12, 1, 9, 15], [2, 19, 8, 12, 9, 0, 10, 18, 15, 6, 13, 17, 5, 7, 4, 14, 16, 1, 3, 11], [15, 5, 14, 2, 17, 8, 11, 18, 3, 13, 12, 16, 6, 19, 1, 7, 10, 9, 0, 4]), num_estimates: 10000
実行結果より、attention-learn-to-route/eval.py
を実行することでdata/tsp/tsp20_test_seed1234.pkl
から10,000のセールスマン巡回問題をロードし、それぞれについて予測を行っていることが確認できます。
modelの実装の確認
以下attention-learn-to-route/eval.py
の内容を元にmodelの実装について確認します。
from utils import load_model, move_to
...
def eval_dataset(dataset_path, width, softmax_temp, opts):
# Even with multiprocessing, we load the model here since it contains the name where to write results
#print("multiprocessing: {}".format(opts.multiprocessing))
model, _ = load_model(opts.model)
...
results = _eval_dataset(model, dataset, width, softmax_temp, opts, device)
...
costs, tours, durations = zip(*results) # Not really costs since they should be negative
...
return costs, tours, durations
...
def _eval_dataset(model, dataset, width, softmax_temp, opts, device):
model.to(device)
model.eval()
...
dataloader = DataLoader(dataset, batch_size=opts.eval_batch_size)
results = []
for batch in tqdm(dataloader, disable=opts.no_progress_bar):
batch = move_to(batch, device)
start = time.time()
with torch.no_grad():
if opts.decode_strategy in ('sample', 'greedy'):
...
# This returns (batch_size, iter_rep shape)
sequences, costs = model.sample_many(batch, batch_rep=batch_rep, iter_rep=iter_rep)
...
if sequences is None:
sequences = [None] * batch_size
costs = [math.inf] * batch_size
else:
sequences, costs = get_best(
sequences.cpu().numpy(), costs.cpu().numpy(),
ids.cpu().numpy() if ids is not None else None,
batch_size
)
duration = time.time() - start
for seq, cost in zip(sequences, costs):
if model.problem.NAME == "tsp":
seq = seq.tolist() # No need to trim as all are same length
elif model.problem.NAME in ("cvrp", "sdvrp"):
seq = np.trim_zeros(seq).tolist() + [0] # Add depot
elif model.problem.NAME in ("op", "pctsp"):
seq = np.trim_zeros(seq) # We have the convention to exclude the depot
else:
assert False, "Unkown problem: {}".format(model.problem.NAME)
# Note VRP only
results.append((cost, seq, duration))
return results
上記より、modelのロードはattention-learn-to-route/utils/functions.py
に実装されているload_model
関数を用いて行われていることが確認できます。引数のopts.model
はReadME.md
の実行例ではpretrained/tsp_20
が用いられます。
def load_model(path, epoch=None):
from nets.attention_model import AttentionModel
from nets.pointer_network import PointerNetwork
...
model_class = {
'attention': AttentionModel,
'pointer': PointerNetwork
}.get(args.get('model', 'attention'), None)
assert model_class is not None, "Unknown model: {}".format(model_class)
model = model_class(
args['embedding_dim'],
args['hidden_dim'],
problem,
n_encode_layers=args['n_encode_layers'],
mask_inner=True,
mask_logits=True,
normalization=args['normalization'],
tanh_clipping=args['tanh_clipping'],
checkpoint_encoder=args.get('checkpoint_encoder', False),
shrink_size=args.get('shrink_size', None)
)
...
return model, args
上記より、AttentionModel
やPointerNetwork
がロードされていることが確認できます。当記事では以下attention-learn-to-route/nets/attention_model.py
に実装されているAttentionModel
クラスについて確認します。
from nets.graph_encoder import GraphAttentionEncoder
class AttentionModel(nn.Module):
def __init__(self,
embedding_dim,
hidden_dim,
problem,
n_encode_layers=2,
tanh_clipping=10.,
mask_inner=True,
mask_logits=True,
normalization='batch',
n_heads=8,
checkpoint_encoder=False,
shrink_size=None):
...
self.embedder = GraphAttentionEncoder(
n_heads=n_heads,
embed_dim=embedding_dim,
n_layers=self.n_encode_layers,
normalization=normalization
)
...
def forward(self, input, return_pi=False):
"""
:param input: (batch_size, graph_size, node_dim) input node features or dictionary with multiple tensors
:param return_pi: whether to return the output sequences, this is optional as it is not compatible with
using DataParallel as the results may be of different lengths on different GPUs
:return:
"""
if self.checkpoint_encoder and self.training: # Only checkpoint if we need gradients
embeddings, _ = checkpoint(self.embedder, self._init_embed(input))
else:
embeddings, _ = self.embedder(self._init_embed(input))
_log_p, pi = self._inner(input, embeddings)
cost, mask = self.problem.get_costs(input, pi)
# Log likelyhood is calculated within the model since returning it per action does not work well with
# DataParallel since sequences can be of different lengths
ll = self._calc_log_likelihood(_log_p, pi, mask)
if return_pi:
return cost, ll, pi
return cost, ll
上記より、エンコーダはattention-learn-to-route/nets/graph_encoder.py
のGraphAttentionEncoder
クラスで実装されていることが確認できます。
class MultiHeadAttentionLayer(nn.Sequential):
def __init__(
self,
n_heads,
embed_dim,
feed_forward_hidden=512,
normalization='batch',
):
super(MultiHeadAttentionLayer, self).__init__(
SkipConnection(
MultiHeadAttention(
n_heads,
input_dim=embed_dim,
embed_dim=embed_dim
)
),
Normalization(embed_dim, normalization),
SkipConnection(
nn.Sequential(
nn.Linear(embed_dim, feed_forward_hidden),
nn.ReLU(),
nn.Linear(feed_forward_hidden, embed_dim)
) if feed_forward_hidden > 0 else nn.Linear(embed_dim, embed_dim)
),
Normalization(embed_dim, normalization)
)
class GraphAttentionEncoder(nn.Module):
def __init__(
self,
n_heads,
embed_dim,
n_layers,
node_dim=None,
normalization='batch',
feed_forward_hidden=512
):
super(GraphAttentionEncoder, self).__init__()
# To map input to embedding space
self.init_embed = nn.Linear(node_dim, embed_dim) if node_dim is not None else None
self.layers = nn.Sequential(*(
MultiHeadAttentionLayer(n_heads, embed_dim, feed_forward_hidden, normalization)
for _ in range(n_layers)
))
def forward(self, x, mask=None):
assert mask is None, "TODO mask not yet supported!"
# Batch multiply to get initial embeddings of nodes
h = self.init_embed(x.view(-1, x.size(-1))).view(*x.size()[:2], -1) if self.init_embed is not None else x
h = self.layers(h)
return (
h, # (batch_size, graph_size, embed_dim)
h.mean(dim=1), # average to get embedding of graph, (batch_size, embed_dim)
)
上記よりGraphAttentionEncoder
クラスでは、n_layers
の数だけMultiHeadAttentionLayer
を用いてグラフニューラルネットワークの層を構築していることが確認できます。また、デコーダはAttentionModel
クラスの_inner
メソッド内で実装されています。
class AttentionModel(nn.Module):
...
def _inner(self, input, embeddings):
outputs = []
sequences = []
state = self.problem.make_state(input)
# Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
fixed = self._precompute(embeddings)
batch_size = state.ids.size(0)
# Perform decoding steps
i = 0
while not (self.shrink_size is None and state.all_finished()):
if self.shrink_size is not None:
unfinished = torch.nonzero(state.get_finished() == 0)
if len(unfinished) == 0:
break
unfinished = unfinished[:, 0]
# Check if we can shrink by at least shrink_size and if this leaves at least 16
# (otherwise batch norm will not work well and it is inefficient anyway)
if 16 <= len(unfinished) <= state.ids.size(0) - self.shrink_size:
# Filter states
state = state[unfinished]
fixed = fixed[unfinished]
log_p, mask = self._get_log_p(fixed, state)
# Select the indices of the next nodes in the sequences, result (batch_size) long
selected = self._select_node(log_p.exp()[:, 0, :], mask[:, 0, :]) # Squeeze out steps dimension
state = state.update(selected)
# Now make log_p, selected desired output size by 'unshrinking'
if self.shrink_size is not None and state.ids.size(0) < batch_size:
log_p_, selected_ = log_p, selected
log_p = log_p_.new_zeros(batch_size, *log_p_.size()[1:])
selected = selected_.new_zeros(batch_size)
log_p[state.ids[:, 0]] = log_p_
selected[state.ids[:, 0]] = selected_
# Collect output of step
outputs.append(log_p[:, 0, :])
sequences.append(selected)
i += 1
# Collected lists, return Tensor
return torch.stack(outputs, 1), torch.stack(sequences, 1)
def _get_log_p(self, fixed, state, normalize=True):
# Compute query = context node embedding
query = fixed.context_node_projected + \
self.project_step_context(self._get_parallel_step_context(fixed.node_embeddings, state))
# Compute keys and values for the nodes
glimpse_K, glimpse_V, logit_K = self._get_attention_node_data(fixed, state)
# Compute the mask
mask = state.get_mask()
# Compute logits (unnormalized log_p)
log_p, glimpse = self._one_to_many_logits(query, glimpse_K, glimpse_V, logit_K, mask)
if normalize:
log_p = torch.log_softmax(log_p / self.temp, dim=-1)
assert not torch.isnan(log_p).any()
return log_p, mask
学習の実装の確認
以下、当項では強化学習(方策勾配法)に基づく学習の実装について取りまとめます。まずReadME.md
で学習コマンドに用いられるattention-learn-to-route/run.py
について簡単に確認します。
from options import get_options
...
if __name__ == "__main__":
run(get_options())
まず、上記よりattention-learn-to-route/run.py
ではattention-learn-to-route/options.py
から学習時の条件を読み込んでいることが確認できます。
import argparse
...
def get_options(args=None):
parser = argparse.ArgumentParser(
description="Attention based model for solving the Travelling Salesman Problem with Reinforcement Learning")
# Data
parser.add_argument('--problem', default='tsp', help="The problem to solve, default 'tsp'")
parser.add_argument('--graph_size', type=int, default=20, help="The size of the problem graph")
parser.add_argument('--batch_size', type=int, default=512, help='Number of instances per batch during training')
parser.add_argument('--epoch_size', type=int, default=1280000, help='Number of instances per epoch during training')
parser.add_argument('--val_size', type=int, default=10000,
help='Number of instances used for reporting validation performance')
parser.add_argument('--val_dataset', type=str, default=None, help='Dataset file to use for validation')
...
opts = parser.parse_args(args)
...
return opts
上記より、基本的にはargparse
モジュールを用いて条件の指定を行っていることが確認できます。次にattention-learn-to-route/run.py
の強化学習に対応するコードについて確認します。
from train import train_epoch, validate, get_inner_model
from reinforce_baselines import NoBaseline, ExponentialBaseline, CriticBaseline, RolloutBaseline, WarmupBaseline
...
def run(opts):
...
# Initialize model
model_class = {
'attention': AttentionModel,
'pointer': PointerNetwork
}.get(opts.model, None)
assert model_class is not None, "Unknown model: {}".format(model_class)
model = model_class(
opts.embedding_dim,
opts.hidden_dim,
problem,
n_encode_layers=opts.n_encode_layers,
mask_inner=True,
mask_logits=True,
normalization=opts.normalization,
tanh_clipping=opts.tanh_clipping,
checkpoint_encoder=opts.checkpoint_encoder,
shrink_size=opts.shrink_size
).to(opts.device)
...
if opts.eval_only:
validate(model, val_dataset, opts)
else:
for epoch in range(opts.epoch_start, opts.epoch_start + opts.n_epochs):
train_epoch(
model,
optimizer,
baseline,
lr_scheduler,
epoch,
val_dataset,
problem,
tb_logger,
opts
)
ネットワークの学習にあたってはattention-learn-to-route/train.py
に実装されているtrain_epoch
関数が実行されます。
def train_epoch(model, optimizer, baseline, lr_scheduler, epoch, val_dataset, problem, tb_logger, opts):
print("Start train epoch {}, lr={} for run {}".format(epoch, optimizer.param_groups[0]['lr'], opts.run_name))
step = epoch * (opts.epoch_size // opts.batch_size)
start_time = time.time()
if not opts.no_tensorboard:
tb_logger.log_value('learnrate_pg0', optimizer.param_groups[0]['lr'], step)
# Generate new training data for each epoch
training_dataset = baseline.wrap_dataset(problem.make_dataset(
size=opts.graph_size, num_samples=opts.epoch_size, distribution=opts.data_distribution))
training_dataloader = DataLoader(training_dataset, batch_size=opts.batch_size, num_workers=1)
# Put model in train mode!
model.train()
set_decode_type(model, "sampling")
for batch_id, batch in enumerate(tqdm(training_dataloader, disable=opts.no_progress_bar)):
train_batch(
model,
optimizer,
baseline,
epoch,
batch_id,
step,
batch,
tb_logger,
opts
)
step += 1
epoch_duration = time.time() - start_time
print("Finished epoch {}, took {} s".format(epoch, time.strftime('%H:%M:%S', time.gmtime(epoch_duration))))
...
avg_reward = validate(model, val_dataset, opts)
if not opts.no_tensorboard:
tb_logger.log_value('val_avg_reward', avg_reward, step)
baseline.epoch_callback(model, epoch)
# lr_scheduler should be called at end of epoch
lr_scheduler.step()
def train_batch(
model,
optimizer,
baseline,
epoch,
batch_id,
step,
batch,
tb_logger,
opts
):
x, bl_val = baseline.unwrap_batch(batch)
x = move_to(x, opts.device)
bl_val = move_to(bl_val, opts.device) if bl_val is not None else None
# Evaluate model, get costs and log probabilities
cost, log_likelihood = model(x)
# Evaluate baseline, get baseline loss if any (only for critic)
bl_val, bl_loss = baseline.eval(x, cost) if bl_val is None else (bl_val, 0)
# Calculate loss
reinforce_loss = ((cost - bl_val) * log_likelihood).mean()
loss = reinforce_loss + bl_loss
# Perform backward pass and optimization step
optimizer.zero_grad()
loss.backward()
# Clip gradient norms and get (clipped) gradient norms for logging
grad_norms = clip_grad_norms(optimizer.param_groups, opts.max_grad_norm)
optimizer.step()
# Logging
if step % int(opts.log_step) == 0:
log_values(cost, grad_norms, epoch, batch_id, step,
log_likelihood, reinforce_loss, bl_loss, tb_logger, opts)
上記よりtrain_epoch
関数は内部でtrain_batch
関数を呼び出していることが確認できます。また、train_batch
関数内ではbaseline.eval
によってベースライン(bl_val
)を計算し、reinforce_loss = ((cost - bl_val) * log_likelihood).mean()
のようにlossを計算していることが確認できます。
ベースラインのオブジェクトの作成にあたってはopts.baseline
の値によって様々なクラスを用いますが、opts.baseline == rollout
の場合に対応するRolloutBaseline
クラスの実装は下記より確認できます。
class RolloutBaseline(Baseline):
def __init__(self, model, problem, opts, epoch=0):
super(Baseline, self).__init__()
self.problem = problem
self.opts = opts
self._update_model(model, epoch)
def _update_model(self, model, epoch, dataset=None):
self.model = copy.deepcopy(model)
# Always generate baseline dataset when updating model to prevent overfitting to the baseline dataset
if dataset is not None:
if len(dataset) != self.opts.val_size:
print("Warning: not using saved baseline dataset since val_size does not match")
dataset = None
elif (dataset[0] if self.problem.NAME == 'tsp' else dataset[0]['loc']).size(0) != self.opts.graph_size:
print("Warning: not using saved baseline dataset since graph_size does not match")
dataset = None
if dataset is None:
self.dataset = self.problem.make_dataset(
size=self.opts.graph_size, num_samples=self.opts.val_size, distribution=self.opts.data_distribution)
else:
self.dataset = dataset
print("Evaluating baseline model on evaluation dataset")
self.bl_vals = rollout(self.model, self.dataset, self.opts).cpu().numpy()
self.mean = self.bl_vals.mean()
self.epoch = epoch
def eval(self, x, c):
# Use volatile mode for efficient inference (single batch so we do not use rollout function)
with torch.no_grad():
v, _ = self.model(x)
# There is no loss
return v, 0