Commit b8a67fea authored by Robert P. Goldman's avatar Robert P. Goldman

Fix forward reference error.

Local class declared after reference.
parent ab227bd4
......@@ -171,79 +171,83 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.
### makes the shape issues just a little easier to deal with.
model = modelcontext(model)
assert model is not None
with model:
if keep_size and samples is not None:
raise IncorrectArgumentsError("Should not specify both keep_size and samples arguments")
if keep_size and not isinstance(trace, MultiTrace):
# arguably this should be just a warning.
raise IncorrectArgumentsError("keep_size argument only applies when sampling from MultiTrace.")
if isinstance(trace, list) and all((isinstance(x, dict) for x in trace)):
_trace = _TraceDict(point_list=trace)
elif isinstance(trace, MultiTrace):
_trace = _TraceDict(multi_trace=trace)
else:
raise TypeError("Unable to generate posterior predictive samples from argument of type %s"%type(trace))
len_trace = len(_trace)
assert isinstance(_trace, _TraceDict)
_samples = [] # type: List[int]
# temporary replacement for more complicated logic.
max_samples: int = len_trace
if samples is None or samples == max_samples:
_samples = [max_samples]
elif samples < max_samples:
warnings.warn("samples parameter is smaller than nchains times ndraws, some draws "
"and/or chains may not be represented in the returned posterior "
"predictive sample")
# if this is less than the number of samples in the trace, take a slice and
# work with that.
_trace = _trace[slice(samples)]
_samples = [samples]
elif samples > max_samples:
full, rem = divmod(samples, max_samples)
_samples = (full * [max_samples]) + ([rem] if rem != 0 else [])
else:
raise IncorrectArgumentsError("Unexpected combination of samples (%s) and max_samples (%d)"%(samples, max_samples))
if keep_size and samples is not None:
raise IncorrectArgumentsError("Should not specify both keep_size and samples arguments")
if keep_size and not isinstance(trace, MultiTrace):
# arguably this should be just a warning.
raise IncorrectArgumentsError("keep_size argument only applies when sampling from MultiTrace.")
if var_names is None:
vars = model.observed_RVs
else:
vars = [model[x] for x in var_names]
if isinstance(trace, list) and all((isinstance(x, dict) for x in trace)):
_trace = _TraceDict(point_list=trace)
elif isinstance(trace, MultiTrace):
_trace = _TraceDict(multi_trace=trace)
else:
raise TypeError("Unable to generate posterior predictive samples from argument of type %s"%type(trace))
len_trace = len(_trace)
assert isinstance(_trace, _TraceDict)
_samples = [] # type: List[int]
# temporary replacement for more complicated logic.
max_samples: int = len_trace
if samples is None or samples == max_samples:
_samples = [max_samples]
elif samples < max_samples:
warnings.warn("samples parameter is smaller than nchains times ndraws, some draws "
"and/or chains may not be represented in the returned posterior "
"predictive sample")
# if this is less than the number of samples in the trace, take a slice and
# work with that.
_trace = _trace[slice(samples)]
_samples = [samples]
elif samples > max_samples:
full, rem = divmod(samples, max_samples)
_samples = (full * [max_samples]) + ([rem] if rem != 0 else [])
else:
raise IncorrectArgumentsError("Unexpected combination of samples (%s) and max_samples (%d)"%(samples, max_samples))
if random_seed is not None:
np.random.seed(random_seed)
if var_names is None:
vars = model.observed_RVs
else:
vars = [model[x] for x in var_names]
if TYPE_CHECKING:
_ETPParent = UserDict[str, np.ndarray] # this is only processed by mypy
else:
_ETPParent = UserDict # this is not seen by mypy but will be executed at runtime.
if random_seed is not None:
np.random.seed(random_seed)
class _ExtendableTrace(_ETPParent):
def extend_trace(self, trace: Dict[str, np.ndarray]) -> None:
for k, v in trace.items():
if k in self.data:
self.data[k] = np.concatenate((self.data[k], v))
else:
self.data[k] = v
ppc_trace = _ExtendableTrace()
for s in _samples:
strace = _trace if s == len_trace else _trace[slice(0, s)]
try:
values = posterior_predictive_draw_values(cast(List[Any], vars), strace, s)
new_trace = {k.name: v for (k, v) in zip(vars, values)} # type: Dict[str, np.ndarray]
ppc_trace.extend_trace(new_trace)
except KeyboardInterrupt:
pass
if keep_size:
assert isinstance(trace, MultiTrace)
return {k: ary.reshape((trace.nchains, len(trace), *ary.shape[1:])) for k, ary in ppc_trace.items() }
else:
return ppc_trace.data # this gets us a Dict[str, np.ndarray] instead of my wrapped equiv.
if TYPE_CHECKING:
_ETPParent = UserDict[str, np.ndarray] # this is only processed by mypy
else:
_ETPParent = UserDict # this is not seen by mypy but will be executed at runtime.
ppc_trace = _ExtendableTrace()
for s in _samples:
strace = _trace if s == len_trace else _trace[slice(0, s)]
try:
values = posterior_predictive_draw_values(cast(List[Any], vars), strace, s)
new_trace = {k.name: v for (k, v) in zip(vars, values)} # type: Dict[str, np.ndarray]
ppc_trace.extend_trace(new_trace)
except KeyboardInterrupt:
pass
if keep_size:
assert isinstance(trace, MultiTrace)
return {k: ary.reshape((trace.nchains, len(trace), *ary.shape[1:])) for k, ary in ppc_trace.items() }
else:
return ppc_trace.data # this gets us a Dict[str, np.ndarray] instead of my wrapped equiv.
class _ExtendableTrace(_ETPParent):
def extend_trace(self, trace: Dict[str, np.ndarray]) -> None:
for k, v in trace.items():
if k in self.data:
self.data[k] = np.concatenate((self.data[k], v))
else:
self.data[k] = v
def posterior_predictive_draw_values(vars: List[Any], trace: _TraceDict, samples: int) -> List[np.ndarray]:
with _PosteriorPredictiveSampler(vars, trace, samples, None) as sampler:
......@@ -404,6 +408,7 @@ class _PosteriorPredictiveSampler(AbstractContextManager):
# initialization phase
context = _DrawValuesContext.get_context()
assert isinstance(context, _DrawValuesContext)
with context:
drawn = context.drawn_vars
evaluated = {} # type: Dict[int, Any]
......@@ -501,7 +506,8 @@ class _PosteriorPredictiveSampler(AbstractContextManager):
if hasattr(param, 'model') and trace and param.name in trace.varnames:
return trace[param.name]
elif hasattr(param, 'random') and param.random is not None:
model = modelcontext(None)
model = modelcontext(None)
assert isinstance(model, Model)
shape = tuple(_param_shape(param, model)) # type: Tuple[int, ...]
return random_sample(param.random, param, point=trace, size=samples, shape=shape)
elif (hasattr(param, 'distribution') and
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment