Unverified Commit 3036da5f authored by michaelosthege's avatar michaelosthege Committed by GitHub

Support xarray input to sample_posterior_predictive (#3846)

* add test for xarray input to sample_posterior_predictive
* support for xarray.Dataset as trace argument to sample_posterior_predictive and fast_sample_posterior_predictive.
* closes #3828
Co-authored-by: default avatarMichael Osthege <zufallsprinzip@hotmail.de>
Co-authored-by: default avatarrpgoldman <rpgoldman@goldman-tribe.org>
parent 29821a54
......@@ -8,6 +8,7 @@
- `DEMetropolisZ`, an improved variant of `DEMetropolis` brings better parallelization and higher efficiency with fewer chains with a slower initial convergence. This implementation is experimental. See [#3784](https://github.com/pymc-devs/pymc3/pull/3784) for more info.
- Notebooks that give insight into `DEMetropolis`, `DEMetropolisZ` and the `DifferentialEquation` interface are now located in the [Tutorials/Deep Dive](https://docs.pymc.io/nb_tutorials/index.html) section.
- Add `fast_sample_posterior_predictive`, a vectorized alternative to `sample_posterior_predictive`. This alternative is substantially faster for large models.
- `sample_posterior_predictive` can now feed on `xarray.Dataset` - e.g. from `InferenceData.posterior`. (see [#3846](https://github.com/pymc-devs/pymc3/pull/3846))
- `SamplerReport` (`MultiTrace.report`) now has properties `n_tune`, `n_draws`, `t_sampling` for increased convenience (see [#3827](https://github.com/pymc-devs/pymc3/pull/3827))
### Maintenance
......
......@@ -12,12 +12,14 @@ from typing_extensions import Protocol
import numpy as np
import theano
import theano.tensor as tt
from xarray import Dataset
from ..backends.base import MultiTrace #, TraceLike, TraceDict
from .distribution import _DrawValuesContext, _DrawValuesContextBlocker, is_fast_drawable, _compile_theano_function, vectorized_ppc
from ..model import Model, get_named_nodes_and_relations, ObservedRV, MultiObservedRV, modelcontext
from ..exceptions import IncorrectArgumentsError
from ..vartypes import theano_constant
from ..util import dataset_to_point_dict
# Failing tests:
# test_mixture_random_shape::test_mixture_random_shape
#
......@@ -119,7 +121,7 @@ class _TraceDict(_TraceDictParent):
def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.ndarray]]],
def fast_sample_posterior_predictive(trace: Union[MultiTrace, Dataset, List[Dict[str, np.ndarray]]],
samples: Optional[int]=None,
model: Optional[Model]=None,
var_names: Optional[List[str]]=None,
......@@ -135,7 +137,7 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.
Parameters
----------
trace : MultiTrace or List of points
trace : MultiTrace, xarray.Dataset, or List of points (dictionary)
Trace generated from MCMC sampling.
samples : int, optional
Number of posterior predictive samples to generate. Defaults to one posterior predictive
......@@ -168,6 +170,9 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.
### greater than the number of samples in the trace parameter, we sample repeatedly. This
### makes the shape issues just a little easier to deal with.
if isinstance(trace, Dataset):
trace = dataset_to_point_dict(trace)
model = modelcontext(model)
assert model is not None
with model:
......
......@@ -30,6 +30,7 @@ import warnings
import numpy as np
import theano.gradient as tg
from theano.tensor import Tensor
import xarray
from .backends.base import BaseTrace, MultiTrace
from .backends.ndarray import NDArray
......@@ -53,6 +54,7 @@ from .util import (
get_untransformed_name,
is_transformed_name,
get_default_varnames,
dataset_to_point_dict,
)
from .vartypes import discrete_types
from .exceptions import IncorrectArgumentsError
......@@ -1520,9 +1522,9 @@ def sample_posterior_predictive(
Parameters
----------
trace: backend, list, or MultiTrace
Trace generated from MCMC sampling. Or a list containing dicts from
find_MAP() or points
trace: backend, list, xarray.Dataset, or MultiTrace
Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()),
or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior)
samples: int
Number of posterior predictive samples to generate. Defaults to one posterior predictive
sample per posterior sample, that is, the number of draws times the number of chains. It
......@@ -1556,6 +1558,9 @@ def sample_posterior_predictive(
Dictionary with the variable names as keys, and values numpy arrays containing
posterior predictive samples.
"""
if isinstance(trace, xarray.Dataset):
trace = dataset_to_point_dict(trace)
len_trace = len(trace)
try:
nchain = trace.nchains
......
......@@ -22,6 +22,7 @@ except ImportError:
import mock
import numpy.testing as npt
import arviz as az
import pymc3 as pm
import theano.tensor as tt
from theano import shared
......@@ -880,3 +881,32 @@ class TestSamplePosteriorPredictive:
var_names=['d']
)
def test_sample_from_xarray_prior(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture
with pmodel:
prior = pm.sample_prior_predictive(samples=20)
idat = az.from_pymc3(trace, prior=prior)
with pmodel:
pp = pm.sample_posterior_predictive(
idat.prior,
var_names=['d']
)
def test_sample_from_xarray_posterior(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture
idat = az.from_pymc3(trace)
with pmodel:
pp = pm.sample_posterior_predictive(
idat.posterior,
var_names=['d']
)
def test_sample_from_xarray_posterior_fast(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture
idat = az.from_pymc3(trace)
with pmodel:
pp = pm.fast_sample_posterior_predictive(
idat.posterior,
var_names=['d']
)
......@@ -14,7 +14,11 @@
import re
import functools
from numpy import asscalar
from typing import List, Dict
import xarray
from numpy import asscalar, ndarray
LATEX_ESCAPE_RE = re.compile(r'(%|_|\$|#|&)', re.MULTILINE)
......@@ -179,3 +183,21 @@ def biwrap(wrapper):
newwrapper = functools.partial(wrapper, *args, **kwargs)
return newwrapper
return enhanced
def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, ndarray]]:
# grab posterior samples for each variable
_samples = {
vn : ds[vn].values
for vn in ds.keys()
}
# make dicts
points = []
for c in ds.chain:
for d in ds.draw:
points.append({
vn : s[c, d]
for vn, s in _samples.items()
})
# use the list of points
ds = points
return ds
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment