Unverified Commit 680dadd1 authored by Dan Foreman-Mackey's avatar Dan Foreman-Mackey Committed by GitHub

Fixing HMC et al. tuning schedule to reset at the beginning of a chain even on 1 core (#3941)

* fixing #3939

* fixing introduced bug in fulladapt quadpot
parent b6a88f06
...@@ -198,6 +198,10 @@ class BaseHMC(arraystep.GradientSharedStep): ...@@ -198,6 +198,10 @@ class BaseHMC(arraystep.GradientSharedStep):
return hmc_step.end.q, [stats] return hmc_step.end.q, [stats]
def reset_tuning(self, start=None):
self.step_adapt.reset()
self.reset(start=None)
def reset(self, start=None): def reset(self, start=None):
self.tune = True self.tune = True
self.potential.reset() self.potential.reset()
......
...@@ -171,16 +171,24 @@ class QuadPotentialDiagAdapt(QuadPotential): ...@@ -171,16 +171,24 @@ class QuadPotentialDiagAdapt(QuadPotential):
self.dtype = dtype self.dtype = dtype
self._n = n self._n = n
self._var = np.array(initial_diag, dtype=self.dtype, copy=True)
self._initial_mean = initial_mean
self._initial_diag = initial_diag
self._initial_weight = initial_weight
self.adaptation_window = adaptation_window
self.adaptation_window_multiplier = float(adaptation_window_multiplier)
self.reset()
def reset(self):
self._var = np.array(self._initial_diag, dtype=self.dtype, copy=True)
self._var_theano = theano.shared(self._var) self._var_theano = theano.shared(self._var)
self._stds = np.sqrt(initial_diag) self._stds = np.sqrt(self._initial_diag)
self._inv_stds = floatX(1.) / self._stds self._inv_stds = floatX(1.) / self._stds
self._foreground_var = _WeightedVariance( self._foreground_var = _WeightedVariance(
self._n, initial_mean, initial_diag, initial_weight, self.dtype) self._n, self._initial_mean, self._initial_diag, self._initial_weight, self.dtype)
self._background_var = _WeightedVariance(self._n, dtype=self.dtype) self._background_var = _WeightedVariance(self._n, dtype=self.dtype)
self._n_samples = 0 self._n_samples = 0
self.adaptation_window = adaptation_window
self.adaptation_window_multiplier = float(adaptation_window_multiplier)
def velocity(self, x, out=None): def velocity(self, x, out=None):
"""Compute the current velocity at a position in parameter space.""" """Compute the current velocity at a position in parameter space."""
...@@ -275,8 +283,8 @@ class QuadPotentialDiagAdaptGrad(QuadPotentialDiagAdapt): ...@@ -275,8 +283,8 @@ class QuadPotentialDiagAdaptGrad(QuadPotentialDiagAdapt):
This is experimental, and may be removed without prior deprication. This is experimental, and may be removed without prior deprication.
""" """
def __init__(self, *args, **kwargs): def reset(self):
super().__init__(*args, **kwargs) super().reset()
self._grads1 = np.zeros(self._n, dtype=self.dtype) self._grads1 = np.zeros(self._n, dtype=self.dtype)
self._ngrads1 = 0 self._ngrads1 = 0
self._grads2 = np.zeros(self._n, dtype=self.dtype) self._grads2 = np.zeros(self._n, dtype=self.dtype)
...@@ -518,20 +526,27 @@ class QuadPotentialFullAdapt(QuadPotentialFull): ...@@ -518,20 +526,27 @@ class QuadPotentialFullAdapt(QuadPotentialFull):
self.dtype = dtype self.dtype = dtype
self._n = n self._n = n
self._cov = np.array(initial_cov, dtype=self.dtype, copy=True) self._initial_mean = initial_mean
self._initial_cov = initial_cov
self._initial_weight = initial_weight
self.adaptation_window = int(adaptation_window)
self.adaptation_window_multiplier = float(adaptation_window_multiplier)
self._update_window = int(update_window)
self.reset()
def reset(self):
self._previous_update = 0
self._cov = np.array(self._initial_cov, dtype=self.dtype, copy=True)
self._chol = scipy.linalg.cholesky(self._cov, lower=True) self._chol = scipy.linalg.cholesky(self._cov, lower=True)
self._chol_error = None self._chol_error = None
self._foreground_cov = _WeightedCovariance( self._foreground_cov = _WeightedCovariance(
self._n, initial_mean, initial_cov, initial_weight, self.dtype self._n, self._initial_mean, self._initial_cov, self._initial_weight, self.dtype
) )
self._background_cov = _WeightedCovariance(self._n, dtype=self.dtype) self._background_cov = _WeightedCovariance(self._n, dtype=self.dtype)
self._n_samples = 0 self._n_samples = 0
self.adaptation_window = int(adaptation_window)
self.adaptation_window_multiplier = float(adaptation_window_multiplier)
self._update_window = int(update_window)
self._previous_update = 0
def _update_from_weightvar(self, weightvar): def _update_from_weightvar(self, weightvar):
weightvar.current_covariance(out=self._cov) weightvar.current_covariance(out=self._cov)
try: try:
......
...@@ -20,15 +20,19 @@ from pymc3.backends.report import SamplerWarning, WarningType ...@@ -20,15 +20,19 @@ from pymc3.backends.report import SamplerWarning, WarningType
class DualAverageAdaptation: class DualAverageAdaptation:
def __init__(self, initial_step, target, gamma, k, t0): def __init__(self, initial_step, target, gamma, k, t0):
self._log_step = np.log(initial_step) self._initial_step = initial_step
self._log_bar = self._log_step
self._target = target self._target = target
self._hbar = 0.
self._k = k self._k = k
self._t0 = t0 self._t0 = t0
self._count = 1
self._mu = np.log(10 * initial_step)
self._gamma = gamma self._gamma = gamma
self.reset()
def reset(self):
self._log_step = np.log(self._initial_step)
self._log_bar = self._log_step
self._hbar = 0.
self._count = 1
self._mu = np.log(10 * self._initial_step)
self._tuned_stats = [] self._tuned_stats = []
def current(self, tune): def current(self, tune):
......
...@@ -145,6 +145,15 @@ class TestSample(SeededTest): ...@@ -145,6 +145,15 @@ class TestSample(SeededTest):
trace = pm.sample(draws=100, tune=50, cores=4) trace = pm.sample(draws=100, tune=50, cores=4)
assert len(trace) == 100 assert len(trace) == 100
def test_reset_tuning(self):
with self.model:
tune = 50
chains = 2
start, step = pm.sampling.init_nuts(chains=chains)
pm.sample(draws=2, tune=tune, chains=chains, step=step, start=start, cores=1)
assert step.potential._n_samples == tune
assert step.step_adapt._count == tune + 1
@pytest.mark.parametrize("step_cls", [pm.NUTS, pm.Metropolis, pm.Slice]) @pytest.mark.parametrize("step_cls", [pm.NUTS, pm.Metropolis, pm.Slice])
@pytest.mark.parametrize("discard", [True, False]) @pytest.mark.parametrize("discard", [True, False])
def test_trace_report(self, step_cls, discard): def test_trace_report(self, step_cls, discard):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment