Unverified Commit dc574b7a authored by Osvaldo Martin's avatar Osvaldo Martin Committed by GitHub

improve ABC sampler (#3940)

* Expand ABC features.

* valueerror

* update notebook

* remove unused import update release notes

* fix notebook style and change order params argument
parent c6bba803
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
- `pm.Data` container can now be used as input for other random variables (issue [#3842](https://github.com/pymc-devs/pymc3/issues/3842), fixed by [#3925](https://github.com/pymc-devs/pymc3/pull/3925)). - `pm.Data` container can now be used as input for other random variables (issue [#3842](https://github.com/pymc-devs/pymc3/issues/3842), fixed by [#3925](https://github.com/pymc-devs/pymc3/pull/3925)).
- Plots and Stats API sections now link to ArviZ documentation [#3927](https://github.com/pymc-devs/pymc3/pull/3927) - Plots and Stats API sections now link to ArviZ documentation [#3927](https://github.com/pymc-devs/pymc3/pull/3927)
- Add `SamplerReport` with properties `n_draws`, `t_sampling` and `n_tune` to SMC. `n_tune` is always 0 [#3931](https://github.com/pymc-devs/pymc3/issues/3931). - Add `SamplerReport` with properties `n_draws`, `t_sampling` and `n_tune` to SMC. `n_tune` is always 0 [#3931](https://github.com/pymc-devs/pymc3/issues/3931).
- SMC-ABC: add option to define summary statistics, allow to sample from more complex models, remove redundant distances [#3940](https://github.com/pymc-devs/pymc3/issues/3940)
### Maintenance ### Maintenance
- Tuning results no longer leak into sequentially sampled `Metropolis` chains (see #3733 and #3796). - Tuning results no longer leak into sequentially sampled `Metropolis` chains (see #3733 and #3796).
......
...@@ -19,17 +19,20 @@ __all__ = ["Simulator"] ...@@ -19,17 +19,20 @@ __all__ = ["Simulator"]
class Simulator(NoDistribution): class Simulator(NoDistribution):
def __init__(self, function, *args, **kwargs): def __init__(self, function, *args, params=None, **kwargs):
""" """
This class stores a function defined by the user in python language. This class stores a function defined by the user in python language.
function: function function: function
Simulation function defined by the user. Simulation function defined by the user.
params: list
Parameters passed to function.
*args and **kwargs: *args and **kwargs:
Arguments and keywords arguments that the function takes. Arguments and keywords arguments that the function takes.
""" """
self.function = function self.function = function
self.params = params
observed = self.data observed = self.data
super().__init__(shape=np.prod(observed.shape), dtype=observed.dtype, *args, **kwargs) super().__init__(shape=np.prod(observed.shape), dtype=observed.dtype, *args, **kwargs)
......
...@@ -28,8 +28,8 @@ def sample_smc( ...@@ -28,8 +28,8 @@ def sample_smc(
p_acc_rate=0.99, p_acc_rate=0.99,
threshold=0.5, threshold=0.5,
epsilon=1.0, epsilon=1.0,
dist_func="absolute_error", dist_func="gaussian_kernel",
sum_stat=False, sum_stat="identity",
progressbar=False, progressbar=False,
model=None, model=None,
random_seed=-1, random_seed=-1,
...@@ -71,11 +71,10 @@ def sample_smc( ...@@ -71,11 +71,10 @@ def sample_smc(
epsilon: float epsilon: float
Standard deviation of the gaussian pseudo likelihood. Only works with `kernel = ABC` Standard deviation of the gaussian pseudo likelihood. Only works with `kernel = ABC`
dist_func: str dist_func: str
Distance function. Available options are ``absolute_error`` (default) and Distance function. The only available option is ``gaussian_kernel``
``sum_of_squared_distance``. Only works with ``kernel = ABC`` sum_stat: str or callable
sum_stat: bool Summary statistics. Available options are ``indentity``, ``sorted``, ``mean``, ``median``.
Whether to use or not a summary statistics. Defaults to False. Only works with If a callable is based it should return a number or a 1d numpy array.
``kernel = ABC``
progressbar: bool progressbar: bool
Flag for displaying a progress bar. Defaults to False. Flag for displaying a progress bar. Defaults to False.
model: Model (optional if in ``with`` context)). model: Model (optional if in ``with`` context)).
......
...@@ -31,7 +31,6 @@ from ..step_methods.arraystep import metrop_select ...@@ -31,7 +31,6 @@ from ..step_methods.arraystep import metrop_select
from ..step_methods.metropolis import MultivariateNormalProposal from ..step_methods.metropolis import MultivariateNormalProposal
from ..backends.ndarray import NDArray from ..backends.ndarray import NDArray
from ..backends.base import MultiTrace from ..backends.base import MultiTrace
from ..util import is_transformed_name
EXPERIMENTAL_WARNING = ( EXPERIMENTAL_WARNING = (
"Warning: SMC-ABC methods are experimental step methods and not yet" "Warning: SMC-ABC methods are experimental step methods and not yet"
...@@ -53,7 +52,7 @@ class SMC: ...@@ -53,7 +52,7 @@ class SMC:
threshold=0.5, threshold=0.5,
epsilon=1.0, epsilon=1.0,
dist_func="absolute_error", dist_func="absolute_error",
sum_stat=False, sum_stat="Identity",
progressbar=False, progressbar=False,
model=None, model=None,
random_seed=-1, random_seed=-1,
...@@ -140,6 +139,7 @@ class SMC: ...@@ -140,6 +139,7 @@ class SMC:
self.epsilon, self.epsilon,
simulator.observations, simulator.observations,
simulator.distribution.function, simulator.distribution.function,
[v.name for v in simulator.distribution.params],
self.model, self.model,
self.var_info, self.var_info,
self.variables, self.variables,
...@@ -281,7 +281,7 @@ class SMC: ...@@ -281,7 +281,7 @@ class SMC:
self.priors[draw], self.priors[draw],
self.likelihoods[draw], self.likelihoods[draw],
draw, draw,
*parameters *parameters,
) )
for draw in iterator for draw in iterator
] ]
...@@ -307,7 +307,7 @@ class SMC: ...@@ -307,7 +307,7 @@ class SMC:
size = 0 size = 0
for var in varnames: for var in varnames:
shape, new_size = self.var_info[var] shape, new_size = self.var_info[var]
value.append(self.posterior[i][size: size + new_size].reshape(shape)) value.append(self.posterior[i][size : size + new_size].reshape(shape))
size += new_size size += new_size
strace.record({k: v for k, v in zip(varnames, value)}) strace.record({k: v for k, v in zip(varnames, value)})
return MultiTrace([strace]) return MultiTrace([strace])
...@@ -389,7 +389,16 @@ class PseudoLikelihood: ...@@ -389,7 +389,16 @@ class PseudoLikelihood:
""" """
def __init__( def __init__(
self, epsilon, observations, function, model, var_info, variables, distance, sum_stat self,
epsilon,
observations,
function,
params,
model,
var_info,
variables,
distance,
sum_stat,
): ):
""" """
epsilon: float epsilon: float
...@@ -398,34 +407,48 @@ class PseudoLikelihood: ...@@ -398,34 +407,48 @@ class PseudoLikelihood:
observed data observed data
function: python function function: python function
data simulator data simulator
params: list
names of the variables parameterizing the simulator.
model: PyMC3 model model: PyMC3 model
var_info: dict var_info: dict
generated by ``SMC.initialize_population`` generated by ``SMC.initialize_population``
distance: str distance : str or callable
Distance function. Available options are ``absolute_error`` (default) and Distance function. The only available option is ``gaussian_kernel``
``sum_of_squared_distance``. sum_stat: str or callable
sum_stat: bool Summary statistics. Available options are ``indentity``, ``sorted``, ``mean``,
Whether to use or not a summary statistics. ``median``. The user can pass any valid Python function
""" """
self.epsilon = epsilon self.epsilon = epsilon
self.observations = observations
self.function = function self.function = function
self.params = params
self.model = model self.model = model
self.var_info = var_info self.var_info = var_info
self.variables = variables self.variables = variables
self.varnames = [v.name for v in self.variables] self.varnames = [v.name for v in self.variables]
self.unobserved_RVs = [v.name for v in self.model.unobserved_RVs] self.unobserved_RVs = [v.name for v in self.model.unobserved_RVs]
self.kernel = self.gauss_kernel
self.dist_func = distance
self.sum_stat = sum_stat
self.get_unobserved_fn = self.model.fastfn(self.model.unobserved_RVs) self.get_unobserved_fn = self.model.fastfn(self.model.unobserved_RVs)
if distance == "absolute_error": if sum_stat == "identity":
self.dist_func = self.absolute_error self.sum_stat = lambda x: x
elif distance == "sum_of_squared_distance": elif sum_stat == "sorted":
self.dist_func = self.sum_of_squared_distance self.sum_stat = np.sort
elif sum_stat == "mean":
self.sum_stat = np.mean
elif sum_stat == "median":
self.sum_stat = np.median
elif hasattr(sum_stat, "__call__"):
self.sum_stat = sum_stat
else:
raise ValueError(f"The summary statistics {sum_stat} is not implemented")
self.observations = self.sum_stat(observations)
if distance == "gaussian_kernel":
self.distance = self.gaussian_kernel
elif hasattr(distance, "__call__"):
self.distance = distance
else: else:
raise ValueError("Distance metric not understood") raise ValueError(f"The distance metric {distance} is not implemented")
def posterior_to_function(self, posterior): def posterior_to_function(self, posterior):
model = self.model model = self.model
...@@ -436,32 +459,18 @@ class PseudoLikelihood: ...@@ -436,32 +459,18 @@ class PseudoLikelihood:
size = 0 size = 0
for var in self.variables: for var in self.variables:
shape, new_size = var_info[var.name] shape, new_size = var_info[var.name]
varvalues.append(posterior[size: size + new_size].reshape(shape)) varvalues.append(posterior[size : size + new_size].reshape(shape))
size += new_size size += new_size
point = {k: v for k, v in zip(self.varnames, varvalues)} point = {k: v for k, v in zip(self.varnames, varvalues)}
for varname, value in zip(self.unobserved_RVs, self.get_unobserved_fn(point)): for varname, value in zip(self.unobserved_RVs, self.get_unobserved_fn(point)):
if not is_transformed_name(varname): if varname in self.params:
samples[varname] = value samples[varname] = value
return samples return samples
def gauss_kernel(self, value): def gaussian_kernel(self, obs_data, sim_data):
epsilon = self.epsilon return np.sum(-0.5 * ((obs_data - sim_data) / self.epsilon) ** 2)
return (-(value ** 2) / epsilon ** 2 + np.log(1 / (2 * np.pi * epsilon ** 2))) / 2.0
def absolute_error(self, a, b):
if self.sum_stat:
return np.abs(a.mean() - b.mean())
else:
return np.mean(np.atleast_2d(np.abs(a - b)))
def sum_of_squared_distance(self, a, b):
if self.sum_stat:
return np.sum(np.atleast_2d((a.mean() - b.mean()) ** 2))
else:
return np.mean(np.sum(np.atleast_2d((a - b) ** 2)))
def __call__(self, posterior): def __call__(self, posterior):
func_parameters = self.posterior_to_function(posterior) func_parameters = self.posterior_to_function(posterior)
sim_data = self.function(**func_parameters) sim_data = self.sum_stat(self.function(**func_parameters))
value = self.dist_func(self.observations, sim_data) return self.distance(self.observations, sim_data)
return self.kernel(value)
...@@ -98,19 +98,19 @@ class TestSMC(SeededTest): ...@@ -98,19 +98,19 @@ class TestSMC(SeededTest):
class TestSMCABC(SeededTest): class TestSMCABC(SeededTest):
def setup_class(self): def setup_class(self):
super().setup_class() super().setup_class()
self.data = np.sort(np.random.normal(loc=0, scale=1, size=1000)) self.data = np.random.normal(loc=0, scale=1, size=1000)
def normal_sim(a, b): def normal_sim(a, b):
return np.sort(np.random.normal(a, b, 1000)) return np.random.normal(a, b, 1000)
with pm.Model() as self.SMABC_test: with pm.Model() as self.SMABC_test:
a = pm.Normal("a", mu=0, sd=5) a = pm.Normal("a", mu=0, sd=5)
b = pm.HalfNormal("b", sd=2) b = pm.HalfNormal("b", sd=2)
s = pm.Simulator("s", normal_sim, observed=self.data) s = pm.Simulator("s", normal_sim, params=(a, b), observed=self.data)
def test_one_gaussian(self): def test_one_gaussian(self):
with self.SMABC_test: with self.SMABC_test:
trace = pm.sample_smc(draws=2000, kernel="ABC", epsilon=0.1) trace = pm.sample_smc(draws=1000, kernel="ABC", sum_stat="sorted", epsilon=1)
np.testing.assert_almost_equal(self.data.mean(), trace["a"].mean(), decimal=2) np.testing.assert_almost_equal(self.data.mean(), trace["a"].mean(), decimal=2)
np.testing.assert_almost_equal(self.data.std(), trace["b"].mean(), decimal=1) np.testing.assert_almost_equal(self.data.std(), trace["b"].mean(), decimal=1)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment