diff --git a/pymc3/gp/util.py b/pymc3/gp/util.py index 675679d8aff07529b9188adf50dcfb8ccbd1d06b..499b9729fdb072cda95a37f4f154b0f375f6c3df 100644 --- a/pymc3/gp/util.py +++ b/pymc3/gp/util.py @@ -15,6 +15,7 @@ from scipy.cluster.vq import kmeans import numpy as np import theano.tensor as tt +import warnings cholesky = tt.slinalg.cholesky solve_lower = tt.slinalg.Solve(A_structure='lower_triangular') @@ -83,17 +84,19 @@ def conditioned_vars(varnames): return gp_wrapper -def plot_gp_dist(ax, samples, x, plot_samples=True, palette="Reds", fill_alpha=0.8, samples_alpha=0.1, fill_kwargs=None, samples_kwargs=None): +def plot_gp_dist(ax, samples:np.ndarray, x:np.ndarray, plot_samples=True, palette="Reds", fill_alpha=0.8, samples_alpha=0.1, fill_kwargs=None, samples_kwargs=None): """ A helper function for plotting 1D GP posteriors from trace Parameters ---------- ax: axes Matplotlib axes. - samples: trace or list of traces - Trace(s) or posterior predictive sample from a GP. - x: array + samples: numpy.ndarray + Array of S posterior predictive sample from a GP. + Expected shape: (S, X) + x: numpy.ndarray Grid of X values corresponding to the samples. + Expected shape: (X,) or (X, 1), or (1, X) plot_samples: bool Plot the GP samples along with posterior (defaults True). palette: str @@ -118,6 +121,12 @@ def plot_gp_dist(ax, samples, x, plot_samples=True, palette="Reds", fill_alpha=0 fill_kwargs = {} if samples_kwargs is None: samples_kwargs = {} + if np.any(np.isnan(samples)): + warnings.warn( + 'There are `nan` entries in the [samples] arguments. ' + 'The plot will not contain a band!', + UserWarning + ) cmap = plt.get_cmap(palette) percs = np.linspace(51, 99, 40) diff --git a/pymc3/tests/test_gp.py b/pymc3/tests/test_gp.py index 93f44b0c01028b618d58e23e10ba73a14f3037e1..75c2bdbc57a8b6e906feaa9f7649a89d767d6468 100644 --- a/pymc3/tests/test_gp.py +++ b/pymc3/tests/test_gp.py @@ -247,6 +247,7 @@ class TestCovProd: with pytest.raises(ValueError, match=r"cannot combine"): cov = M + pm.gp.cov.ExpQuad(1, 1.) + class TestCovExponentiation: def test_symexp_cov(self): X = np.linspace(0, 1, 10)[:, None] @@ -539,6 +540,7 @@ class TestMatern12: Kd = theano.function([],cov(X, diag=True))() npt.assert_allclose(np.diag(K), Kd, atol=1e-5) + class TestCosine: def test_1d(self): X = np.linspace(0, 1, 10)[:, None] @@ -1142,3 +1144,36 @@ class TestMarginalKron: cov_funcs=self.cov_funcs) with pytest.raises(TypeError): gp1 + gp2 + + +class TestUtil: + def test_plot_gp_dist(self): + """Test that the plotting helper works with the stated input shapes.""" + import matplotlib.pyplot as plt + X = 100 + S = 500 + fig, ax = plt.subplots() + pm.gp.util.plot_gp_dist( + ax, + x=np.linspace(0, 50, X), + samples=np.random.normal(np.arange(X), size=(S, X)) + ) + plt.close() + pass + + def test_plot_gp_dist_warn_nan(self): + """Test that the plotting helper works with the stated input shapes.""" + import matplotlib.pyplot as plt + X = 100 + S = 500 + samples = np.random.normal(np.arange(X), size=(S, X)) + samples[15, 3] = np.nan + fig, ax = plt.subplots() + with pytest.warns(UserWarning): + pm.gp.util.plot_gp_dist( + ax, + x=np.linspace(0, 50, X), + samples=samples + ) + plt.close() + pass