-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnode.py
More file actions
70 lines (60 loc) · 2.48 KB
/
node.py
File metadata and controls
70 lines (60 loc) · 2.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
class TreeNode:
"""A node in the MCTS tree. Each node keeps track of its own value Q,
prior probability P, and its visit-count-adjusted prior score u.
"""
def __init__(self, parent, prior_p):
self._parent = parent
self._children = {} # a map from action to TreeNode
self._n_visits = 0
self._Q = 0
self._u = 0
self._P = prior_p
def expand(self, actions, probs):
"""Expand tree by creating new children.
action_priors: a list of tuples of actions and their prior probability
according to the policy function.
"""
for action, prob in zip(actions, probs):
if action not in self._children:
self._children[action] = TreeNode(self, prob)
def select(self, c_puct):
"""Select action among children that gives maximum action value Q
plus bonus u(P).
Return: A tuple of (action, next_node)
"""
action, node = max(
self._children.items(), key=lambda act_node: act_node[1].get_value(c_puct)
)
return node, action
def update(self, leaf_value):
"""Update node values from leaf evaluation.
leaf_value: the value of subtree evaluation from the current player's
perspective.
"""
# Count visit.
self._n_visits += 1
# Update Q, a running average of values for all visits.
self._Q += 1.0 * (leaf_value - self._Q) / self._n_visits
def update_recursive(self, leaf_value):
"""Like a call to update(), but applied recursively for all ancestors."""
# If it is not root, this node's parent should be updated first.
if self._parent:
self._parent.update_recursive(-leaf_value)
self.update(leaf_value)
def get_value(self, c_puct):
"""Calculate and return the value for this node.
It is a combination of leaf evaluations Q, and this node's prior
adjusted for its visit count, u.
c_puct: a number in (0, inf) controlling the relative impact of
value Q, and prior probability P, on this node's score.
"""
self._u = (
c_puct * self._P * np.sqrt(self._parent._n_visits) / (1 + self._n_visits)
)
return self._Q + self._u
def is_leaf(self):
"""Check if leaf node (i.e. no nodes below this have been expanded)."""
return self._children == {}
def is_root(self):
return self._parent is None