Unverified Commit effe5c4a authored by Tirth Patel's avatar Tirth Patel Committed by GitHub

small changes to speed up find_MAP (#3916)

parent ccc78cbc
...@@ -140,11 +140,11 @@ def find_MAP( ...@@ -140,11 +140,11 @@ def find_MAP(
def grad_logp(point): def grad_logp(point):
return nan_to_num(-dlogp_func(point)) return nan_to_num(-dlogp_func(point))
opt_result = fmin(cost_func, bij.map(start), fprime=grad_logp, *args, **kwargs) opt_result = fmin(cost_func, x0, fprime=grad_logp, *args, **kwargs)
else: else:
# Check to see if minimization function uses a starting value # Check to see if minimization function uses a starting value
if "x0" in getargspec(fmin).args: if "x0" in getargspec(fmin).args:
opt_result = fmin(cost_func, bij.map(start), *args, **kwargs) opt_result = fmin(cost_func, x0, *args, **kwargs)
else: else:
opt_result = fmin(cost_func, *args, **kwargs) opt_result = fmin(cost_func, *args, **kwargs)
...@@ -174,6 +174,7 @@ def find_MAP( ...@@ -174,6 +174,7 @@ def find_MAP(
assert isinstance(cost_func.progress, ProgressBar) assert isinstance(cost_func.progress, ProgressBar)
cost_func.progress.total = last_v cost_func.progress.total = last_v
cost_func.progress.update(last_v) cost_func.progress.update(last_v)
print()
vars = get_default_varnames(model.unobserved_RVs, include_transformed) vars = get_default_varnames(model.unobserved_RVs, include_transformed)
mx = {var.name: value for var, value in zip(vars, model.fastfn(vars)(bij.rmap(mx0)))} mx = {var.name: value for var, value in zip(vars, model.fastfn(vars)(bij.rmap(mx0)))}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment