Unverified Commit 6c5254fe authored by michaelosthege's avatar michaelosthege Committed by GitHub

Include n_tune, n_draws and t_sampling in SamplerReport (#3827)

* include n_tune, n_draws and t_sampling in SamplerReport
* count tune/draw samples instead of trusting parameters (because of KeyboardInterrupt)
* fall back to tune and len(trace) if tune stat is unavailable
* add test for SamplerReport n_tune and n_draws
* clarify that n_tune are not necessarily in the trace
* use actual number of chains to compute totals
* mention new SamplerReport properties in release notes
Co-authored-by: default avatarMichael Osthege <m.osthege@fz-juelich.de>
parent b5891be9
......@@ -8,6 +8,7 @@
- `DEMetropolisZ`, an improved variant of `DEMetropolis` brings better parallelization and higher efficiency with fewer chains with a slower initial convergence. This implementation is experimental. See [#3784](https://github.com/pymc-devs/pymc3/pull/3784) for more info.
- Notebooks that give insight into `DEMetropolis`, `DEMetropolisZ` and the `DifferentialEquation` interface are now located in the [Tutorials/Deep Dive](https://docs.pymc.io/nb_tutorials/index.html) section.
- Add `fast_sample_posterior_predictive`, a vectorized alternative to `sample_posterior_predictive`. This alternative is substantially faster for large models.
- `SamplerReport` (`MultiTrace.report`) now has properties `n_tune`, `n_draws`, `t_sampling` for increased convenience (see [#3827](https://github.com/pymc-devs/pymc3/pull/3827))
### Maintenance
- Remove `sample_ppc` and `sample_ppc_w` that were deprecated in 3.6.
......
......@@ -15,6 +15,7 @@
from collections import namedtuple
import logging
import enum
import typing
from ..util import is_transformed_name, get_untransformed_name
......@@ -51,11 +52,15 @@ _LEVELS = {
class SamplerReport:
"""This object bundles warnings, convergence statistics and metadata of a sampling run."""
def __init__(self):
self._chain_warnings = {}
self._global_warnings = []
self._ess = None
self._rhat = None
self._n_tune = None
self._n_draws = None
self._t_sampling = None
@property
def _warnings(self):
......@@ -68,6 +73,25 @@ class SamplerReport:
return all(_LEVELS[warn.level] < _LEVELS['warn']
for warn in self._warnings)
@property
def n_tune(self) -> typing.Optional[int]:
"""Number of tune iterations - not necessarily kept in trace!"""
return self._n_tune
@property
def n_draws(self) -> typing.Optional[int]:
"""Number of draw iterations."""
return self._n_draws
@property
def t_sampling(self) -> typing.Optional[float]:
"""
Number of seconds that the sampling procedure took.
(Includes parallelization overhead.)
"""
return self._t_sampling
def raise_ok(self, level='error'):
errors = [warn for warn in self._warnings
if _LEVELS[warn.level] >= _LEVELS[level]]
......@@ -151,7 +175,6 @@ class SamplerReport:
warn_list.extend(warnings)
def _log_summary(self):
def log_warning(warn):
level = _LEVELS[warn.level]
logger.log(level, warn.message)
......
......@@ -24,6 +24,7 @@ from collections import defaultdict
from copy import copy
import pickle
import logging
import time
import warnings
import numpy as np
......@@ -488,6 +489,7 @@ def sample(
)
parallel = cores > 1 and chains > 1 and not has_population_samplers
t_start = time.time()
if parallel:
_log.info("Multiprocess sampling ({} chains in {} jobs)".format(chains, cores))
_print_step_hierarchy(step)
......@@ -533,8 +535,36 @@ def sample(
_print_step_hierarchy(step)
trace = _sample_many(**sample_args)
discard = tune if discard_tuned_samples else 0
trace = trace[discard:]
t_sampling = time.time() - t_start
# count the number of tune/draw iterations that happened
# ideally via the "tune" statistic, but not all samplers record it!
if 'tune' in trace.stat_names:
stat = trace.get_sampler_stats('tune', chains=0)
# when CompoundStep is used, the stat is 2 dimensional!
if len(stat.shape) == 2:
stat = stat[:,0]
stat = tuple(stat)
n_tune = stat.count(True)
n_draws = stat.count(False)
else:
# these may be wrong when KeyboardInterrupt happened, but they're better than nothing
n_tune = min(tune, len(trace))
n_draws = max(0, len(trace) - n_tune)
if discard_tuned_samples:
trace = trace[n_tune:]
# save metadata in SamplerReport
trace.report._n_tune = n_tune
trace.report._n_draws = n_draws
trace.report._t_sampling = t_sampling
n_chains = len(trace.chains)
_log.info(
f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations '
f'({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) '
f'took {trace.report.t_sampling:.0f} seconds.'
)
if compute_convergence_checks:
if draws - tune < 100:
......
......@@ -142,6 +142,22 @@ class TestSample(SeededTest):
trace = pm.sample(draws=100, tune=50, cores=4)
assert len(trace) == 100
@pytest.mark.parametrize("step_cls", [pm.NUTS, pm.Metropolis, pm.Slice])
@pytest.mark.parametrize("discard", [True, False])
def test_trace_report(self, step_cls, discard):
with self.model:
# add more variables, because stats are 2D with CompoundStep!
pm.Uniform('uni')
trace = pm.sample(
draws=100, tune=50, cores=1,
discard_tuned_samples=discard,
step=step_cls()
)
assert trace.report.n_tune == 50
assert trace.report.n_draws == 100
assert isinstance(trace.report.t_sampling, float)
pass
@pytest.mark.parametrize('cores', [1, 2])
def test_sampler_stat_tune(self, cores):
with self.model:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment