Monte Carlo Tree Search for solving a simple math problem

Last updated at Posted at 2019-03-18


In this my first post on Qiita, I describe Monte Carlo tree search (MCTS) for solving a simple math problem other than Chess, Go, Shogi, etc. This is to see how MCTS works and prepare it for a general application, which is explained in a later post (See also the follow-up post).

Monte-Carlo Tree Search (MCTS)

MCTS is known as an efficient tree search algorithm for computational Go programs. In this post, we see how MCTS can be applied for solving a simple multi-variable equation.

MCTS starts from a root node and repeats a procedure below until it reaches terminal conditions:


1) Selection: selects one terminal node using a rule
2) Expansion: add additional nodes represents possible actions to a selected node.
3) Playout: take random actions from expanded nodes and score values of these nodes based on the outcome of random actions.
4) Backpropagation: reflect playout result to all parent nodes up to a root

Example problem

Let's assume a multi-variable equation in the following form and find a minimum using MCTS:

y = f(x) = (x_1 - a)^2 + (x_2 - b)^2 + (x_3 - c)^2 + (x_4 -d)^2

Generally, the form of f(x) is unknown in a type of problems that MCTS would be employed. In this simple example, we already know f(x) take a minimum value at (x1, x2, x3, x4) = (a, b, c, d). We would like MCTS to find a series of actions to reach this minimum. To apply this problem to MCTS, we define discrete actions for x:

  • Add x1 to +1
  • Add x1 to -1
  • Add x2 to +1
  • Add x2 to -1


Starting from arbitrary initial values of x and repeating actions, it results in a sequence of actions: (x1:+1, x2:-1, ....). MCTS tries to find an action sequence that maximizes a cumulative reward.

Reinforcement learning

A process to solve the problem can be seen as a reinforcement learning, where the relationship y=f(x) can be seen as "environment" receiving actions from MCTS agent and returning "reward" to MCTS agent.


Pseudo code

The main part of the code has structure below. Since MCTS is probabilistic, it may or may not reach the actual minimum. Thus, a tree is generated multiple times to increase a chance to reach the real minimum and the best result is returned.


Software versions

  • Python 3.5.2
  • Numpy 1.14.2+mkl


from math import *
from copy import copy

import numpy as np
import random

class environment:
    def __init__(self):
        self.a = 0.
        self.b = 2.
        self.c = 1.
        self.d = 2.
        self.x1 = -12.
        self.x2 = 14.
        self.x3 = -13.
        self.x4 = 9.
        self.f = self.func()

    def func(self):
        return (self.x1 - self.a)**2+(self.x2 - self.b)**2+(self.x3 - self.c)**2+(self.x4 - self.d)**2

    def step(self, action):
        f_before = self.func()

        if action == 1:
            self.x1 += 1.
        elif action == -1:
            self.x1 -= 1.
        elif action ==  2:
            self.x2 += 1.
        elif action == -2:
            self.x2 -= 1.
        elif action ==  3:
            self.x3 += 1.
        elif action == -3:
            self.x3 -= 1.
        elif action == 4:
            self.x4 += 1.
        elif action == -4:
            self.x4 -= 1.

        f_after = self.func()
        if f_after == 0:
            reward = 100
            terminal = True
            reward = 1 if abs(f_before) - abs(f_after) > 0 else -1
            terminal = False

        return f_after, reward, terminal, [self.x1, self.x2, self.x3, self.x4]

    def action_space(self):
        return [-4, -3, -2, -1, 0, 1, 2, 3, 4]

    def sample(self):
        return np.random.choice([-4, -3, -2, -1, 0, 1, 2, 3, 4])

class Node():
    def __init__(self, parent, action):
        self.num_visits = 1
        self.reward = 0.
        self.children = []
        self.parent = parent
        self.action = action

    def update(self, reward):
        self.reward += reward
        self.visits += 1

    def __repr__(self):
        s="Reward/Visits =  %.1f/%.1f (Child %d)"%(self.reward, self.num_visits, len(self.children))
        return s

def ucb(node):
    return node.reward / node.num_visits + sqrt(log(node.parent.num_visits)/node.num_visits)

def reward_rate(node):
    return node.reward / node.num_visits


env = environment()

num_mainloops = 20
max_playout_depth = 5
num_tree_search = 180

best_sum_reward = -inf
best_acdtion_sequence = []
best_f = 0
best_x = []

for _ in range(num_mainloops):
    root = Node(None, None)

    for run_no in range(num_tree_search):
        env_copy = copy(env)

        terminal = False
        sum_reward = 0

        # 1) Selection
        current_node = root
        while len(current_node.children) != 0:
                current_node = max(current_node.children, key = ucb)
                _, reward, terminal, _ = env_copy.step(current_node.action)
                sum_reward += reward

        # 2) Expansion
        if not terminal:
            possible_actions = env_copy.action_space()
            current_node.children = [Node(current_node, action) for action in possible_actions]     

        # Routine for each children hereafter

        for c in current_node.children:
            # 3) Playout
            env_playout = copy(env_copy)
            sum_reward_playout = 0
            action_sequence = []

            _, reward, terminal, _ = env_playout.step(c.action)
            sum_reward_playout += reward

            while not terminal:
                action = env_copy.sample()
                _, reward, terminal, _ = env_playout.step(action)
                sum_reward_playout += reward

                if len(action_sequence) > max_playout_depth:

            if terminal:
                print("Terminal reached during a playout. #########")

            # 4) Backpropagate
            c_ = c
            while c_:
                c_.num_visits +=1
                c_.reward += sum_reward + sum_reward_playout
                c_ = c_.parent

    current_node = root
    action_sequence = []
    sum_reward = 0
    env_copy = copy(env)

    while len(current_node.children) != 0:
            current_node = max(current_node.children, key = reward_rate)

    for action in action_sequence:
        _, reward, terminal, _ = env_copy.step(action)
        sum_reward += reward
        if terminal:

    f, _, _, x = env_copy.step(0)       

        print("Action sequence: ", str(action_sequence))
        print("Sum_reward: ", str(sum_reward))

        print("f, x (original): ", env.f , str([env.x1, env.x2]))
        print("f, x (after MCT): ", str(f), str(x) )

    if sum_reward > best_sum_reward:
        best_sum_reward = sum_reward
        best_action_sequence = action_sequence
        best_f = f
        best_x = x

print("Best Action sequence: ", str(best_action_sequence))
print("Action sequence length: ", str(len(best_action_sequence)))
print("Best Sum_reward: ", str(best_sum_reward))
print("f, x (original): ", env.f , str([env.x1, env.x2, env.x3, env.x4]))
print("f, x (after MCTS): ", str(best_f), str(best_x) )


Running the code above, it prints the best action sequence found by MCTS
f, x (original) -- corresponds initial state, and
f, x (after MCTS) --- shows final x take all actions and value of f(x).

Terminal reached during a playout. #########
Terminal reached during a playout. #########
Terminal reached during a playout. #########
Best Action sequence:  [3, -4, -4, 3, 3, 3, -4, -4, 3, 3, 3, -4, -2, -2, -4, -2,
 -2, -4, 0, 3, 1, -3, -2, 1, -4, -4, -3, 3, -1, -2, 3, 1, 1, 4, 1, 4, 2, -2, 1,
-2, 3, -3, 1, 2, -2, -3, 1, 3, -4, -4, -1, -2, -4, 1, 3, -2, 4, 3, 3, 1, -3, -2,
 1, -2, 1, 1, 4, -1, 1, 4, -4, -2, -4, 3, 3, 1, 2, -2, -4, -4, 4, -2, -3, 3, 4,
2, -1, -2, -1, -3, 4, 0, 3, -4, -3, -3, 3, -2, 3, 2, 2, -4, -2, -4, 1, -4, 4, -3
, 1, 3, -1, -2, 4, 3, -4, -1, 2, 4, 2, 1, 4, -4, 4, -3, 1, -1, 4, 0, -3, 2, 1, 4
, 1, -1, 3, 3, -3, -2, 3, 3]
Action sequence length:  140
Best Sum_reward:  141
f, x (original):  533.0 [-12.0, 14.0, -13.0, 9.0]
f, x (after MCTS):  0.0 [0.0, 2.0, 1.0, 2.0]


This code can be freely copied, modified, and used. Please cite if this post is useful.


