Commit 747db639 authored by Osvaldo Martin's avatar Osvaldo Martin Committed by GitHub

SMC: refactor, speed-up and run multiple chains in parallel for diagnostics (#3981)

* first attempt to vectorize smc kernel

* add ess, remove multiprocessing

* run multiple chains

* remove unused imports

* add more info to report

* minor fix

* test log

* fix type_num error

* remove unused imports update BF notebook

* update notebook with diagnostics

* update notebooks

* update notebook

* update notebook
parent facbdf14
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -14,25 +14,37 @@
import time
import logging
import warnings
from collections.abc import Iterable
import multiprocessing as mp
import numpy as np
from .smc import SMC
from ..model import modelcontext
from ..backends.base import MultiTrace
from ..parallel_sampling import _cpu_count
EXPERIMENTAL_WARNING = (
"Warning: SMC-ABC is an experimental step method and not yet recommended for use in PyMC3!"
)
def sample_smc(
draws=1000,
draws=2000,
kernel="metropolis",
n_steps=25,
parallel=False,
start=None,
cores=None,
tune_steps=True,
p_acc_rate=0.99,
threshold=0.5,
epsilon=1.0,
dist_func="gaussian_kernel",
sum_stat="identity",
progressbar=False,
model=None,
random_seed=-1,
parallel=False,
chains=None,
cores=None,
):
r"""
Sequential Monte Carlo based sampling
......@@ -49,15 +61,9 @@ def sample_smc(
The number of steps of each Markov Chain. If ``tune_steps == True`` ``n_steps`` will be used
for the first stage and for the others it will be determined automatically based on the
acceptance rate and `p_acc_rate`, the max number of steps is ``n_steps``.
parallel: bool
Distribute computations across cores if the number of cores is larger than 1.
Defaults to False.
start: dict, or array of dict
Starting point in parameter space. It should be a list of dict with length `chains`.
When None (default) the starting point is sampled from the prior distribution.
cores: int
The number of chains to run in parallel. If ``None`` (default), it will be automatically
set to the number of CPUs in the system.
tune_steps: bool
Whether to compute the number of steps automatically or not. Defaults to True
p_acc_rate: float
......@@ -75,11 +81,19 @@ def sample_smc(
sum_stat: str or callable
Summary statistics. Available options are ``indentity``, ``sorted``, ``mean``, ``median``.
If a callable is based it should return a number or a 1d numpy array.
progressbar: bool
Flag for displaying a progress bar. Defaults to False.
model: Model (optional if in ``with`` context)).
random_seed: int
random seed
parallel: bool
Distribute computations across cores if the number of cores is larger than 1.
Defaults to False.
cores : int
The number of chains to run in parallel. If ``None``, set to the number of CPUs in the
system, but at most 4.
chains : int
The number of chains to sample. Running independent chains is important for some
convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever
is larger.
Notes
-----
......@@ -126,52 +140,126 @@ def sample_smc(
%282007%29133:7%28816%29>`__
"""
_log = logging.getLogger("pymc3")
_log.info("Initializing SMC sampler...")
if cores is None:
cores = _cpu_count()
if chains is None:
chains = max(2, cores)
_log.info(f"Multiprocess sampling ({chains} chains in {cores} jobs)")
if random_seed == -1:
random_seed = None
if chains == 1 and isinstance(random_seed, int):
random_seed = [random_seed]
if random_seed is None or isinstance(random_seed, int):
if random_seed is not None:
np.random.seed(random_seed)
random_seed = [np.random.randint(2 ** 30) for _ in range(chains)]
if not isinstance(random_seed, Iterable):
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")
if kernel.lower() == "abc":
warnings.warn(EXPERIMENTAL_WARNING)
if len(modelcontext(model).observed_RVs) != 1:
warnings.warn("SMC-ABC only works properly with models with one observed variable")
params = (
draws,
kernel,
n_steps,
start,
tune_steps,
p_acc_rate,
threshold,
epsilon,
dist_func,
sum_stat,
model,
)
t1 = time.time()
if parallel:
loggers = [_log] + [None] * (chains - 1)
pool = mp.Pool(cores)
results = pool.starmap(
sample_smc_int, [(*params, random_seed[i], i, loggers[i]) for i in range(chains)]
)
pool.close()
pool.join()
else:
results = []
for i in range(chains):
results.append((sample_smc_int(*params, random_seed[i], i, _log)))
traces, log_marginal_likelihoods, betas, accept_ratios, nsteps = zip(*results)
trace = MultiTrace(traces)
trace.report._n_draws = draws
trace.report._n_tune = 0
trace.report._t_sampling = time.time() - t1
trace.report.log_marginal_likelihood = np.array(log_marginal_likelihoods)
trace.report.betas = betas
trace.report.accept_ratios = accept_ratios
trace.report.nsteps = nsteps
return trace
def sample_smc_int(
draws,
kernel,
n_steps,
start,
tune_steps,
p_acc_rate,
threshold,
epsilon,
dist_func,
sum_stat,
model,
random_seed,
chain,
_log,
):
smc = SMC(
draws=draws,
kernel=kernel,
n_steps=n_steps,
parallel=parallel,
start=start,
cores=cores,
tune_steps=tune_steps,
p_acc_rate=p_acc_rate,
threshold=threshold,
epsilon=epsilon,
dist_func=dist_func,
sum_stat=sum_stat,
progressbar=progressbar,
model=model,
random_seed=random_seed,
chain=chain,
)
t1 = time.time()
_log = logging.getLogger("pymc3")
_log.info("Sample initial stage: ...")
stage = 0
betas = []
accept_ratios = []
nsteps = []
smc.initialize_population()
smc.setup_kernel()
smc.initialize_logp()
while smc.beta < 1:
smc.update_weights_beta()
_log.info(
"Stage: {:3d} Beta: {:.3f} Steps: {:3d} Acce: {:.3f}".format(
stage, smc.beta, smc.n_steps, smc.acc_rate
)
)
smc.resample()
if _log is not None:
_log.info(f"Stage: {stage:3d} Beta: {smc.beta:.3f}")
smc.update_proposal()
if stage > 0:
smc.tune()
smc.resample()
smc.mutate()
smc.tune()
stage += 1
betas.append(smc.beta)
accept_ratios.append(smc.acc_rate)
nsteps.append(smc.n_steps)
if smc.parallel and smc.cores > 1:
smc.pool.close()
smc.pool.join()
trace = smc.posterior_to_trace()
trace.report._n_draws = smc.draws
trace.report._n_tune = 0
trace.report._t_sampling = time.time() - t1
return trace
return smc.posterior_to_trace(), smc.log_marginal_likelihood, betas, accept_ratios, nsteps
This diff is collapsed.
......@@ -79,9 +79,9 @@ class TestSMC(SeededTest):
a = pm.Beta("a", alpha, beta)
y = pm.Bernoulli("y", a, observed=data)
trace = pm.sample_smc(2000)
marginals.append(model.marginal_log_likelihood)
marginals.append(trace.report.log_marginal_likelihood)
# compare to the analytical result
assert abs(np.exp(marginals[1] - marginals[0]) - 4.0) <= 1
assert abs(np.exp(np.mean(marginals[1]) - np.mean(marginals[0])) - 4.0) <= 1
def test_start(self):
with pm.Model() as model:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment