Unverified Commit aafa00bc authored by Michael Osthege's avatar Michael Osthege Committed by GitHub

return_inferencedata option for pm.sample (#3911)

* mention arviz functions by name in warning

* convert to InferenceData already in sample function
+ convert to InferenceData and save metadata to it already in sample()
+ pass idata instead of trace to convergence check, to avoid duplicate work
+ directly use arviz diagnostics instead of pymc3 aliases

* fix refactoring bugs

* fix indentation

* add return_inferencedata option
+ set to None
+ defaults to False

* Fix numpy docstring format.

Replaced "<varname>: <type>" with "<varname> : <type>" per numpy guidelines.

Fix spelling typo.

* pass model to from_pymc3 because of deprecation warning

* add test for return_inferencedata option

* advise against keeping warmup draws in a MultiTrace

* mention #3911

* pin to arviz 0.8.0 and address review feedback

* rerun/update notebook to show inferencedata trace

* fix typo

* make all from_pymc3 accessible to the user

* remove duplicate entry, and wording

* address review feedback

* update arviz to 0.8.1 because of bugfix

* incorporate review feedback
+ more direct use of ArviZ
+ some wording things

* use arviz plot_ppc

* also ignore Visual Studio cache

* fix warmup saving logic and test

* require latest ArviZ patch

* change warning to nuget users towards InferenceData

* update ArviZ minimum version

* address review feedback

* start showing the FutureWarning about return_inferencedata in minor release >=3.10

* require arviz>=0.8.3 for latest bugfix
Co-authored-by: default avatarrpgoldman <rpgoldman@goldman-tribe.org>
parent 30d28f44
...@@ -33,7 +33,8 @@ benchmarks/html/ ...@@ -33,7 +33,8 @@ benchmarks/html/
benchmarks/results/ benchmarks/results/
.pytest_cache/ .pytest_cache/
# VSCode # Visual Studio / VSCode
.vs/
.vscode/ .vscode/
.mypy_cache .mypy_cache
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
- GP covariance functions can now be exponentiated by a scalar. See PR [#3852](https://github.com/pymc-devs/pymc3/pull/3852) - GP covariance functions can now be exponentiated by a scalar. See PR [#3852](https://github.com/pymc-devs/pymc3/pull/3852)
- `sample_posterior_predictive` can now feed on `xarray.Dataset` - e.g. from `InferenceData.posterior`. (see [#3846](https://github.com/pymc-devs/pymc3/pull/3846)) - `sample_posterior_predictive` can now feed on `xarray.Dataset` - e.g. from `InferenceData.posterior`. (see [#3846](https://github.com/pymc-devs/pymc3/pull/3846))
- `SamplerReport` (`MultiTrace.report`) now has properties `n_tune`, `n_draws`, `t_sampling` for increased convenience (see [#3827](https://github.com/pymc-devs/pymc3/pull/3827)) - `SamplerReport` (`MultiTrace.report`) now has properties `n_tune`, `n_draws`, `t_sampling` for increased convenience (see [#3827](https://github.com/pymc-devs/pymc3/pull/3827))
- `pm.sample(..., return_inferencedata=True)` can now directly return the trace as `arviz.InferenceData` (see [#3911](https://github.com/pymc-devs/pymc3/pull/3911))
- `pm.sample` now has support for adapting dense mass matrix using `QuadPotentialFullAdapt` (see [#3596](https://github.com/pymc-devs/pymc3/pull/3596), [#3705](https://github.com/pymc-devs/pymc3/pull/3705), [#3858](https://github.com/pymc-devs/pymc3/pull/3858), and [#3893](https://github.com/pymc-devs/pymc3/pull/3893)). Use `init="adapt_full"` or `init="jitter+adapt_full"` to use. - `pm.sample` now has support for adapting dense mass matrix using `QuadPotentialFullAdapt` (see [#3596](https://github.com/pymc-devs/pymc3/pull/3596), [#3705](https://github.com/pymc-devs/pymc3/pull/3705), [#3858](https://github.com/pymc-devs/pymc3/pull/3858), and [#3893](https://github.com/pymc-devs/pymc3/pull/3893)). Use `init="adapt_full"` or `init="jitter+adapt_full"` to use.
- `Moyal` distribution added (see [#3870](https://github.com/pymc-devs/pymc3/pull/3870)). - `Moyal` distribution added (see [#3870](https://github.com/pymc-devs/pymc3/pull/3870)).
- `pm.LKJCholeskyCov` now automatically computes and returns the unpacked Cholesky decomposition, the correlations and the standard deviations of the covariance matrix (see [#3881](https://github.com/pymc-devs/pymc3/pull/3881)). - `pm.LKJCholeskyCov` now automatically computes and returns the unpacked Cholesky decomposition, the correlations and the standard deviations of the covariance matrix (see [#3881](https://github.com/pymc-devs/pymc3/pull/3881)).
...@@ -21,6 +22,8 @@ ...@@ -21,6 +22,8 @@
### Maintenance ### Maintenance
- Tuning results no longer leak into sequentially sampled `Metropolis` chains (see #3733 and #3796). - Tuning results no longer leak into sequentially sampled `Metropolis` chains (see #3733 and #3796).
- We'll deprecate the `Text` and `SQLite` backends and the `save_trace`/`load_trace` functions, since this is now done with ArviZ. (see [#3902](https://github.com/pymc-devs/pymc3/pull/3902))
- ArviZ `v0.8.3` is now the minimum required version
- In named models, `pm.Data` objects now get model-relative names (see [#3843](https://github.com/pymc-devs/pymc3/pull/3843)). - In named models, `pm.Data` objects now get model-relative names (see [#3843](https://github.com/pymc-devs/pymc3/pull/3843)).
- `pm.sample` now takes 1000 draws and 1000 tuning samples by default, instead of 500 previously (see [#3855](https://github.com/pymc-devs/pymc3/pull/3855)). - `pm.sample` now takes 1000 draws and 1000 tuning samples by default, instead of 500 previously (see [#3855](https://github.com/pymc-devs/pymc3/pull/3855)).
- Moved argument division out of `NegativeBinomial` `random` method. Fixes [#3864](https://github.com/pymc-devs/pymc3/issues/3864) in the style of [#3509](https://github.com/pymc-devs/pymc3/pull/3509). - Moved argument division out of `NegativeBinomial` `random` method. Fixes [#3864](https://github.com/pymc-devs/pymc3/issues/3864) in the style of [#3509](https://github.com/pymc-devs/pymc3/pull/3509).
...@@ -34,7 +37,7 @@ ...@@ -34,7 +37,7 @@
### Deprecations ### Deprecations
- Remove `sample_ppc` and `sample_ppc_w` that were deprecated in 3.6. - Remove `sample_ppc` and `sample_ppc_w` that were deprecated in 3.6.
- Deprecated `sd` in version 3.7 has been replaced by `sigma` now raises `DeprecationWarning` on using `sd` in continuous, mixed and timeseries distributions. (see [#3837](https://github.com/pymc-devs/pymc3/pull/3837) and [#3688](https://github.com/pymc-devs/pymc3/issues/3688)). - Deprecated `sd` has been replaced by `sigma` (already in version 3.7) in continuous, mixed and timeseries distributions and now raises `DeprecationWarning` when `sd` is used. (see [#3837](https://github.com/pymc-devs/pymc3/pull/3837) and [#3688](https://github.com/pymc-devs/pymc3/issues/3688)).
- We'll deprecate the `Text` and `SQLite` backends and the `save_trace`/`load_trace` functions, since this is now done with ArviZ. (see [#3902](https://github.com/pymc-devs/pymc3/pull/3902)) - We'll deprecate the `Text` and `SQLite` backends and the `save_trace`/`load_trace` functions, since this is now done with ArviZ. (see [#3902](https://github.com/pymc-devs/pymc3/pull/3902))
- Dropped some deprecated kwargs and functions (see [#3906](https://github.com/pymc-devs/pymc3/pull/3906)) - Dropped some deprecated kwargs and functions (see [#3906](https://github.com/pymc-devs/pymc3/pull/3906))
- Dropped the outdated 'nuts' initialization method for `pm.sample` (see [#3863](https://github.com/pymc-devs/pymc3/pull/3863)). - Dropped the outdated 'nuts' initialization method for `pm.sample` (see [#3863](https://github.com/pymc-devs/pymc3/pull/3863)).
......
This diff is collapsed.
...@@ -55,7 +55,7 @@ def save_trace(trace: MultiTrace, directory: Optional[str]=None, overwrite=False ...@@ -55,7 +55,7 @@ def save_trace(trace: MultiTrace, directory: Optional[str]=None, overwrite=False
""" """
warnings.warn( warnings.warn(
'The `save_trace` function will soon be removed.' 'The `save_trace` function will soon be removed.'
'Instead, use ArviZ to save/load traces.', 'Instead, use `arviz.to_netcdf` to save traces.',
DeprecationWarning, DeprecationWarning,
) )
...@@ -98,7 +98,7 @@ def load_trace(directory: str, model=None) -> MultiTrace: ...@@ -98,7 +98,7 @@ def load_trace(directory: str, model=None) -> MultiTrace:
""" """
warnings.warn( warnings.warn(
'The `load_trace` function will soon be removed.' 'The `load_trace` function will soon be removed.'
'Instead, use ArviZ to save/load traces.', 'Instead, use `arviz.from_netcdf` to load traces.',
DeprecationWarning, DeprecationWarning,
) )
straces = [] straces = []
......
...@@ -18,6 +18,7 @@ import enum ...@@ -18,6 +18,7 @@ import enum
import typing import typing
from ..util import is_transformed_name, get_untransformed_name from ..util import is_transformed_name, get_untransformed_name
import arviz
logger = logging.getLogger('pymc3') logger = logging.getLogger('pymc3')
...@@ -98,8 +99,8 @@ class SamplerReport: ...@@ -98,8 +99,8 @@ class SamplerReport:
if errors: if errors:
raise ValueError('Serious convergence issues during sampling.') raise ValueError('Serious convergence issues during sampling.')
def _run_convergence_checks(self, trace, model): def _run_convergence_checks(self, idata:arviz.InferenceData, model):
if trace.nchains == 1: if idata.posterior.sizes['chain'] == 1:
msg = ("Only one chain was sampled, this makes it impossible to " msg = ("Only one chain was sampled, this makes it impossible to "
"run some convergence checks") "run some convergence checks")
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info', warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info',
...@@ -107,9 +108,6 @@ class SamplerReport: ...@@ -107,9 +108,6 @@ class SamplerReport:
self._add_warnings([warn]) self._add_warnings([warn])
return return
from pymc3 import rhat, ess
from arviz import from_pymc3
valid_name = [rv.name for rv in model.free_RVs + model.deterministics] valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
varnames = [] varnames = []
for rv in model.free_RVs: for rv in model.free_RVs:
...@@ -117,12 +115,11 @@ class SamplerReport: ...@@ -117,12 +115,11 @@ class SamplerReport:
if is_transformed_name(rv_name): if is_transformed_name(rv_name):
rv_name2 = get_untransformed_name(rv_name) rv_name2 = get_untransformed_name(rv_name)
rv_name = rv_name2 if rv_name2 in valid_name else rv_name rv_name = rv_name2 if rv_name2 in valid_name else rv_name
if rv_name in trace.varnames: if rv_name in idata.posterior:
varnames.append(rv_name) varnames.append(rv_name)
idata = from_pymc3(trace, log_likelihood=False) self._ess = ess = arviz.ess(idata, var_names=varnames)
self._ess = ess = ess(idata, var_names=varnames) self._rhat = rhat = arviz.rhat(idata, var_names=varnames)
self._rhat = rhat = rhat(idata, var_names=varnames)
warnings = [] warnings = []
rhat_max = max(val.max() for val in rhat.values()) rhat_max = max(val.max() for val in rhat.values())
...@@ -147,7 +144,7 @@ class SamplerReport: ...@@ -147,7 +144,7 @@ class SamplerReport:
warnings.append(warn) warnings.append(warn)
eff_min = min(val.min() for val in ess.values()) eff_min = min(val.min() for val in ess.values())
n_samples = len(trace) * trace.nchains n_samples = idata.posterior.sizes['chain'] * idata.posterior.sizes['draw']
if eff_min < 200 and n_samples >= 500: if eff_min < 200 and n_samples >= 500:
msg = ("The estimated number of effective samples is smaller than " msg = ("The estimated number of effective samples is smaller than "
"200 for some parameters.") "200 for some parameters.")
......
...@@ -194,7 +194,7 @@ def load(name, model=None): ...@@ -194,7 +194,7 @@ def load(name, model=None):
""" """
warnings.warn( warnings.warn(
'The `load` function will soon be removed. ' 'The `load` function will soon be removed. '
'Please use ArviZ to save traces. ' 'Please use `arviz.from_netcdf` to load traces. '
'If you have good reasons for using the `load` function, file an issue and tell us about them. ', 'If you have good reasons for using the `load` function, file an issue and tell us about them. ',
DeprecationWarning, DeprecationWarning,
) )
...@@ -239,7 +239,7 @@ def dump(name, trace, chains=None): ...@@ -239,7 +239,7 @@ def dump(name, trace, chains=None):
""" """
warnings.warn( warnings.warn(
'The `dump` function will soon be removed. ' 'The `dump` function will soon be removed. '
'Please use ArviZ to save traces. ' 'Please use `arviz.to_netcdf` to save traces. '
'If you have good reasons for using the `dump` function, file an issue and tell us about them. ', 'If you have good reasons for using the `dump` function, file an issue and tell us about them. ',
DeprecationWarning, DeprecationWarning,
) )
......
This diff is collapsed.
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from itertools import combinations from itertools import combinations
import packaging
from typing import Tuple from typing import Tuple
import numpy as np import numpy as np
...@@ -160,6 +161,38 @@ class TestSample(SeededTest): ...@@ -160,6 +161,38 @@ class TestSample(SeededTest):
assert isinstance(trace.report.t_sampling, float) assert isinstance(trace.report.t_sampling, float)
pass pass
def test_return_inferencedata(self):
with self.model:
kwargs = dict(
draws=100, tune=50, cores=1,
chains=2, step=pm.Metropolis()
)
v = packaging.version.parse(pm.__version__)
if v.major > 3 or v.minor >= 10:
with pytest.warns(FutureWarning, match="pass return_inferencedata"):
result = pm.sample(**kwargs)
# trace with tuning
with pytest.warns(UserWarning, match="will be included"):
result = pm.sample(**kwargs, return_inferencedata=False, discard_tuned_samples=False)
assert isinstance(result, pm.backends.base.MultiTrace)
assert len(result) == 150
# inferencedata with tuning
result = pm.sample(**kwargs, return_inferencedata=True, discard_tuned_samples=False)
assert isinstance(result, az.InferenceData)
assert result.posterior.sizes["draw"] == 100
assert result.posterior.sizes["chain"] == 2
assert len(result._groups_warmup) > 0
# inferencedata without tuning
result = pm.sample(**kwargs, return_inferencedata=True, discard_tuned_samples=True)
assert isinstance(result, az.InferenceData)
assert result.posterior.sizes["draw"] == 100
assert result.posterior.sizes["chain"] == 2
assert len(result._groups_warmup) == 0
pass
@pytest.mark.parametrize('cores', [1, 2]) @pytest.mark.parametrize('cores', [1, 2])
def test_sampler_stat_tune(self, cores): def test_sampler_stat_tune(self, cores):
with self.model: with self.model:
......
arviz>=0.7.0 arviz>=0.8.3
theano>=1.0.4 theano>=1.0.4
numpy>=1.13.0 numpy>=1.13.0
scipy>=0.18.1 scipy>=0.18.1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment