Unverified Commit ad590746 authored by Michael Osthege's avatar Michael Osthege Committed by GitHub

fix & specify type and shapes for plot_gp_dist (#3913)

* fix & specify type and shapes for plot_gp_dist

* warn user about nan samples
closes #3917

* test that UserWarning is triggered when some samples are nan
parent 18f1e513
......@@ -15,6 +15,7 @@
from scipy.cluster.vq import kmeans
import numpy as np
import theano.tensor as tt
import warnings
cholesky = tt.slinalg.cholesky
solve_lower = tt.slinalg.Solve(A_structure='lower_triangular')
......@@ -83,17 +84,19 @@ def conditioned_vars(varnames):
return gp_wrapper
def plot_gp_dist(ax, samples, x, plot_samples=True, palette="Reds", fill_alpha=0.8, samples_alpha=0.1, fill_kwargs=None, samples_kwargs=None):
def plot_gp_dist(ax, samples:np.ndarray, x:np.ndarray, plot_samples=True, palette="Reds", fill_alpha=0.8, samples_alpha=0.1, fill_kwargs=None, samples_kwargs=None):
""" A helper function for plotting 1D GP posteriors from trace
Parameters
----------
ax: axes
Matplotlib axes.
samples: trace or list of traces
Trace(s) or posterior predictive sample from a GP.
x: array
samples: numpy.ndarray
Array of S posterior predictive sample from a GP.
Expected shape: (S, X)
x: numpy.ndarray
Grid of X values corresponding to the samples.
Expected shape: (X,) or (X, 1), or (1, X)
plot_samples: bool
Plot the GP samples along with posterior (defaults True).
palette: str
......@@ -118,6 +121,12 @@ def plot_gp_dist(ax, samples, x, plot_samples=True, palette="Reds", fill_alpha=0
fill_kwargs = {}
if samples_kwargs is None:
samples_kwargs = {}
if np.any(np.isnan(samples)):
warnings.warn(
'There are `nan` entries in the [samples] arguments. '
'The plot will not contain a band!',
UserWarning
)
cmap = plt.get_cmap(palette)
percs = np.linspace(51, 99, 40)
......
......@@ -247,6 +247,7 @@ class TestCovProd:
with pytest.raises(ValueError, match=r"cannot combine"):
cov = M + pm.gp.cov.ExpQuad(1, 1.)
class TestCovExponentiation:
def test_symexp_cov(self):
X = np.linspace(0, 1, 10)[:, None]
......@@ -539,6 +540,7 @@ class TestMatern12:
Kd = theano.function([],cov(X, diag=True))()
npt.assert_allclose(np.diag(K), Kd, atol=1e-5)
class TestCosine:
def test_1d(self):
X = np.linspace(0, 1, 10)[:, None]
......@@ -1142,3 +1144,36 @@ class TestMarginalKron:
cov_funcs=self.cov_funcs)
with pytest.raises(TypeError):
gp1 + gp2
class TestUtil:
def test_plot_gp_dist(self):
"""Test that the plotting helper works with the stated input shapes."""
import matplotlib.pyplot as plt
X = 100
S = 500
fig, ax = plt.subplots()
pm.gp.util.plot_gp_dist(
ax,
x=np.linspace(0, 50, X),
samples=np.random.normal(np.arange(X), size=(S, X))
)
plt.close()
pass
def test_plot_gp_dist_warn_nan(self):
"""Test that the plotting helper works with the stated input shapes."""
import matplotlib.pyplot as plt
X = 100
S = 500
samples = np.random.normal(np.arange(X), size=(S, X))
samples[15, 3] = np.nan
fig, ax = plt.subplots()
with pytest.warns(UserWarning):
pm.gp.util.plot_gp_dist(
ax,
x=np.linspace(0, 50, X),
samples=samples
)
plt.close()
pass
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment