Commit 7a4d47c0 authored by Junpeng Lao's avatar Junpeng Lao

move function to pymc3.math (2)

parent 0ebedb02
......@@ -21,33 +21,12 @@ from ..arraystep import Competence
from .base_hmc import BaseHMC, HMCStepData, DivergenceInfo
from .integration import IntegrationError
from pymc3.backends.report import SamplerWarning, WarningType
from pymc3.math import logbern, log1mexp_numpy, logdiffexp_numpy
from pymc3.theanof import floatX
from pymc3.vartypes import continuous_types
__all__ = ["NUTS"]
def logbern(log_p):
if np.isnan(log_p):
raise FloatingPointError("log_p can't be nan.")
return np.log(nr.uniform()) < log_p
def log1mexp_numpy(x):
"""Return log(1 - exp(-x)).
This function is numerically more stable than the naive approach.
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
"""
return np.where(
x < 0.683,
np.log(-np.expm1(-x)),
np.log1p(-np.exp(-x)))
def logdiffexp_numpy(a, b):
"""log(exp(a) - exp(b))"""
return a + log1mexp_numpy(a - b)
__all__ = ["NUTS"]
class NUTS(BaseHMC):
......@@ -422,7 +401,7 @@ class _Tree:
if self.log_size > 0:
# Remove contribution from initial state which is always a perfect
# accept
log_sum_weight = logdiffexp_numpy(self.log_size, 0.)
log_sum_weight = logdiffexp_numpy(self.log_size, 0.0)
self.mean_tree_accept = np.exp(self.log_accept_sum - log_sum_weight)
return {
"depth": self.depth,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment