Unverified Commit 74b77889 authored by rpgoldman's avatar rpgoldman Committed by GitHub

Merge pull request #3841 from rpgoldman/iss3840

Fix computation of samples argument in sample_posterior_predictive
Solves #3840 
parents 363afc80 839206b6
......@@ -1568,7 +1568,13 @@ def sample_posterior_predictive(
raise IncorrectArgumentsError("Should not specify both keep_size and size argukments")
if samples is None:
samples = sum(len(v) for v in trace._straces.values())
if isinstance(trace, MultiTrace):
samples = sum(len(v) for v in trace._straces.values())
elif isinstance(trace, list) and all((isinstance(x, dict) for x in trace)):
# this is a list of points
samples = len(trace)
else:
raise ValueError("Do not know how to compute number of samples for trace argument of type %s"%type(trace))
if samples < len_trace * nchain:
warnings.warn(
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from itertools import combinations
from typing import Tuple
import numpy as np
try:
......@@ -693,6 +694,16 @@ def test_exec_nuts_init(method):
assert "a" in start[0] and "b_log__" in start[0]
@pytest.fixture(scope="class")
def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]:
with pm.Model() as pmodel:
n = pm.Normal('n')
trace = pm.sample()
with pmodel:
d = pm.Deterministic('d', n * 4)
return pmodel, trace
class TestSamplePriorPredictive(SeededTest):
def test_ignores_observed(self):
observed = np.random.normal(10, 1, size=200)
......@@ -851,3 +862,21 @@ class TestSamplePriorPredictive(SeededTest):
with model:
prior_trace = pm.sample_prior_predictive(5)
assert prior_trace["x"].shape == (5, 3, 1)
class TestSamplePosteriorPredictive:
def test_point_list_arg_bug_fspp(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture
with pmodel:
pp = pm.fast_sample_posterior_predictive(
[trace[15]],
var_names=['d']
)
def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture
with pmodel:
pp = pm.sample_posterior_predictive(
[trace[15]],
var_names=['d']
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment