Unverified Commit ad590746 authored by Michael Osthege's avatar Michael Osthege Committed by GitHub

fix & specify type and shapes for plot_gp_dist (#3913)

* fix & specify type and shapes for plot_gp_dist

* warn user about nan samples
closes #3917

* test that UserWarning is triggered when some samples are nan
parent 18f1e513
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from scipy.cluster.vq import kmeans from scipy.cluster.vq import kmeans
import numpy as np import numpy as np
import theano.tensor as tt import theano.tensor as tt
import warnings
cholesky = tt.slinalg.cholesky cholesky = tt.slinalg.cholesky
solve_lower = tt.slinalg.Solve(A_structure='lower_triangular') solve_lower = tt.slinalg.Solve(A_structure='lower_triangular')
...@@ -83,17 +84,19 @@ def conditioned_vars(varnames): ...@@ -83,17 +84,19 @@ def conditioned_vars(varnames):
return gp_wrapper return gp_wrapper
def plot_gp_dist(ax, samples, x, plot_samples=True, palette="Reds", fill_alpha=0.8, samples_alpha=0.1, fill_kwargs=None, samples_kwargs=None): def plot_gp_dist(ax, samples:np.ndarray, x:np.ndarray, plot_samples=True, palette="Reds", fill_alpha=0.8, samples_alpha=0.1, fill_kwargs=None, samples_kwargs=None):
""" A helper function for plotting 1D GP posteriors from trace """ A helper function for plotting 1D GP posteriors from trace
Parameters Parameters
---------- ----------
ax: axes ax: axes
Matplotlib axes. Matplotlib axes.
samples: trace or list of traces samples: numpy.ndarray
Trace(s) or posterior predictive sample from a GP. Array of S posterior predictive sample from a GP.
x: array Expected shape: (S, X)
x: numpy.ndarray
Grid of X values corresponding to the samples. Grid of X values corresponding to the samples.
Expected shape: (X,) or (X, 1), or (1, X)
plot_samples: bool plot_samples: bool
Plot the GP samples along with posterior (defaults True). Plot the GP samples along with posterior (defaults True).
palette: str palette: str
...@@ -118,6 +121,12 @@ def plot_gp_dist(ax, samples, x, plot_samples=True, palette="Reds", fill_alpha=0 ...@@ -118,6 +121,12 @@ def plot_gp_dist(ax, samples, x, plot_samples=True, palette="Reds", fill_alpha=0
fill_kwargs = {} fill_kwargs = {}
if samples_kwargs is None: if samples_kwargs is None:
samples_kwargs = {} samples_kwargs = {}
if np.any(np.isnan(samples)):
warnings.warn(
'There are `nan` entries in the [samples] arguments. '
'The plot will not contain a band!',
UserWarning
)
cmap = plt.get_cmap(palette) cmap = plt.get_cmap(palette)
percs = np.linspace(51, 99, 40) percs = np.linspace(51, 99, 40)
......
...@@ -247,6 +247,7 @@ class TestCovProd: ...@@ -247,6 +247,7 @@ class TestCovProd:
with pytest.raises(ValueError, match=r"cannot combine"): with pytest.raises(ValueError, match=r"cannot combine"):
cov = M + pm.gp.cov.ExpQuad(1, 1.) cov = M + pm.gp.cov.ExpQuad(1, 1.)
class TestCovExponentiation: class TestCovExponentiation:
def test_symexp_cov(self): def test_symexp_cov(self):
X = np.linspace(0, 1, 10)[:, None] X = np.linspace(0, 1, 10)[:, None]
...@@ -539,6 +540,7 @@ class TestMatern12: ...@@ -539,6 +540,7 @@ class TestMatern12:
Kd = theano.function([],cov(X, diag=True))() Kd = theano.function([],cov(X, diag=True))()
npt.assert_allclose(np.diag(K), Kd, atol=1e-5) npt.assert_allclose(np.diag(K), Kd, atol=1e-5)
class TestCosine: class TestCosine:
def test_1d(self): def test_1d(self):
X = np.linspace(0, 1, 10)[:, None] X = np.linspace(0, 1, 10)[:, None]
...@@ -1142,3 +1144,36 @@ class TestMarginalKron: ...@@ -1142,3 +1144,36 @@ class TestMarginalKron:
cov_funcs=self.cov_funcs) cov_funcs=self.cov_funcs)
with pytest.raises(TypeError): with pytest.raises(TypeError):
gp1 + gp2 gp1 + gp2
class TestUtil:
def test_plot_gp_dist(self):
"""Test that the plotting helper works with the stated input shapes."""
import matplotlib.pyplot as plt
X = 100
S = 500
fig, ax = plt.subplots()
pm.gp.util.plot_gp_dist(
ax,
x=np.linspace(0, 50, X),
samples=np.random.normal(np.arange(X), size=(S, X))
)
plt.close()
pass
def test_plot_gp_dist_warn_nan(self):
"""Test that the plotting helper works with the stated input shapes."""
import matplotlib.pyplot as plt
X = 100
S = 500
samples = np.random.normal(np.arange(X), size=(S, X))
samples[15, 3] = np.nan
fig, ax = plt.subplots()
with pytest.warns(UserWarning):
pm.gp.util.plot_gp_dist(
ax,
x=np.linspace(0, 50, X),
samples=samples
)
plt.close()
pass
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment