Commit b5660603 authored by Robert P. Goldman's avatar Robert P. Goldman

Test to replicate issue 3840.

parent 363afc80
......@@ -13,6 +13,7 @@
# limitations under the License.
from itertools import combinations
from typing import Tuple
import numpy as np
try:
......@@ -693,6 +694,16 @@ def test_exec_nuts_init(method):
assert "a" in start[0] and "b_log__" in start[0]
@pytest.fixture(scope="class")
def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]:
with pm.Model() as pmodel:
n = pm.Normal('n')
trace = pm.sample()
with pmodel:
d = pm.Deterministic('d', n * 4)
return pmodel, trace
class TestSamplePriorPredictive(SeededTest):
def test_ignores_observed(self):
observed = np.random.normal(10, 1, size=200)
......@@ -851,3 +862,21 @@ class TestSamplePriorPredictive(SeededTest):
with model:
prior_trace = pm.sample_prior_predictive(5)
assert prior_trace["x"].shape == (5, 3, 1)
class TestSamplePosteriorPredictive:
def test_point_list_arg_bug_fspp(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture
with pmodel:
pp = pm.fast_sample_posterior_predictive(
[trace[15]],
var_names=['d']
)
def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture
with pmodel:
pp = pm.sample_posterior_predictive(
[trace[15]],
var_names=['d']
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment