util.py 4.98 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright 2020 The PyMC Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

15
from scipy.cluster.vq import kmeans
Bill Engels's avatar
Bill Engels committed
16
import numpy as np
Bill's avatar
Bill committed
17
import theano.tensor as tt
18
import warnings
Bill's avatar
Bill committed
19

gBokiau's avatar
gBokiau committed
20
cholesky = tt.slinalg.cholesky
Bill's avatar
Bill committed
21 22 23 24 25 26 27 28 29 30
solve_lower = tt.slinalg.Solve(A_structure='lower_triangular')
solve_upper = tt.slinalg.Solve(A_structure='upper_triangular')
solve = tt.slinalg.Solve(A_structure='general')


def infer_shape(X, n_points=None):
    if n_points is None:
        try:
            n_points = np.int(X.shape[0])
        except TypeError:
Bill Engels's avatar
Bill Engels committed
31
            raise TypeError("Cannot infer 'shape', provide as an argument")
Bill's avatar
Bill committed
32 33 34 35 36 37 38
    return n_points


def stabilize(K):
    """ adds small diagonal to a covariance matrix """
    return K + 1e-6 * tt.identity_like(K)

39 40 41 42 43 44 45 46

def kmeans_inducing_points(n_inducing, X):
    # first whiten X
    if isinstance(X, tt.TensorConstant):
        X = X.value
    elif isinstance(X, (np.ndarray, tuple, list)):
        X = np.asarray(X)
    else:
Bill Engels's avatar
Bill Engels committed
47 48 49 50
        raise TypeError(("To use K-means initialization, "
                         "please provide X as a type that "
                         "can be cast to np.ndarray, instead "
                         "of {}".format(type(X))))
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    scaling = np.std(X, 0)
    # if std of a column is very small (zero), don't normalize that column
    scaling[scaling <= 1e-6] = 1.0
    Xw = X / scaling
    Xu, distortion = kmeans(Xw, n_inducing)
    return Xu * scaling


def conditioned_vars(varnames):
    """ Decorator for validating attrs that are conditioned on. """
    def gp_wrapper(cls):
        def make_getter(name):
            def getter(self):
                value = getattr(self, name, None)
                if value is None:
                    raise AttributeError(("'{}' not set.  Provide as argument "
                                          "to condition, or call 'prior' "
                                          "first".format(name.lstrip("_"))))
                else:
                    return value
                return getattr(self, name)
            return getter

        def make_setter(name):
            def setter(self, val):
                setattr(self, name, val)
            return setter

        for name in varnames:
            getter = make_getter('_' + name)
            setter = make_setter('_' + name)
            setattr(cls, name, property(getter, setter))
        return cls
    return gp_wrapper


87
def plot_gp_dist(ax, samples:np.ndarray, x:np.ndarray, plot_samples=True, palette="Reds", fill_alpha=0.8, samples_alpha=0.1, fill_kwargs=None, samples_kwargs=None):
88 89 90 91
    """ A helper function for plotting 1D GP posteriors from trace 
    
        Parameters
    ----------
92
    ax: axes
93
        Matplotlib axes.
94 95 96 97
    samples: numpy.ndarray
        Array of S posterior predictive sample from a GP.
        Expected shape: (S, X)
    x: numpy.ndarray
98
        Grid of X values corresponding to the samples. 
99
        Expected shape: (X,) or (X, 1), or (1, X)
100 101 102 103
    plot_samples: bool
        Plot the GP samples along with posterior (defaults True).
    palette: str
        Palette for coloring output (defaults to "Reds").
104
    fill_alpha: float
105
        Alpha value for the posterior interval fill (defaults to 0.8).
106
    samples_alpha: float
107
        Alpha value for the sample lines (defaults to 0.1).
108
    fill_kwargs: dict
109
        Additional arguments for posterior interval fill (fill_between).
110
    samples_kwargs: dict
111 112 113 114 115
        Additional keyword arguments for samples plot.

    Returns
    -------

116
    ax: Matplotlib axes
117
    """
Bill Engels's avatar
Bill Engels committed
118 119
    import matplotlib.pyplot as plt

120 121 122 123
    if fill_kwargs is None:
        fill_kwargs = {}
    if samples_kwargs is None:
        samples_kwargs = {}
124 125 126 127 128 129
    if np.any(np.isnan(samples)):
        warnings.warn(
            'There are `nan` entries in the [samples] arguments. '
            'The plot will not contain a band!',
            UserWarning
        )
130

Bill Engels's avatar
Bill Engels committed
131 132 133 134 135 136 137 138 139
    cmap = plt.get_cmap(palette)
    percs = np.linspace(51, 99, 40)
    colors = (percs - np.min(percs)) / (np.max(percs) - np.min(percs))
    samples = samples.T
    x = x.flatten()
    for i, p in enumerate(percs[::-1]):
        upper = np.percentile(samples, p, axis=1)
        lower = np.percentile(samples, 100-p, axis=1)
        color_val = colors[i]
140
        ax.fill_between(x, upper, lower, color=cmap(color_val), alpha=fill_alpha, **fill_kwargs)
Bill Engels's avatar
Bill Engels committed
141 142 143
    if plot_samples:
        # plot a few samples
        idx = np.random.randint(0, samples.shape[1], 30)
144 145 146 147
        ax.plot(x, samples[:,idx], color=cmap(0.9), lw=1, alpha=samples_alpha,
                **samples_kwargs)

    return ax