Commit fd79d1f2 authored by Tirth Patel's avatar Tirth Patel

fir array_wrap to support any dimensional input

parent 18a2c3bf
......@@ -116,12 +116,17 @@ class Covariance:
return Exponentiated(self, other)
raise ValueError("A covariance function can only be exponentiated by a scalar value")
def __array_wrap__(self, result):
"""
Required to allow radd/rmul by numpy arrays.
"""
result = np.squeeze(result)
if len(result.shape) <= 1:
result = result.reshape(1, 1)
elif len(result.shape) > 2:
raise ValueError(f"cannot combine a covariance function with array of shape {result.shape}")
r, c = result.shape
A = np.zeros((r, c))
for i in range(r):
......
......@@ -165,6 +165,11 @@ class TestCovAdd:
K_true = theano.function([], cov_true(X))()
assert np.allclose(K, K_true)
def test_inv_rightadd(self):
M = np.random.randn(2, 2, 2)
with pytest.raises(ValueError, match=r"cannot combine"):
cov = M + pm.gp.cov.ExpQuad(1, 1.)
class TestCovProd:
def test_symprod_cov(self):
......@@ -237,6 +242,11 @@ class TestCovProd:
npt.assert_allclose(np.diag(K1), K2d, atol=1e-5)
npt.assert_allclose(np.diag(K2), K1d, atol=1e-5)
def test_inv_rightprod(self):
M = np.random.randn(2, 2, 2)
with pytest.raises(ValueError, match=r"cannot combine"):
cov = M + pm.gp.cov.ExpQuad(1, 1.)
class TestCovExponentiation:
def test_symexp_cov(self):
X = np.linspace(0, 1, 10)[:, None]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment