Commit 9b50ea9a authored by Junpeng Lao's avatar Junpeng Lao

Rename property name in NUTS to better reflect the computation

parent 8f74fcb8
...@@ -206,11 +206,12 @@ class NUTS(BaseHMC): ...@@ -206,11 +206,12 @@ class NUTS(BaseHMC):
# A proposal for the next position # A proposal for the next position
Proposal = namedtuple("Proposal", "q, q_grad, energy, log_p_accept, logp") Proposal = namedtuple("Proposal", "q, q_grad, energy, log_p_accept_weighted, logp")
# A subtree of the binary tree built by nuts. # A subtree of the binary tree built by nuts.
Subtree = namedtuple( Subtree = namedtuple(
"Subtree", "left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals" "Subtree",
"left, right, p_sum, proposal, log_size, log_weighted_accept_sum, n_proposals",
) )
...@@ -243,7 +244,7 @@ class _Tree: ...@@ -243,7 +244,7 @@ class _Tree:
) )
self.depth = 0 self.depth = 0
self.log_size = 0 self.log_size = 0
self.log_accept_sum = -np.inf self.log_weighted_accept_sum = -np.inf
self.mean_tree_accept = 0.0 self.mean_tree_accept = 0.0
self.n_proposals = 0 self.n_proposals = 0
self.p_sum = start.p.copy() self.p_sum = start.p.copy()
...@@ -291,7 +292,9 @@ class _Tree: ...@@ -291,7 +292,9 @@ class _Tree:
self.proposal = tree.proposal self.proposal = tree.proposal
self.log_size = np.logaddexp(self.log_size, tree.log_size) self.log_size = np.logaddexp(self.log_size, tree.log_size)
self.log_accept_sum = np.logaddexp(self.log_accept_sum, tree.log_accept_sum) self.log_weighted_accept_sum = np.logaddexp(
self.log_weighted_accept_sum, tree.log_weighted_accept_sum
)
self.p_sum[:] += tree.p_sum self.p_sum[:] += tree.p_sum
# Additional turning check only when tree depth > 0 to avoid redundant work # Additional turning check only when tree depth > 0 to avoid redundant work
...@@ -331,13 +334,17 @@ class _Tree: ...@@ -331,13 +334,17 @@ class _Tree:
# e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)}) # e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)})
# Saturated Metropolis accept probability with Boltzmann weight # Saturated Metropolis accept probability with Boltzmann weight
# if h - H0 < 0 # if h - H0 < 0
log_p_accept = -energy_change + min(0.0, -energy_change) log_p_accept_weighted = -energy_change + min(0.0, -energy_change)
log_size = -energy_change log_size = -energy_change
proposal = Proposal( proposal = Proposal(
right.q, right.q_grad, right.energy, log_p_accept, right.model_logp right.q,
right.q_grad,
right.energy,
log_p_accept_weighted,
right.model_logp,
) )
tree = Subtree( tree = Subtree(
right, right, right.p, proposal, log_size, log_p_accept, 1 right, right, right.p, proposal, log_size, log_p_accept_weighted, 1
) )
return tree, None, False return tree, None, False
else: else:
...@@ -377,7 +384,9 @@ class _Tree: ...@@ -377,7 +384,9 @@ class _Tree:
turning = turning | turning1 | turning2 turning = turning | turning1 | turning2
log_size = np.logaddexp(tree1.log_size, tree2.log_size) log_size = np.logaddexp(tree1.log_size, tree2.log_size)
log_accept_sum = np.logaddexp(tree1.log_accept_sum, tree2.log_accept_sum) log_weighted_accept_sum = np.logaddexp(
tree1.log_weighted_accept_sum, tree2.log_weighted_accept_sum
)
if logbern(tree2.log_size - log_size): if logbern(tree2.log_size - log_size):
proposal = tree2.proposal proposal = tree2.proposal
else: else:
...@@ -385,13 +394,13 @@ class _Tree: ...@@ -385,13 +394,13 @@ class _Tree:
else: else:
p_sum = tree1.p_sum p_sum = tree1.p_sum
log_size = tree1.log_size log_size = tree1.log_size
log_accept_sum = tree1.log_accept_sum log_weighted_accept_sum = tree1.log_weighted_accept_sum
proposal = tree1.proposal proposal = tree1.proposal
n_proposals = tree1.n_proposals + tree2.n_proposals n_proposals = tree1.n_proposals + tree2.n_proposals
tree = Subtree( tree = Subtree(
left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals left, right, p_sum, proposal, log_size, log_weighted_accept_sum, n_proposals
) )
return tree, diverging, turning return tree, diverging, turning
...@@ -401,7 +410,9 @@ class _Tree: ...@@ -401,7 +410,9 @@ class _Tree:
# Remove contribution from initial state which is always a perfect # Remove contribution from initial state which is always a perfect
# accept # accept
log_sum_weight = logdiffexp_numpy(self.log_size, 0.0) log_sum_weight = logdiffexp_numpy(self.log_size, 0.0)
self.mean_tree_accept = np.exp(self.log_accept_sum - log_sum_weight) self.mean_tree_accept = np.exp(
self.log_weighted_accept_sum - log_sum_weight
)
return { return {
"depth": self.depth, "depth": self.depth,
"mean_tree_accept": self.mean_tree_accept, "mean_tree_accept": self.mean_tree_accept,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment