Unverified Commit 40d9597a authored by Ahan M R's avatar Ahan M R Committed by GitHub

Improve documentation for distributions (#3837)

* partially fixes #3688

* fixes #3688 sd deprecated to sigma
parent e2979a34
......@@ -476,6 +476,10 @@ class Normal(Continuous):
def __init__(self, mu=0, sigma=None, tau=None, sd=None, **kwargs):
if sd is not None:
sigma = sd
warnings.warn(
"sd is deprecated, use sigma instead",
DeprecationWarning
)
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
self.sigma = self.sd = tt.as_tensor_variable(sigma)
self.tau = tt.as_tensor_variable(tau)
......@@ -636,6 +640,10 @@ class TruncatedNormal(BoundedContinuous):
transform='auto', sd=None, *args, **kwargs):
if sd is not None:
sigma = sd
warnings.warn(
"sd is deprecated, use sigma instead",
DeprecationWarning
)
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
self.sigma = self.sd = tt.as_tensor_variable(sigma)
self.tau = tt.as_tensor_variable(tau)
......@@ -839,7 +847,10 @@ class HalfNormal(PositiveContinuous):
def __init__(self, sigma=None, tau=None, sd=None, *args, **kwargs):
if sd is not None:
sigma = sd
warnings.warn(
"sd is deprecated, use sigma instead",
DeprecationWarning
)
super().__init__(*args, **kwargs)
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
......@@ -1232,6 +1243,10 @@ class Beta(UnitContinuous):
super().__init__(*args, **kwargs)
if sd is not None:
sigma = sd
warnings.warn(
"sd is deprecated, use sigma instead",
DeprecationWarning
)
alpha, beta = self.get_alpha_beta(alpha, beta, mu, sigma)
self.alpha = alpha = tt.as_tensor_variable(floatX(alpha))
self.beta = beta = tt.as_tensor_variable(floatX(beta))
......@@ -1788,6 +1803,10 @@ class Lognormal(PositiveContinuous):
super().__init__(*args, **kwargs)
if sd is not None:
sigma = sd
warnings.warn(
"sd is deprecated, use sigma instead",
DeprecationWarning
)
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
......@@ -1959,6 +1978,10 @@ class StudentT(Continuous):
super(StudentT, self).__init__(*args, **kwargs)
if sd is not None:
sigma = sd
warnings.warn(
"sd is deprecated, use sigma instead",
DeprecationWarning
)
self.nu = nu = tt.as_tensor_variable(floatX(nu))
lam, sigma = get_tau_sigma(tau=lam, sigma=sigma)
self.lam = lam = tt.as_tensor_variable(lam)
......@@ -2519,6 +2542,10 @@ class Gamma(PositiveContinuous):
super().__init__(*args, **kwargs)
if sd is not None:
sigma = sd
warnings.warn(
"sd is deprecated, use sigma instead",
DeprecationWarning
)
alpha, beta = self.get_alpha_beta(alpha, beta, mu, sigma)
self.alpha = alpha = tt.as_tensor_variable(floatX(alpha))
......@@ -2677,6 +2704,10 @@ class InverseGamma(PositiveContinuous):
if sd is not None:
sigma = sd
warnings.warn(
"sd is deprecated, use sigma instead",
DeprecationWarning
)
alpha, beta = InverseGamma._get_alpha_beta(alpha, beta, mu, sigma)
self.alpha = alpha = tt.as_tensor_variable(floatX(alpha))
......@@ -3032,6 +3063,10 @@ class HalfStudentT(PositiveContinuous):
super().__init__(*args, **kwargs)
if sd is not None:
sigma = sd
warnings.warn(
"sd is deprecated, use sigma instead",
DeprecationWarning
)
self.mode = tt.as_tensor_variable(0)
lam, sigma = get_tau_sigma(lam, sigma)
......@@ -3172,6 +3207,10 @@ class ExGaussian(Continuous):
if sd is not None:
sigma = sd
warnings.warn(
"sd is deprecated, use sigma instead",
DeprecationWarning
)
self.mu = mu = tt.as_tensor_variable(floatX(mu))
self.sigma = self.sd = sigma = tt.as_tensor_variable(floatX(sigma))
......@@ -3456,6 +3495,10 @@ class SkewNormal(Continuous):
if sd is not None:
sigma = sd
warnings.warn(
"sd is deprecated, use sigma instead",
DeprecationWarning
)
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
self.mu = mu = tt.as_tensor_variable(floatX(mu))
......@@ -3877,6 +3920,10 @@ class Rice(PositiveContinuous):
super().__init__(*args, **kwargs)
if sd is not None:
sigma = sd
warnings.warn(
"sd is deprecated, use sigma instead",
DeprecationWarning
)
nu, b, sigma = self.get_nu_b(nu, b, sigma)
self.nu = nu = tt.as_tensor_variable(floatX(nu))
......@@ -4148,6 +4195,10 @@ class LogitNormal(UnitContinuous):
def __init__(self, mu=0, sigma=None, tau=None, sd=None, **kwargs):
if sd is not None:
sigma = sd
warnings.warn(
"sd is deprecated, use sigma instead",
DeprecationWarning
)
self.mu = mu = tt.as_tensor_variable(floatX(mu))
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
self.sigma = self.sd = tt.as_tensor_variable(sigma)
......
......@@ -16,6 +16,7 @@ from collections.abc import Iterable
import numpy as np
import theano
import theano.tensor as tt
import warnings
from pymc3.util import get_variable_name
from ..math import logsumexp
......@@ -610,6 +611,10 @@ class NormalMixture(Mixture):
def __init__(self, w, mu, sigma=None, tau=None, sd=None, comp_shape=(), *args, **kwargs):
if sd is not None:
sigma = sd
warnings.warn(
"sd is deprecated, use sigma instead",
DeprecationWarning
)
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
self.mu = mu = tt.as_tensor_variable(mu)
......
......@@ -27,7 +27,7 @@ from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh
from theano.tensor.slinalg import Cholesky
import pymc3 as pm
from pymc3.theanof import floatX
from pymc3.theanof import floatX, intX
from . import transforms
from pymc3.util import get_variable_name
from .distribution import (Continuous, Discrete, draw_values, generate_samples,
......@@ -327,7 +327,7 @@ class MvNormal(_QuadFormBase):
TensorVariable
"""
quaddist, logdet, ok = self._quaddist(value)
k = value.shape[-1].astype(theano.config.floatX)
k = intX(value.shape[-1]).astype(theano.config.floatX)
norm = - 0.5 * k * pm.floatX(np.log(2 * np.pi))
return bound(norm - 0.5 * quaddist - logdet, ok)
......@@ -441,7 +441,7 @@ class MvStudentT(_QuadFormBase):
TensorVariable
"""
quaddist, logdet, ok = self._quaddist(value)
k = value.shape[-1].astype(theano.config.floatX)
k = intX(value.shape[-1]).astype(theano.config.floatX)
norm = (gammaln((self.nu + k) / 2.)
- gammaln(self.nu / 2.)
......
......@@ -123,6 +123,10 @@ class AR(distribution.Continuous):
super().__init__(*args, **kwargs)
if sd is not None:
sigma = sd
warnings.warn(
"sd is deprecated, use sigma instead",
DeprecationWarning
)
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
self.sigma = self.sd = tt.as_tensor_variable(sigma)
......@@ -210,6 +214,10 @@ class GaussianRandomWalk(distribution.Continuous):
raise TypeError("GaussianRandomWalk must be supplied a non-zero shape argument!")
if sd is not None:
sigma = sd
warnings.warn(
"sd is deprecated, use sigma instead",
DeprecationWarning
)
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
self.tau = tt.as_tensor_variable(tau)
sigma = tt.as_tensor_variable(sigma)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment