0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【最適化】Attention, Learn to Solve Routing Problems!の論文とPyTorch実装

Last updated at Posted at 2025-05-27

当記事ではAttentionを用いたTransformerライクなGNN(Graph Neural Network)を用いてセールスマン巡回問題(TSP)や配送計画問題(VRP)に取り組んだ研究であるAttention, Learn to Solve Routing Problems!の論文やPyTorch実装について確認を行いました。

Attention, Learn to Solve Routing Problems!論文の確認

論文の概要

Attention, Learn to Solve Routing Problems!はAttentionに基づくTransformerライクなGNNを用いてルート最適化について取り組んだ研究です。

AttentionVRP1.png
Encoderの処理概要(Attention, Learn to Solve Routing Problems!論文 Figure.1)

Attention, Learn to Solve Routing Problems!では上図のようなGNNに基づくEncoderが用いられています。MHA(Multi Head Attention)とFF(Feed-Forward)の組み合わせである点ではTransformerに類似した処理であると言えます。

AttentionVRP2.png
Decoderの処理概要(Attention, Learn to Solve Routing Problems!論文 Figure.2)

同様にDecoderでは上図のような処理が行われます。図だけでは処理の詳細まで確認できないので次項ではEncoderとDecoderの式定義について詳しく確認します。

EncoderとDecoderの式定義

Encoderの式定義

Encoderの式定義は基本的にはTransformerの式定義に類似しています。下記では出力に近い順に数式をまとめました。

\begin{align}
\bar{h}^{N} &= \sum_{i=1}^{n} h_{i}^{N} \\
\hat{h}_{i} &= \mathrm{BN}^{l} \left[ h_{i}^{l-1} + \mathrm{MHA}_{i}^{l} \left( h_{1}^{l-1}, \cdots , h_{n}^{l-1} \right) \right] \\
h_{i}^{l} &= \mathrm{BN}^{l} \left[ \hat{h}_{i} + \mathrm{FF}^{l} \left( \hat{h}_{i} \right) \right] \\
l & \in \{ 1, \cdots , N \} \\
h_{i}^{0} &= W^{x} \mathbf{x}_{i} + \mathbf{b}^{x}, \, \mathbf{x}_{i} \in \mathbb{R}^{d_x}, \, h_{i}^{l} \in \mathbb{R}^{d_h}
\end{align}

Encoderは概ね上記のように式定義されます。1行目のグラフの特徴量ベクトル$\bar{h}^{N} \in \mathbb{R}^{d_h}$を抽出するにあたって2行目や3行目のMHAやFFの演算を繰り返した結果を用います。$N$はGNNのレイヤーの数、$n$はノードの数にそれぞれ対応することに注意しておくと良いです。また、入力の次元数の$d_x$はセールスマン巡回問題(TSP; Travelling Salesperson Problem)の場合$d_x=2$、中間層の特徴量ベクトルの次元数の$d_h$ではこの研究では$d_h=128$に指定されます。

Decoderの式定義

セールスマン巡回問題などの問題設定のインスタンスの$s$が与えられた場合、Decoderの式の理解にあたっては訪れるノードの順番(tour)の$\boldsymbol{\pi}$を下記のように定義することにまず着目すると良いです。

\begin{align}
\boldsymbol{\pi} = (\pi_{1}, \pi_{2}, \cdots , \pi_{n})
\end{align}

上記の$\pi_{t}, t \in {1, \cdots , n}$は$t$番目に訪れる場所のように解釈すると良いです。Decoderでは$t$番目に訪れる場所が$i$である確率を下記のように定義します。

\begin{align}
p_{\theta}(\pi_{t}=i | s, \boldsymbol{\pi}_{1:t-1}) &= \frac{e^{u_{(c)i}}}{\sum_{j} e^{u_{(c)j}}}
\end{align}

上記の$u_{(c)j}$は下記のように定義される関数です。

\begin{align}
u_{(c)j} &= \left\{ \,
    \begin{aligned}
    & C \cdot \tanh{\left( \frac{q_{(c)}^{\mathrm{T}} k_{j}}{\sqrt{d_{k}}} \right)} \quad \mathrm{if} \, j \neq \pi_{t'} \,\, {}^{\forall} t' < t \\
    & - \infty \qquad \mathrm{otherwise}
    \end{aligned}
\right. \\
-10 \, & \leq \, C \, \leq \, 10
\end{align}

上記の$-\infty$は既に訪れたノードに二度訪れないようにマスキングする意図で設定されています。また、$u_{(c)j}$に出てくるQueryの$q_{(c)}$とKeyの$k_{j}$はそれぞれ下記のように計算されます。

\begin{align}
q_{(c)} &= W^{Q} h_{(c)} \in \mathbb{R}^{d_{h}} \\ 
k_{j} &= W^{K} h_{i} \in \mathbb{R}^{d_{h}} \\
v_{i} &= W^{V} h_{i} \in \mathbb{R}^{d_{h}}
\end{align}

上記に出てくる$h_{(c)} \in \mathbb{R}^{3 d_{h}}$は下記のような式で定義されます。

\begin{align}
h_{(c)} = \left\{ \,
    \begin{aligned}
    & \left[ \bar{h}^{N}, h_{\pi_{t-1}}^{N}, h_{\pi_{1}} \right] \quad t>1 \\
    & \left[ \bar{h}^{N}, v^{l}, v^{f} \right] \qquad t=1
    \end{aligned}
\right.
\end{align}

上記の$[\cdot , \cdot , \cdot]$はベクトルの連結(concat)にあたっての演算子です。

方策勾配法(Policy Gradint)に基づくネットワークの学習

前項のネットワークの学習にあたって、セールスマン巡回問題では厳密な正解を用意するのが難しいです。この際に方策勾配法(Policy Gradient)基づく強化学習でネットワークの学習を行うことは原理的に可能であり、この研究では方策勾配法の手法の1つであるREINFORCEMENTが用いられています。このようにニューラルネットワークの学習にあたって方策勾配法を用いる手法はLLMのRLHF(Reinforcement Learning with Human Feedback)でも用いられているので合わせて抑えておくと良いです。

Attention, Learn to Solve Routing Problems!の論文では下記のように目的関数や目的関数の勾配の式を定義・導出し、勾配降下法(Gradient Descent)を用いて学習を行います。

\begin{align}
L(\theta|s) &= \mathbb{E}_{(\boldsymbol{\pi}|s)} [L(\boldsymbol{\pi})] \\
\nabla L(\theta|s) &= \mathbb{E}_{(\boldsymbol{\pi}|s)} \left[ (L(\boldsymbol{\pi}) - b(s)) \nabla \log{p_{\theta} (\boldsymbol{\pi}|s)} \right]
\end{align}

上記の勾配の$\nabla L(\theta|s)$の式はlog勾配Trickなどを用いることで導出することができます。また、$b(s)$はベースライン(baseline)でこの研究では貪欲法(現在地に近いノードから選択する方法)に基づいて用意されます(greedy rollout)。

\begin{align}
b(s) &= M_{i} \\
M_{0} &= L(\boldsymbol{\pi}_{0}) \\
M_{i+1} &= \beta M_{i} + (1 - \beta) L(\boldsymbol{\pi}_{i})
\end{align}

ベースラインは上記のように計算されます。論文の式ではインデックスがありませんでしたが可読性を上げるにあたって当記事ではインデックスを追記しました。

また、Algorithmについては論文に記載があるので合わせて確認しておくと良いと思います。

AttentionVRP3.png
REINFORCEMENTを用いた学習のアルゴリズム(Attention, Learn to Solve Routing Problems!論文 Algorithm.1)

Attention, Learn to Solve Routing Problems!のPyTorch実装の確認

PyTorch実装の入手とdemoの実行

Attention, Learn to Solve Routing Problems!の実装は上記より入手することができます。demo用のコードは用意されていないので、attention-learn-to-route/eval.pyを下記のように改変して実行すると出力がわかりやすいと思います。

attention-learn-to-route/eval.py
if __name__ == "__main__":
    ...
    for width in widths:
        for dataset_path in opts.datasets:
            costs, tours, durations = eval_dataset(dataset_path, width, opts.softmax_temperature, opts)
            
    print("Cost0-9: {}, Tours0-9:{}, num_estimates: {}".format(costs[0:10], tours[0:10], len(costs)))

・実行コマンド

$ python eval.py -f data/tsp/tsp20_test_seed1234.pkl --model pretrained/tsp_20 --decode_strategy greedy

・実行結果

Average cost: 3.8469574451446533 +- 0.006181415915489197
Average serial duration: 0.0451378002166748 +- 0.0012258773853210062
Average parallel duration: 4.4079883024096486e-05
Calculated total duration: 0:00:00
Cost0-9: (3.66383, 3.7619956, 3.7999754, 4.118915, 3.294455, 3.8607297, 3.8752038, 3.9516256, 4.1289177, 4.0149283), Tours0-9:([16, 2, 15, 12, 4, 9, 17, 6, 1, 13, 19, 18, 3, 0, 14, 10, 7, 5, 11, 8], [6, 16, 11, 4, 10, 18, 2, 15, 9, 8, 7, 13, 17, 1, 5, 3, 19, 0, 14, 12], [12, 17, 11, 1, 19, 10, 3, 8, 6, 16, 14, 9, 13, 5, 0, 7, 4, 15, 18, 2], [8, 6, 11, 15, 18, 0, 17, 1, 19, 9, 5, 3, 10, 12, 4, 7, 16, 2, 14, 13], [15, 14, 5, 3, 10, 9, 11, 2, 1, 4, 19, 13, 8, 6, 7, 16, 0, 17, 12, 18], [7, 19, 2, 0, 11, 16, 1, 9, 18, 6, 12, 5, 15, 3, 14, 4, 10, 13, 17, 8], [13, 9, 14, 18, 12, 17, 10, 7, 15, 5, 11, 8, 4, 3, 1, 19, 2, 0, 16, 6], [4, 14, 19, 3, 6, 8, 18, 10, 5, 13, 17, 0, 16, 11, 2, 7, 12, 1, 9, 15], [2, 19, 8, 12, 9, 0, 10, 18, 15, 6, 13, 17, 5, 7, 4, 14, 16, 1, 3, 11], [15, 5, 14, 2, 17, 8, 11, 18, 3, 13, 12, 16, 6, 19, 1, 7, 10, 9, 0, 4]), num_estimates: 10000

実行結果より、attention-learn-to-route/eval.pyを実行することでdata/tsp/tsp20_test_seed1234.pklから10,000のセールスマン巡回問題をロードし、それぞれについて予測を行っていることが確認できます。

modelの実装の確認

以下attention-learn-to-route/eval.pyの内容を元にmodelの実装について確認します。

attention-learn-to-route/eval.py
from utils import load_model, move_to

...

def eval_dataset(dataset_path, width, softmax_temp, opts):
    # Even with multiprocessing, we load the model here since it contains the name where to write results
    #print("multiprocessing: {}".format(opts.multiprocessing))
    
    model, _ = load_model(opts.model)

    ...

    results = _eval_dataset(model, dataset, width, softmax_temp, opts, device)

    ...

    costs, tours, durations = zip(*results)  # Not really costs since they should be negative

    ...

    return costs, tours, durations
    
...

def _eval_dataset(model, dataset, width, softmax_temp, opts, device):

    model.to(device)
    model.eval()

    ...

    dataloader = DataLoader(dataset, batch_size=opts.eval_batch_size)

    results = []
    for batch in tqdm(dataloader, disable=opts.no_progress_bar):
        batch = move_to(batch, device)

        start = time.time()
        with torch.no_grad():
            if opts.decode_strategy in ('sample', 'greedy'):
                
                ...
                
                # This returns (batch_size, iter_rep shape)
                sequences, costs = model.sample_many(batch, batch_rep=batch_rep, iter_rep=iter_rep)
                
        ...        
                
        if sequences is None:
            sequences = [None] * batch_size
            costs = [math.inf] * batch_size
        else:
            sequences, costs = get_best(
                sequences.cpu().numpy(), costs.cpu().numpy(),
                ids.cpu().numpy() if ids is not None else None,
                batch_size
            )
            
        duration = time.time() - start
        for seq, cost in zip(sequences, costs):
            if model.problem.NAME == "tsp":
                seq = seq.tolist()  # No need to trim as all are same length
            elif model.problem.NAME in ("cvrp", "sdvrp"):
                seq = np.trim_zeros(seq).tolist() + [0]  # Add depot
            elif model.problem.NAME in ("op", "pctsp"):
                seq = np.trim_zeros(seq)  # We have the convention to exclude the depot
            else:
                assert False, "Unkown problem: {}".format(model.problem.NAME)
            # Note VRP only
            results.append((cost, seq, duration))

    return results

上記より、modelのロードはattention-learn-to-route/utils/functions.pyに実装されているload_model関数を用いて行われていることが確認できます。引数のopts.modelReadME.mdの実行例ではpretrained/tsp_20が用いられます。

attention-learn-to-route/utils/functions.py
def load_model(path, epoch=None):
    from nets.attention_model import AttentionModel
    from nets.pointer_network import PointerNetwork

    ...

    model_class = {
        'attention': AttentionModel,
        'pointer': PointerNetwork
    }.get(args.get('model', 'attention'), None)
    assert model_class is not None, "Unknown model: {}".format(model_class)

    model = model_class(
        args['embedding_dim'],
        args['hidden_dim'],
        problem,
        n_encode_layers=args['n_encode_layers'],
        mask_inner=True,
        mask_logits=True,
        normalization=args['normalization'],
        tanh_clipping=args['tanh_clipping'],
        checkpoint_encoder=args.get('checkpoint_encoder', False),
        shrink_size=args.get('shrink_size', None)
    )
    
    ...
    
    return model, args

上記より、AttentionModelPointerNetworkがロードされていることが確認できます。当記事では以下attention-learn-to-route/nets/attention_model.pyに実装されているAttentionModelクラスについて確認します。

attention-learn-to-route/nets/attention_model.py
from nets.graph_encoder import GraphAttentionEncoder

class AttentionModel(nn.Module):

    def __init__(self,
                 embedding_dim,
                 hidden_dim,
                 problem,
                 n_encode_layers=2,
                 tanh_clipping=10.,
                 mask_inner=True,
                 mask_logits=True,
                 normalization='batch',
                 n_heads=8,
                 checkpoint_encoder=False,
                 shrink_size=None):

        ...

        self.embedder = GraphAttentionEncoder(
            n_heads=n_heads,
            embed_dim=embedding_dim,
            n_layers=self.n_encode_layers,
            normalization=normalization
        )

        ...
        
    def forward(self, input, return_pi=False):
        """
        :param input: (batch_size, graph_size, node_dim) input node features or dictionary with multiple tensors
        :param return_pi: whether to return the output sequences, this is optional as it is not compatible with
        using DataParallel as the results may be of different lengths on different GPUs
        :return:
        """

        if self.checkpoint_encoder and self.training:  # Only checkpoint if we need gradients
            embeddings, _ = checkpoint(self.embedder, self._init_embed(input))
        else:
            embeddings, _ = self.embedder(self._init_embed(input))

        _log_p, pi = self._inner(input, embeddings)

        cost, mask = self.problem.get_costs(input, pi)
        # Log likelyhood is calculated within the model since returning it per action does not work well with
        # DataParallel since sequences can be of different lengths
        ll = self._calc_log_likelihood(_log_p, pi, mask)
        if return_pi:
            return cost, ll, pi

        return cost, ll

上記より、エンコーダはattention-learn-to-route/nets/graph_encoder.pyGraphAttentionEncoderクラスで実装されていることが確認できます。

attention-learn-to-route/nets/graph_encoder.py
class MultiHeadAttentionLayer(nn.Sequential):

    def __init__(
            self,
            n_heads,
            embed_dim,
            feed_forward_hidden=512,
            normalization='batch',
    ):
        super(MultiHeadAttentionLayer, self).__init__(
            SkipConnection(
                MultiHeadAttention(
                    n_heads,
                    input_dim=embed_dim,
                    embed_dim=embed_dim
                )
            ),
            Normalization(embed_dim, normalization),
            SkipConnection(
                nn.Sequential(
                    nn.Linear(embed_dim, feed_forward_hidden),
                    nn.ReLU(),
                    nn.Linear(feed_forward_hidden, embed_dim)
                ) if feed_forward_hidden > 0 else nn.Linear(embed_dim, embed_dim)
            ),
            Normalization(embed_dim, normalization)
        )
        
class GraphAttentionEncoder(nn.Module):
    def __init__(
            self,
            n_heads,
            embed_dim,
            n_layers,
            node_dim=None,
            normalization='batch',
            feed_forward_hidden=512
    ):
        super(GraphAttentionEncoder, self).__init__()

        # To map input to embedding space
        self.init_embed = nn.Linear(node_dim, embed_dim) if node_dim is not None else None

        self.layers = nn.Sequential(*(
            MultiHeadAttentionLayer(n_heads, embed_dim, feed_forward_hidden, normalization)
            for _ in range(n_layers)
        ))

    def forward(self, x, mask=None):

        assert mask is None, "TODO mask not yet supported!"

        # Batch multiply to get initial embeddings of nodes
        h = self.init_embed(x.view(-1, x.size(-1))).view(*x.size()[:2], -1) if self.init_embed is not None else x

        h = self.layers(h)

        return (
            h,  # (batch_size, graph_size, embed_dim)
            h.mean(dim=1),  # average to get embedding of graph, (batch_size, embed_dim)
        )

上記よりGraphAttentionEncoderクラスでは、n_layersの数だけMultiHeadAttentionLayerを用いてグラフニューラルネットワークの層を構築していることが確認できます。また、デコーダはAttentionModelクラスの_innerメソッド内で実装されています。

attention-learn-to-route/nets/attention_model.py
class AttentionModel(nn.Module):

    ...
    
    def _inner(self, input, embeddings):

        outputs = []
        sequences = []

        state = self.problem.make_state(input)

        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        fixed = self._precompute(embeddings)

        batch_size = state.ids.size(0)

        # Perform decoding steps
        i = 0
        while not (self.shrink_size is None and state.all_finished()):

            if self.shrink_size is not None:
                unfinished = torch.nonzero(state.get_finished() == 0)
                if len(unfinished) == 0:
                    break
                unfinished = unfinished[:, 0]
                # Check if we can shrink by at least shrink_size and if this leaves at least 16
                # (otherwise batch norm will not work well and it is inefficient anyway)
                if 16 <= len(unfinished) <= state.ids.size(0) - self.shrink_size:
                    # Filter states
                    state = state[unfinished]
                    fixed = fixed[unfinished]

            log_p, mask = self._get_log_p(fixed, state)

            # Select the indices of the next nodes in the sequences, result (batch_size) long
            selected = self._select_node(log_p.exp()[:, 0, :], mask[:, 0, :])  # Squeeze out steps dimension

            state = state.update(selected)

            # Now make log_p, selected desired output size by 'unshrinking'
            if self.shrink_size is not None and state.ids.size(0) < batch_size:
                log_p_, selected_ = log_p, selected
                log_p = log_p_.new_zeros(batch_size, *log_p_.size()[1:])
                selected = selected_.new_zeros(batch_size)

                log_p[state.ids[:, 0]] = log_p_
                selected[state.ids[:, 0]] = selected_

            # Collect output of step
            outputs.append(log_p[:, 0, :])
            sequences.append(selected)

            i += 1

        # Collected lists, return Tensor
        return torch.stack(outputs, 1), torch.stack(sequences, 1)
        
    def _get_log_p(self, fixed, state, normalize=True):

        # Compute query = context node embedding
        query = fixed.context_node_projected + \
                self.project_step_context(self._get_parallel_step_context(fixed.node_embeddings, state))

        # Compute keys and values for the nodes
        glimpse_K, glimpse_V, logit_K = self._get_attention_node_data(fixed, state)

        # Compute the mask
        mask = state.get_mask()

        # Compute logits (unnormalized log_p)
        log_p, glimpse = self._one_to_many_logits(query, glimpse_K, glimpse_V, logit_K, mask)

        if normalize:
            log_p = torch.log_softmax(log_p / self.temp, dim=-1)

        assert not torch.isnan(log_p).any()

        return log_p, mask

学習の実装の確認

以下、当項では強化学習(方策勾配法)に基づく学習の実装について取りまとめます。まずReadME.mdで学習コマンドに用いられるattention-learn-to-route/run.pyについて簡単に確認します。

attention-learn-to-route/run.py
from options import get_options

...

if __name__ == "__main__":
    run(get_options())

まず、上記よりattention-learn-to-route/run.pyではattention-learn-to-route/options.pyから学習時の条件を読み込んでいることが確認できます。

attention-learn-to-route/option.py
import argparse

...

def get_options(args=None):
    parser = argparse.ArgumentParser(
        description="Attention based model for solving the Travelling Salesman Problem with Reinforcement Learning")

    # Data
    parser.add_argument('--problem', default='tsp', help="The problem to solve, default 'tsp'")
    parser.add_argument('--graph_size', type=int, default=20, help="The size of the problem graph")
    parser.add_argument('--batch_size', type=int, default=512, help='Number of instances per batch during training')
    parser.add_argument('--epoch_size', type=int, default=1280000, help='Number of instances per epoch during training')
    parser.add_argument('--val_size', type=int, default=10000,
                        help='Number of instances used for reporting validation performance')
    parser.add_argument('--val_dataset', type=str, default=None, help='Dataset file to use for validation')
    
    ...
    
    opts = parser.parse_args(args)
    
    ...
    
    return opts

上記より、基本的にはargparseモジュールを用いて条件の指定を行っていることが確認できます。次にattention-learn-to-route/run.pyの強化学習に対応するコードについて確認します。

attention-learn-to-route/run.py
from train import train_epoch, validate, get_inner_model
from reinforce_baselines import NoBaseline, ExponentialBaseline, CriticBaseline, RolloutBaseline, WarmupBaseline

...

def run(opts):

    ...

    # Initialize model
    model_class = {
        'attention': AttentionModel,
        'pointer': PointerNetwork
    }.get(opts.model, None)
    assert model_class is not None, "Unknown model: {}".format(model_class)
    model = model_class(
        opts.embedding_dim,
        opts.hidden_dim,
        problem,
        n_encode_layers=opts.n_encode_layers,
        mask_inner=True,
        mask_logits=True,
        normalization=opts.normalization,
        tanh_clipping=opts.tanh_clipping,
        checkpoint_encoder=opts.checkpoint_encoder,
        shrink_size=opts.shrink_size
    ).to(opts.device)

    ...

    if opts.eval_only:
        validate(model, val_dataset, opts)
    else:
        for epoch in range(opts.epoch_start, opts.epoch_start + opts.n_epochs):
            train_epoch(
                model,
                optimizer,
                baseline,
                lr_scheduler,
                epoch,
                val_dataset,
                problem,
                tb_logger,
                opts
            )

ネットワークの学習にあたってはattention-learn-to-route/train.pyに実装されているtrain_epoch関数が実行されます。

attention-learn-to-route/train.py
def train_epoch(model, optimizer, baseline, lr_scheduler, epoch, val_dataset, problem, tb_logger, opts):
    print("Start train epoch {}, lr={} for run {}".format(epoch, optimizer.param_groups[0]['lr'], opts.run_name))
    step = epoch * (opts.epoch_size // opts.batch_size)
    start_time = time.time()

    if not opts.no_tensorboard:
        tb_logger.log_value('learnrate_pg0', optimizer.param_groups[0]['lr'], step)

    # Generate new training data for each epoch
    training_dataset = baseline.wrap_dataset(problem.make_dataset(
        size=opts.graph_size, num_samples=opts.epoch_size, distribution=opts.data_distribution))
    training_dataloader = DataLoader(training_dataset, batch_size=opts.batch_size, num_workers=1)

    # Put model in train mode!
    model.train()
    set_decode_type(model, "sampling")

    for batch_id, batch in enumerate(tqdm(training_dataloader, disable=opts.no_progress_bar)):

        train_batch(
            model,
            optimizer,
            baseline,
            epoch,
            batch_id,
            step,
            batch,
            tb_logger,
            opts
        )

        step += 1

    epoch_duration = time.time() - start_time
    print("Finished epoch {}, took {} s".format(epoch, time.strftime('%H:%M:%S', time.gmtime(epoch_duration))))

    ...

    avg_reward = validate(model, val_dataset, opts)

    if not opts.no_tensorboard:
        tb_logger.log_value('val_avg_reward', avg_reward, step)

    baseline.epoch_callback(model, epoch)

    # lr_scheduler should be called at end of epoch
    lr_scheduler.step()


def train_batch(
        model,
        optimizer,
        baseline,
        epoch,
        batch_id,
        step,
        batch,
        tb_logger,
        opts
):
    x, bl_val = baseline.unwrap_batch(batch)
    x = move_to(x, opts.device)
    bl_val = move_to(bl_val, opts.device) if bl_val is not None else None

    # Evaluate model, get costs and log probabilities
    cost, log_likelihood = model(x)

    # Evaluate baseline, get baseline loss if any (only for critic)
    bl_val, bl_loss = baseline.eval(x, cost) if bl_val is None else (bl_val, 0)

    # Calculate loss
    reinforce_loss = ((cost - bl_val) * log_likelihood).mean()
    loss = reinforce_loss + bl_loss

    # Perform backward pass and optimization step
    optimizer.zero_grad()
    loss.backward()
    # Clip gradient norms and get (clipped) gradient norms for logging
    grad_norms = clip_grad_norms(optimizer.param_groups, opts.max_grad_norm)
    optimizer.step()

    # Logging
    if step % int(opts.log_step) == 0:
        log_values(cost, grad_norms, epoch, batch_id, step,
                   log_likelihood, reinforce_loss, bl_loss, tb_logger, opts)

上記よりtrain_epoch関数は内部でtrain_batch関数を呼び出していることが確認できます。また、train_batch関数内ではbaseline.evalによってベースライン(bl_val)を計算し、reinforce_loss = ((cost - bl_val) * log_likelihood).mean()のようにlossを計算していることが確認できます。

ベースラインのオブジェクトの作成にあたってはopts.baselineの値によって様々なクラスを用いますが、opts.baseline == rolloutの場合に対応するRolloutBaselineクラスの実装は下記より確認できます。

attention-learn-to-route/reinforce_baselines.py
class RolloutBaseline(Baseline):

    def __init__(self, model, problem, opts, epoch=0):
        super(Baseline, self).__init__()

        self.problem = problem
        self.opts = opts

        self._update_model(model, epoch)

    def _update_model(self, model, epoch, dataset=None):
        self.model = copy.deepcopy(model)
        # Always generate baseline dataset when updating model to prevent overfitting to the baseline dataset

        if dataset is not None:
            if len(dataset) != self.opts.val_size:
                print("Warning: not using saved baseline dataset since val_size does not match")
                dataset = None
            elif (dataset[0] if self.problem.NAME == 'tsp' else dataset[0]['loc']).size(0) != self.opts.graph_size:
                print("Warning: not using saved baseline dataset since graph_size does not match")
                dataset = None

        if dataset is None:
            self.dataset = self.problem.make_dataset(
                size=self.opts.graph_size, num_samples=self.opts.val_size, distribution=self.opts.data_distribution)
        else:
            self.dataset = dataset
        print("Evaluating baseline model on evaluation dataset")
        self.bl_vals = rollout(self.model, self.dataset, self.opts).cpu().numpy()
        self.mean = self.bl_vals.mean()
        self.epoch = epoch

    def eval(self, x, c):
        # Use volatile mode for efficient inference (single batch so we do not use rollout function)
        with torch.no_grad():
            v, _ = self.model(x)

        # There is no loss
        return v, 0
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?