Unverified Commit c34ae3f5 authored by Alexandre ANDORRA's avatar Alexandre ANDORRA Committed by GitHub

Check that concentration parameters of Dirichlet distribution are all > 0 (#3853)

* Added check that a>0 in Dirichlet

* Cast a as array for tests

* Test a>0 only when a not an RV and convert to array when list

* Added test for init of Dirichlet with negative values

* Added release note

* Resolved conflict in release notes

* Escaped parenthesis in match regexp
parent 0456f397
......@@ -20,6 +20,7 @@
- `pm.sample` now takes 1000 draws and 1000 tuning samples by default, instead of 500 previously (see [#3855](https://github.com/pymc-devs/pymc3/pull/3855)).
- Dropped the outdated 'nuts' initialization method for `pm.sample` (see [#3863](https://github.com/pymc-devs/pymc3/pull/3863)).
- Moved argument division out of `NegativeBinomial` `random` method. Fixes [#3864](https://github.com/pymc-devs/pymc3/issues/3864) in the style of [#3509](https://github.com/pymc-devs/pymc3/pull/3509).
- The Dirichlet distribution now raises a ValueError when it's initialized with <= 0 values (see [#3853](https://github.com/pymc-devs/pymc3/pull/3853)).
## PyMC3 3.8 (November 29 2019)
......
......@@ -488,6 +488,16 @@ class Dirichlet(Continuous):
def __init__(self, a, transform=transforms.stick_breaking,
*args, **kwargs):
if not isinstance(a, pm.model.TensorVariable):
if not isinstance(a, list) and not isinstance(a, np.ndarray):
raise TypeError(
'The vector of concentration parameters (a) must be a python list '
'or numpy array.')
a = np.array(a)
if (a <= 0).any():
raise ValueError("All concentration parameters (a) must be > 0.")
shape = np.atleast_1d(a.shape)[-1]
kwargs.setdefault("shape", shape)
......
......@@ -944,17 +944,43 @@ class TestMatchesScipy(SeededTest):
@pytest.mark.parametrize('n', [2, 3])
def test_dirichlet(self, n):
self.pymc3_matches_scipy(Dirichlet, Simplex(
n), {'a': Vector(Rplus, n)}, dirichlet_logpdf)
self.pymc3_matches_scipy(
Dirichlet,
Simplex(n),
{'a': Vector(Rplus, n)},
dirichlet_logpdf
)
@pytest.mark.parametrize('n', [3, 4])
def test_dirichlet_init_fail(self, n):
with Model():
with pytest.raises(
ValueError,
match=r"All concentration parameters \(a\) must be > 0."
):
_ = Dirichlet('x', a=np.zeros(n), shape=n)
with pytest.raises(
ValueError,
match=r"All concentration parameters \(a\) must be > 0."
):
_ = Dirichlet('x', a=np.array([-1.] * n), shape=n)
def test_dirichlet_2D(self):
self.pymc3_matches_scipy(Dirichlet, MultiSimplex(2, 2),
{'a': Vector(Vector(Rplus, 2), 2)}, dirichlet_logpdf)
self.pymc3_matches_scipy(
Dirichlet,
MultiSimplex(2, 2),
{'a': Vector(Vector(Rplus, 2), 2)},
dirichlet_logpdf
)
@pytest.mark.parametrize('n', [2, 3])
def test_multinomial(self, n):
self.pymc3_matches_scipy(Multinomial, Vector(Nat, n), {'p': Simplex(n), 'n': Nat},
multinomial_logpdf)
self.pymc3_matches_scipy(
Multinomial,
Vector(Nat, n),
{'p': Simplex(n), 'n': Nat},
multinomial_logpdf
)
@pytest.mark.parametrize('p,n', [
[[.25, .25, .25, .25], 1],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment