Commit e88c2f94 authored by Luciano Paz's avatar Luciano Paz Committed by GitHub

Set different almost equal tolerance depending on floatX (#3980)

parent d0de7637
......@@ -23,6 +23,7 @@ import pymc3 as pm
from pymc3.distributions import HalfCauchy, Normal, transforms
from pymc3 import Potential, Deterministic
from pymc3.model import ValueGradFunction
from .helpers import select_by_precision
class NewModel(pm.Model):
......@@ -192,17 +193,33 @@ def test_matrix_multiplication():
tune=0,
compute_convergence_checks=False,
progressbar=False)
decimal = select_by_precision(7, 5)
for point in posterior.points():
npt.assert_almost_equal(point['matrix'] @ point['transformed'],
point['rv_rv'])
npt.assert_almost_equal(np.ones((2, 2)) @ point['transformed'],
point['np_rv'])
npt.assert_almost_equal(point['matrix'] @ np.ones(2),
point['rv_np'])
npt.assert_almost_equal(point['matrix'] @ point['rv_rv'],
point['rv_det'])
npt.assert_almost_equal(point['rv_rv'] @ point['transformed'],
point['det_rv'])
npt.assert_almost_equal(
point['matrix'] @ point['transformed'],
point['rv_rv'],
decimal=decimal,
)
npt.assert_almost_equal(
np.ones((2, 2)) @ point['transformed'],
point['np_rv'],
decimal=decimal,
)
npt.assert_almost_equal(
point['matrix'] @ np.ones(2),
point['rv_np'],
decimal=decimal,
)
npt.assert_almost_equal(
point['matrix'] @ point['rv_rv'],
point['rv_det'],
decimal=decimal,
)
npt.assert_almost_equal(
point['rv_rv'] @ point['transformed'],
point['det_rv'],
decimal=decimal,
)
def test_duplicate_vars():
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment