2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Monte Carlo Tree Search for solving a simple math problem

Last updated at Posted at 2019-03-18

Introduction

In this my first post on Qiita, I describe Monte Carlo tree search (MCTS) for solving a simple math problem other than Chess, Go, Shogi, etc. This is to see how MCTS works and prepare it for a general application, which is explained in a later post (See also the follow-up post).

Monte-Carlo Tree Search (MCTS)

MCTS is known as an efficient tree search algorithm for computational Go programs. In this post, we see how MCTS can be applied for solving a simple multi-variable equation.

MCTS starts from a root node and repeats a procedure below until it reaches terminal conditions:

mcts.jpg

  1. Selection: selects one terminal node using a rule
  2. Expansion: add additional nodes represents possible actions to a selected node.
  3. Playout: take random actions from expanded nodes and score values of these nodes based on the outcome of random actions.
  4. Backpropagation: reflect playout result to all parent nodes up to a root

Example problem

Let's assume a multi-variable equation in the following form and find a minimum using MCTS:

y = f(x) = (x_1 - a)^2 + (x_2 - b)^2 + (x_3 - c)^2 + (x_4 -d)^2

Generally, the form of f(x) is unknown in a type of problems that MCTS would be employed. In this simple example, we already know f(x) take a minimum value at (x1, x2, x3, x4) = (a, b, c, d). We would like MCTS to find a series of actions to reach this minimum. To apply this problem to MCTS, we define discrete actions for x:

  • Add x1 to +1
  • Add x1 to -1
  • Add x2 to +1
  • Add x2 to -1

...
etc.

Starting from arbitrary initial values of x and repeating actions, it results in a sequence of actions: (x1:+1, x2:-1, ....). MCTS tries to find an action sequence that maximizes a cumulative reward.

Reinforcement learning

A process to solve the problem can be seen as a reinforcement learning, where the relationship y=f(x) can be seen as "environment" receiving actions from MCTS agent and returning "reward" to MCTS agent.

rl.jpg

Pseudo code

The main part of the code has structure below. Since MCTS is probabilistic, it may or may not reach the actual minimum. Thus, a tree is generated multiple times to increase a chance to reach the real minimum and the best result is returned.

code_structure.jpg

Software versions

  • Python 3.5.2
  • Numpy 1.14.2+mkl

Code


from math import *
from copy import copy

import numpy as np
import random


class environment:
	def __init__(self):
		self.a = 0.
		self.b = 2.
		self.c = 1.
		self.d = 2.
		self.x1 = -12.
		self.x2 = 14.
		self.x3 = -13.
		self.x4 = 9.
		self.f = self.func()
	
	def func(self):
		return (self.x1 - self.a)**2+(self.x2 - self.b)**2+(self.x3 - self.c)**2+(self.x4 - self.d)**2
		
	def step(self, action):
		f_before = self.func()
		
		if action == 1:
			self.x1 += 1.
		elif action == -1:
			self.x1 -= 1.
		elif action ==  2:
			self.x2 += 1.
		elif action == -2:
			self.x2 -= 1.
		elif action ==  3:
			self.x3 += 1.
		elif action == -3:
			self.x3 -= 1.
		elif action == 4:
			self.x4 += 1.
		elif action == -4:
			self.x4 -= 1.
		
		f_after = self.func()
		if f_after == 0:
			reward = 100
			terminal = True
		else:
			reward = 1 if abs(f_before) - abs(f_after) > 0 else -1
			terminal = False
		
		return f_after, reward, terminal, [self.x1, self.x2, self.x3, self.x4]
	
	def action_space(self):
		return [-4, -3, -2, -1, 0, 1, 2, 3, 4]
		
	def sample(self):
		return np.random.choice([-4, -3, -2, -1, 0, 1, 2, 3, 4])

class Node():
	def __init__(self, parent, action):
		self.num_visits = 1
		self.reward = 0.
		self.children = []
		self.parent = parent
		self.action = action
	
	def update(self, reward):
		self.reward += reward
		self.visits += 1
	
	def __repr__(self):
		s="Reward/Visits =  %.1f/%.1f (Child %d)"%(self.reward, self.num_visits, len(self.children))
		return s

def ucb(node):
    return node.reward / node.num_visits + sqrt(log(node.parent.num_visits)/node.num_visits)

def reward_rate(node):
	return node.reward / node.num_visits

SHOW_INTERMEDIATE_RESULTS = False

env = environment()

num_mainloops = 20
max_playout_depth = 5
num_tree_search = 180

best_sum_reward = -inf
best_acdtion_sequence = []
best_f = 0
best_x = []

for _ in range(num_mainloops):
	root = Node(None, None)

	
	for run_no in range(num_tree_search):
		env_copy = copy(env)
		
		terminal = False
		sum_reward = 0
		
		# 1) Selection
		current_node = root
		while len(current_node.children) != 0:
				current_node = max(current_node.children, key = ucb)
				_, reward, terminal, _ = env_copy.step(current_node.action)
				sum_reward += reward
		
		# 2) Expansion
		if not terminal:
			possible_actions = env_copy.action_space()
			current_node.children = [Node(current_node, action) for action in possible_actions]		
		
		# Routine for each children hereafter
		
		for c in current_node.children:
			# 3) Playout
			env_playout = copy(env_copy)
			sum_reward_playout = 0
			action_sequence = []
			
			_, reward, terminal, _ = env_playout.step(c.action)
			sum_reward_playout += reward
			action_sequence.append(c.action)
			
			while not terminal:
				action = env_copy.sample()
				_, reward, terminal, _ = env_playout.step(action)
				sum_reward_playout += reward
				action_sequence.append(action)
				
				if len(action_sequence) > max_playout_depth:
					break
			
			if terminal:
				print("Terminal reached during a playout. #########")
			
			# 4) Backpropagate
			c_ = c
			while c_:
				c_.num_visits +=1
				c_.reward += sum_reward + sum_reward_playout
				c_ = c_.parent
		
	#Decision
	current_node = root
	action_sequence = []
	sum_reward = 0
	env_copy = copy(env)
	
	while len(current_node.children) != 0:
			current_node = max(current_node.children, key = reward_rate)
			action_sequence.append(current_node.action)
	
	for action in action_sequence:
		_, reward, terminal, _ = env_copy.step(action)
		sum_reward += reward
		if terminal:
			break
	
	f, _, _, x = env_copy.step(0)		
	
	if SHOW_INTERMEDIATE_RESULTS == True:
		print("Action sequence: ", str(action_sequence))
		print("Sum_reward: ", str(sum_reward))

		print("f, x (original): ", env.f , str([env.x1, env.x2]))
		print("f, x (after MCT): ", str(f), str(x) )
		print("----------")
	
	if sum_reward > best_sum_reward:
		best_sum_reward = sum_reward
		best_action_sequence = action_sequence
		best_f = f
		best_x = x
	
print("Best Action sequence: ", str(best_action_sequence))
print("Action sequence length: ", str(len(best_action_sequence)))
print("Best Sum_reward: ", str(best_sum_reward))
print("f, x (original): ", env.f , str([env.x1, env.x2, env.x3, env.x4]))
print("f, x (after MCTS): ", str(best_f), str(best_x) )

Result

Running the code above, it prints the best action sequence found by MCTS
f, x (original) -- corresponds initial state, and
f, x (after MCTS) --- shows final x take all actions and value of f(x).

Terminal reached during a playout. #########
Terminal reached during a playout. #########
...
Terminal reached during a playout. #########
Best Action sequence:  [3, -4, -4, 3, 3, 3, -4, -4, 3, 3, 3, -4, -2, -2, -4, -2,
 -2, -4, 0, 3, 1, -3, -2, 1, -4, -4, -3, 3, -1, -2, 3, 1, 1, 4, 1, 4, 2, -2, 1,
-2, 3, -3, 1, 2, -2, -3, 1, 3, -4, -4, -1, -2, -4, 1, 3, -2, 4, 3, 3, 1, -3, -2,
 1, -2, 1, 1, 4, -1, 1, 4, -4, -2, -4, 3, 3, 1, 2, -2, -4, -4, 4, -2, -3, 3, 4,
2, -1, -2, -1, -3, 4, 0, 3, -4, -3, -3, 3, -2, 3, 2, 2, -4, -2, -4, 1, -4, 4, -3
, 1, 3, -1, -2, 4, 3, -4, -1, 2, 4, 2, 1, 4, -4, 4, -3, 1, -1, 4, 0, -3, 2, 1, 4
, 1, -1, 3, 3, -3, -2, 3, 3]
Action sequence length:  140
Best Sum_reward:  141
f, x (original):  533.0 [-12.0, 14.0, -13.0, 9.0]
f, x (after MCTS):  0.0 [0.0, 2.0, 1.0, 2.0]

References

  1. https://www.youtube.com/watch?v=Fbs4lnGLS8M
  2. https://jeffbradberry.com/posts/2015/09/intro-to-monte-carlo-tree-search/
  3. https://en.wikipedia.org/wiki/Monte_Carlo_tree_search

Note

This code can be freely copied, modified, and used. Please cite if this post is useful.

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?