Commit 9b50ea9a authored by Junpeng Lao's avatar Junpeng Lao

Rename property name in NUTS to better reflect the computation

parent 8f74fcb8
......@@ -206,11 +206,12 @@ class NUTS(BaseHMC):
# A proposal for the next position
Proposal = namedtuple("Proposal", "q, q_grad, energy, log_p_accept, logp")
Proposal = namedtuple("Proposal", "q, q_grad, energy, log_p_accept_weighted, logp")
# A subtree of the binary tree built by nuts.
Subtree = namedtuple(
"Subtree", "left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals"
"Subtree",
"left, right, p_sum, proposal, log_size, log_weighted_accept_sum, n_proposals",
)
......@@ -243,7 +244,7 @@ class _Tree:
)
self.depth = 0
self.log_size = 0
self.log_accept_sum = -np.inf
self.log_weighted_accept_sum = -np.inf
self.mean_tree_accept = 0.0
self.n_proposals = 0
self.p_sum = start.p.copy()
......@@ -291,7 +292,9 @@ class _Tree:
self.proposal = tree.proposal
self.log_size = np.logaddexp(self.log_size, tree.log_size)
self.log_accept_sum = np.logaddexp(self.log_accept_sum, tree.log_accept_sum)
self.log_weighted_accept_sum = np.logaddexp(
self.log_weighted_accept_sum, tree.log_weighted_accept_sum
)
self.p_sum[:] += tree.p_sum
# Additional turning check only when tree depth > 0 to avoid redundant work
......@@ -331,13 +334,17 @@ class _Tree:
# e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)})
# Saturated Metropolis accept probability with Boltzmann weight
# if h - H0 < 0
log_p_accept = -energy_change + min(0.0, -energy_change)
log_p_accept_weighted = -energy_change + min(0.0, -energy_change)
log_size = -energy_change
proposal = Proposal(
right.q, right.q_grad, right.energy, log_p_accept, right.model_logp
right.q,
right.q_grad,
right.energy,
log_p_accept_weighted,
right.model_logp,
)
tree = Subtree(
right, right, right.p, proposal, log_size, log_p_accept, 1
right, right, right.p, proposal, log_size, log_p_accept_weighted, 1
)
return tree, None, False
else:
......@@ -377,7 +384,9 @@ class _Tree:
turning = turning | turning1 | turning2
log_size = np.logaddexp(tree1.log_size, tree2.log_size)
log_accept_sum = np.logaddexp(tree1.log_accept_sum, tree2.log_accept_sum)
log_weighted_accept_sum = np.logaddexp(
tree1.log_weighted_accept_sum, tree2.log_weighted_accept_sum
)
if logbern(tree2.log_size - log_size):
proposal = tree2.proposal
else:
......@@ -385,13 +394,13 @@ class _Tree:
else:
p_sum = tree1.p_sum
log_size = tree1.log_size
log_accept_sum = tree1.log_accept_sum
log_weighted_accept_sum = tree1.log_weighted_accept_sum
proposal = tree1.proposal
n_proposals = tree1.n_proposals + tree2.n_proposals
tree = Subtree(
left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals
left, right, p_sum, proposal, log_size, log_weighted_accept_sum, n_proposals
)
return tree, diverging, turning
......@@ -401,7 +410,9 @@ class _Tree:
# Remove contribution from initial state which is always a perfect
# accept
log_sum_weight = logdiffexp_numpy(self.log_size, 0.0)
self.mean_tree_accept = np.exp(self.log_accept_sum - log_sum_weight)
self.mean_tree_accept = np.exp(
self.log_weighted_accept_sum - log_sum_weight
)
return {
"depth": self.depth,
"mean_tree_accept": self.mean_tree_accept,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment