Unverified Commit aafa00bc authored by Michael Osthege's avatar Michael Osthege Committed by GitHub

return_inferencedata option for pm.sample (#3911)

* mention arviz functions by name in warning

* convert to InferenceData already in sample function
+ convert to InferenceData and save metadata to it already in sample()
+ pass idata instead of trace to convergence check, to avoid duplicate work
+ directly use arviz diagnostics instead of pymc3 aliases

* fix refactoring bugs

* fix indentation

* add return_inferencedata option
+ set to None
+ defaults to False

* Fix numpy docstring format.

Replaced "<varname>: <type>" with "<varname> : <type>" per numpy guidelines.

Fix spelling typo.

* pass model to from_pymc3 because of deprecation warning

* add test for return_inferencedata option

* advise against keeping warmup draws in a MultiTrace

* mention #3911

* pin to arviz 0.8.0 and address review feedback

* rerun/update notebook to show inferencedata trace

* fix typo

* make all from_pymc3 accessible to the user

* remove duplicate entry, and wording

* address review feedback

* update arviz to 0.8.1 because of bugfix

* incorporate review feedback
+ more direct use of ArviZ
+ some wording things

* use arviz plot_ppc

* also ignore Visual Studio cache

* fix warmup saving logic and test

* require latest ArviZ patch

* change warning to nuget users towards InferenceData

* update ArviZ minimum version

* address review feedback

* start showing the FutureWarning about return_inferencedata in minor release >=3.10

* require arviz>=0.8.3 for latest bugfix
Co-authored-by: default avatarrpgoldman <rpgoldman@goldman-tribe.org>
parent 30d28f44
......@@ -33,7 +33,8 @@ benchmarks/html/
benchmarks/results/
.pytest_cache/
# VSCode
# Visual Studio / VSCode
.vs/
.vscode/
.mypy_cache
......
......@@ -11,6 +11,7 @@
- GP covariance functions can now be exponentiated by a scalar. See PR [#3852](https://github.com/pymc-devs/pymc3/pull/3852)
- `sample_posterior_predictive` can now feed on `xarray.Dataset` - e.g. from `InferenceData.posterior`. (see [#3846](https://github.com/pymc-devs/pymc3/pull/3846))
- `SamplerReport` (`MultiTrace.report`) now has properties `n_tune`, `n_draws`, `t_sampling` for increased convenience (see [#3827](https://github.com/pymc-devs/pymc3/pull/3827))
- `pm.sample(..., return_inferencedata=True)` can now directly return the trace as `arviz.InferenceData` (see [#3911](https://github.com/pymc-devs/pymc3/pull/3911))
- `pm.sample` now has support for adapting dense mass matrix using `QuadPotentialFullAdapt` (see [#3596](https://github.com/pymc-devs/pymc3/pull/3596), [#3705](https://github.com/pymc-devs/pymc3/pull/3705), [#3858](https://github.com/pymc-devs/pymc3/pull/3858), and [#3893](https://github.com/pymc-devs/pymc3/pull/3893)). Use `init="adapt_full"` or `init="jitter+adapt_full"` to use.
- `Moyal` distribution added (see [#3870](https://github.com/pymc-devs/pymc3/pull/3870)).
- `pm.LKJCholeskyCov` now automatically computes and returns the unpacked Cholesky decomposition, the correlations and the standard deviations of the covariance matrix (see [#3881](https://github.com/pymc-devs/pymc3/pull/3881)).
......@@ -21,6 +22,8 @@
### Maintenance
- Tuning results no longer leak into sequentially sampled `Metropolis` chains (see #3733 and #3796).
- We'll deprecate the `Text` and `SQLite` backends and the `save_trace`/`load_trace` functions, since this is now done with ArviZ. (see [#3902](https://github.com/pymc-devs/pymc3/pull/3902))
- ArviZ `v0.8.3` is now the minimum required version
- In named models, `pm.Data` objects now get model-relative names (see [#3843](https://github.com/pymc-devs/pymc3/pull/3843)).
- `pm.sample` now takes 1000 draws and 1000 tuning samples by default, instead of 500 previously (see [#3855](https://github.com/pymc-devs/pymc3/pull/3855)).
- Moved argument division out of `NegativeBinomial` `random` method. Fixes [#3864](https://github.com/pymc-devs/pymc3/issues/3864) in the style of [#3509](https://github.com/pymc-devs/pymc3/pull/3509).
......@@ -34,7 +37,7 @@
### Deprecations
- Remove `sample_ppc` and `sample_ppc_w` that were deprecated in 3.6.
- Deprecated `sd` in version 3.7 has been replaced by `sigma` now raises `DeprecationWarning` on using `sd` in continuous, mixed and timeseries distributions. (see [#3837](https://github.com/pymc-devs/pymc3/pull/3837) and [#3688](https://github.com/pymc-devs/pymc3/issues/3688)).
- Deprecated `sd` has been replaced by `sigma` (already in version 3.7) in continuous, mixed and timeseries distributions and now raises `DeprecationWarning` when `sd` is used. (see [#3837](https://github.com/pymc-devs/pymc3/pull/3837) and [#3688](https://github.com/pymc-devs/pymc3/issues/3688)).
- We'll deprecate the `Text` and `SQLite` backends and the `save_trace`/`load_trace` functions, since this is now done with ArviZ. (see [#3902](https://github.com/pymc-devs/pymc3/pull/3902))
- Dropped some deprecated kwargs and functions (see [#3906](https://github.com/pymc-devs/pymc3/pull/3906))
- Dropped the outdated 'nuts' initialization method for `pm.sample` (see [#3863](https://github.com/pymc-devs/pymc3/pull/3863)).
......
This diff is collapsed.
......@@ -55,7 +55,7 @@ def save_trace(trace: MultiTrace, directory: Optional[str]=None, overwrite=False
"""
warnings.warn(
'The `save_trace` function will soon be removed.'
'Instead, use ArviZ to save/load traces.',
'Instead, use `arviz.to_netcdf` to save traces.',
DeprecationWarning,
)
......@@ -98,7 +98,7 @@ def load_trace(directory: str, model=None) -> MultiTrace:
"""
warnings.warn(
'The `load_trace` function will soon be removed.'
'Instead, use ArviZ to save/load traces.',
'Instead, use `arviz.from_netcdf` to load traces.',
DeprecationWarning,
)
straces = []
......
......@@ -18,6 +18,7 @@ import enum
import typing
from ..util import is_transformed_name, get_untransformed_name
import arviz
logger = logging.getLogger('pymc3')
......@@ -98,8 +99,8 @@ class SamplerReport:
if errors:
raise ValueError('Serious convergence issues during sampling.')
def _run_convergence_checks(self, trace, model):
if trace.nchains == 1:
def _run_convergence_checks(self, idata:arviz.InferenceData, model):
if idata.posterior.sizes['chain'] == 1:
msg = ("Only one chain was sampled, this makes it impossible to "
"run some convergence checks")
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info',
......@@ -107,9 +108,6 @@ class SamplerReport:
self._add_warnings([warn])
return
from pymc3 import rhat, ess
from arviz import from_pymc3
valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
varnames = []
for rv in model.free_RVs:
......@@ -117,12 +115,11 @@ class SamplerReport:
if is_transformed_name(rv_name):
rv_name2 = get_untransformed_name(rv_name)
rv_name = rv_name2 if rv_name2 in valid_name else rv_name
if rv_name in trace.varnames:
if rv_name in idata.posterior:
varnames.append(rv_name)
idata = from_pymc3(trace, log_likelihood=False)
self._ess = ess = ess(idata, var_names=varnames)
self._rhat = rhat = rhat(idata, var_names=varnames)
self._ess = ess = arviz.ess(idata, var_names=varnames)
self._rhat = rhat = arviz.rhat(idata, var_names=varnames)
warnings = []
rhat_max = max(val.max() for val in rhat.values())
......@@ -147,7 +144,7 @@ class SamplerReport:
warnings.append(warn)
eff_min = min(val.min() for val in ess.values())
n_samples = len(trace) * trace.nchains
n_samples = idata.posterior.sizes['chain'] * idata.posterior.sizes['draw']
if eff_min < 200 and n_samples >= 500:
msg = ("The estimated number of effective samples is smaller than "
"200 for some parameters.")
......
......@@ -194,7 +194,7 @@ def load(name, model=None):
"""
warnings.warn(
'The `load` function will soon be removed. '
'Please use ArviZ to save traces. '
'Please use `arviz.from_netcdf` to load traces. '
'If you have good reasons for using the `load` function, file an issue and tell us about them. ',
DeprecationWarning,
)
......@@ -239,7 +239,7 @@ def dump(name, trace, chains=None):
"""
warnings.warn(
'The `dump` function will soon be removed. '
'Please use ArviZ to save traces. '
'Please use `arviz.to_netcdf` to save traces. '
'If you have good reasons for using the `dump` function, file an issue and tell us about them. ',
DeprecationWarning,
)
......
This diff is collapsed.
......@@ -13,6 +13,7 @@
# limitations under the License.
from itertools import combinations
import packaging
from typing import Tuple
import numpy as np
......@@ -160,6 +161,38 @@ class TestSample(SeededTest):
assert isinstance(trace.report.t_sampling, float)
pass
def test_return_inferencedata(self):
with self.model:
kwargs = dict(
draws=100, tune=50, cores=1,
chains=2, step=pm.Metropolis()
)
v = packaging.version.parse(pm.__version__)
if v.major > 3 or v.minor >= 10:
with pytest.warns(FutureWarning, match="pass return_inferencedata"):
result = pm.sample(**kwargs)
# trace with tuning
with pytest.warns(UserWarning, match="will be included"):
result = pm.sample(**kwargs, return_inferencedata=False, discard_tuned_samples=False)
assert isinstance(result, pm.backends.base.MultiTrace)
assert len(result) == 150
# inferencedata with tuning
result = pm.sample(**kwargs, return_inferencedata=True, discard_tuned_samples=False)
assert isinstance(result, az.InferenceData)
assert result.posterior.sizes["draw"] == 100
assert result.posterior.sizes["chain"] == 2
assert len(result._groups_warmup) > 0
# inferencedata without tuning
result = pm.sample(**kwargs, return_inferencedata=True, discard_tuned_samples=True)
assert isinstance(result, az.InferenceData)
assert result.posterior.sizes["draw"] == 100
assert result.posterior.sizes["chain"] == 2
assert len(result._groups_warmup) == 0
pass
@pytest.mark.parametrize('cores', [1, 2])
def test_sampler_stat_tune(self, cores):
with self.model:
......
arviz>=0.7.0
arviz>=0.8.3
theano>=1.0.4
numpy>=1.13.0
scipy>=0.18.1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment