Commit e2a90c03 authored by Oriol (ZBook)'s avatar Oriol (ZBook) Committed by Thomas Wiecki

use from_pymc3(..., log_likelihood=False) and update requirements

parent 7301027c
......@@ -87,7 +87,7 @@ class SamplerReport:
def t_sampling(self) -> typing.Optional[float]:
"""
Number of seconds that the sampling procedure took.
(Includes parallelization overhead.)
"""
return self._t_sampling
......@@ -108,6 +108,7 @@ class SamplerReport:
return
from pymc3 import rhat, ess
from arviz import from_pymc3
valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
varnames = []
......@@ -119,8 +120,9 @@ class SamplerReport:
if rv_name in trace.varnames:
varnames.append(rv_name)
self._ess = ess = ess(trace, var_names=varnames)
self._rhat = rhat = rhat(trace, var_names=varnames)
idata = from_pymc3(trace, log_likelihood=False)
self._ess = ess = ess(idata, var_names=varnames)
self._rhat = rhat = rhat(idata, var_names=varnames)
warnings = []
rhat_max = max(val.max() for val in rhat.values())
......
arviz>=0.4.1
arviz>=0.7.0
theano>=1.0.4
numpy>=1.13.0
scipy>=0.18.1
......@@ -7,4 +7,4 @@ patsy>=0.5.1
fastprogress>=0.2.0
h5py>=2.7.0
typing-extensions>=3.7.4
contextvars; python_version < '3.7'
contextvars; python_version < '3.7'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment