Unverified Commit effe5c4a authored by Tirth Patel's avatar Tirth Patel Committed by GitHub

small changes to speed up find_MAP (#3916)

parent ccc78cbc
......@@ -140,11 +140,11 @@ def find_MAP(
def grad_logp(point):
return nan_to_num(-dlogp_func(point))
opt_result = fmin(cost_func, bij.map(start), fprime=grad_logp, *args, **kwargs)
opt_result = fmin(cost_func, x0, fprime=grad_logp, *args, **kwargs)
else:
# Check to see if minimization function uses a starting value
if "x0" in getargspec(fmin).args:
opt_result = fmin(cost_func, bij.map(start), *args, **kwargs)
opt_result = fmin(cost_func, x0, *args, **kwargs)
else:
opt_result = fmin(cost_func, *args, **kwargs)
......@@ -174,6 +174,7 @@ def find_MAP(
assert isinstance(cost_func.progress, ProgressBar)
cost_func.progress.total = last_v
cost_func.progress.update(last_v)
print()
vars = get_default_varnames(model.unobserved_RVs, include_transformed)
mx = {var.name: value for var, value in zip(vars, model.fastfn(vars)(bij.rmap(mx0)))}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment