Unverified Commit 76f3a7b7 authored by Adrian Seyboldt's avatar Adrian Seyboldt Committed by GitHub

Allow specification of dims instead of shape (#3551)

* Allow specification of dims instead of shape

* Add pm.TidyData

* Create coords for pm.Data(ndarray)

* empty commit to trigger CI

* Apply suggestions from code review
Co-authored-by: default avatarAlexandre ANDORRA <andorra.alexandre@gmail.com>

* apply black formatting

* address review comments & formatting

* Add demonstration of named coordinates/dims

* don't require dim names to be identifiers

* sort imports

* raise ShapeError instead of ValueError

* formatting

* robustify Dtype and ShapeError

* Removed TidyData and refined dims and coords implementation

* Changed name of kwarg export_dims and improved docstrings

* Add link to ArviZ in docstrings

* Removed TidyData from __all__

* Polished Data container NB

* Fixed line break in data.py

* Fix inference of coords for dataframes

* Refined Data container NB

* Updated getting started NB with new dims and coords features

* Reran getting started NB

* Blackified NBs

* rerun with ArviZ branch

* use np.shape to be compatible with tuples/lists

* add tests for named coordinate handling

* Extended tests for data container
Co-authored-by: default avatarMichael Osthege <m.osthege@fz-juelich.de>
Co-authored-by: default avatarMichael Osthege <michael.osthege@outlook.com>
Co-authored-by: default avatarAlexandre ANDORRA <andorra.alexandre@gmail.com>
parent 8a8beab6
......@@ -17,6 +17,7 @@
- `pm.LKJCholeskyCov` now automatically computes and returns the unpacked Cholesky decomposition, the correlations and the standard deviations of the covariance matrix (see [#3881](https://github.com/pymc-devs/pymc3/pull/3881)).
- `pm.Data` container can now be used for index variables, i.e with integer data and not only floats (issue [#3813](https://github.com/pymc-devs/pymc3/issues/3813), fixed by [#3925](https://github.com/pymc-devs/pymc3/pull/3925)).
- `pm.Data` container can now be used as input for other random variables (issue [#3842](https://github.com/pymc-devs/pymc3/issues/3842), fixed by [#3925](https://github.com/pymc-devs/pymc3/pull/3925)).
- Allow users to specify coordinates and dimension names instead of numerical shapes when specifying a model. This makes interoperability with ArviZ easier. ([see #3551](https://github.com/pymc-devs/pymc3/pull/3551))
- Plots and Stats API sections now link to ArviZ documentation [#3927](https://github.com/pymc-devs/pymc3/pull/3927)
- Add `SamplerReport` with properties `n_draws`, `t_sampling` and `n_tune` to SMC. `n_tune` is always 0 [#3931](https://github.com/pymc-devs/pymc3/issues/3931).
- SMC-ABC: add option to define summary statistics, allow to sample from more complex models, remove redundant distances [#3940](https://github.com/pymc-devs/pymc3/issues/3940)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -56,17 +56,32 @@ class Distribution:
"a 'with model:' block, or use the '.dist' syntax "
"for a standalone distribution.")
if isinstance(name, string_types):
data = kwargs.pop('observed', None)
cls.data = data
if isinstance(data, ObservedRV) or isinstance(data, FreeRV):
raise TypeError("observed needs to be data but got: {}".format(type(data)))
total_size = kwargs.pop('total_size', None)
dist = cls.dist(*args, **kwargs)
return model.Var(name, dist, data, total_size)
else:
if not isinstance(name, string_types):
raise TypeError("Name needs to be a string but got: {}".format(name))
data = kwargs.pop('observed', None)
cls.data = data
if isinstance(data, ObservedRV) or isinstance(data, FreeRV):
raise TypeError("observed needs to be data but got: {}".format(type(data)))
total_size = kwargs.pop('total_size', None)
dims = kwargs.pop('dims', None)
has_shape = 'shape' in kwargs
shape = kwargs.pop('shape', None)
if dims is not None:
if shape is not None:
raise ValueError("Specify only one of 'dims' or 'shape'")
if isinstance(dims, string_types):
dims = (dims,)
shape = model.shape_from_dims(dims)
# Some distributions do not accept shape=None
if has_shape or shape is not None:
dist = cls.dist(*args, **kwargs, shape=shape)
else:
dist = cls.dist(*args, **kwargs)
return model.Var(name, dist, data, total_size, dims=dims)
def __getnewargs__(self):
return _Unpickling,
......@@ -77,7 +92,7 @@ class Distribution:
return dist
def __init__(self, shape, dtype, testval=None, defaults=(),
transform=None, broadcastable=None):
transform=None, broadcastable=None, dims=None):
self.shape = np.atleast_1d(shape)
if False in (np.floor(self.shape) == self.shape):
raise TypeError("Expected int elements in shape")
......@@ -467,8 +482,10 @@ class DensityDist(Distribution):
)
return samples
else:
raise ValueError("Distribution was not passed any random method "
"Define a custom random method and pass it as kwarg random")
raise ValueError(
"Distribution was not passed any random method. "
"Define a custom random method and pass it as kwarg random"
)
class _DrawValuesContext(metaclass=ContextMeta, context_class='_DrawValuesContext'):
......
......@@ -44,8 +44,12 @@ class ImputationWarning(UserWarning):
class ShapeError(Exception):
"""Error that the shape of a variable is incorrect."""
def __init__(self, message, actual=None, expected=None):
if expected and actual:
if actual is not None and expected is not None:
super().__init__('{} (actual {} != expected {})'.format(message, actual, expected))
elif actual is not None and expected is None:
super().__init__('{} (actual {})'.format(message, actual))
elif actual is None and expected is not None:
super().__init__('{} (expected {})'.format(message, expected))
else:
super().__init__(message)
......@@ -53,7 +57,11 @@ class ShapeError(Exception):
class DtypeError(TypeError):
"""Error that the dtype of a variable is incorrect."""
def __init__(self, message, actual=None, expected=None):
if expected and actual:
if actual is not None and expected is not None:
super().__init__('{} (actual {} != expected {})'.format(message, actual, expected))
elif actual is not None and expected is None:
super().__init__('{} (actual {})'.format(message, actual))
elif actual is None and expected is not None:
super().__init__('{} (expected {})'.format(message, expected))
else:
super().__init__(message)
This diff is collapsed.
......@@ -15,6 +15,7 @@
import pymc3 as pm
from .helpers import SeededTest
import numpy as np
import pandas as pd
import pytest
......@@ -195,6 +196,54 @@ class TestData(SeededTest):
text = 'obs [label="obs ~ Normal" style=filled]'
assert text in g.source
def test_explicit_coords(self):
N_rows = 5
N_cols = 7
data = np.random.uniform(size=(N_rows, N_cols))
coords = {
"rows": [f"R{r+1}" for r in range(N_rows)],
"columns": [f"C{c+1}" for c in range(N_cols)]
}
# pass coordinates explicitly, use numpy array in Data container
with pm.Model(coords=coords) as pmodel:
pm.Data('observations', data, dims=("rows", "columns"))
assert "rows" in pmodel.coords
assert pmodel.coords["rows"] == ['R1', 'R2', 'R3', 'R4', 'R5']
assert "columns" in pmodel.coords
assert pmodel.coords["columns"] == ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7']
assert pmodel.RV_dims == {'observations': ('rows', 'columns')}
def test_implicit_coords_series(self):
ser_sales = pd.Series(
data=np.random.randint(low=0, high=30, size=22),
index=pd.date_range(start="2020-05-01", periods=22, freq="24H", name="date"),
name="sales"
)
with pm.Model() as pmodel:
pm.Data("sales", ser_sales, dims="date", export_index_as_coords=True)
assert "date" in pmodel.coords
assert len(pmodel.coords["date"]) == 22
assert pmodel.RV_dims == {'sales': ('date',)}
def test_implicit_coords_dataframe(self):
N_rows = 5
N_cols = 7
df_data = pd.DataFrame()
for c in range(N_cols):
df_data[f'Column {c+1}'] = np.random.normal(size=(N_rows,))
df_data.index.name = 'rows'
df_data.columns.name = 'columns'
# infer coordinates from index and columns of the DataFrame
with pm.Model() as pmodel:
pm.Data('observations', df_data, dims=("rows", "columns"), export_index_as_coords=True)
assert "rows" in pmodel.coords
assert "columns" in pmodel.coords
assert pmodel.RV_dims == {'observations': ('rows', 'columns')}
def test_data_naming():
"""
......
......@@ -100,25 +100,37 @@ class TestUpdateStartVals(SeededTest):
class TestExceptions:
def test_shape_error(self):
err = pm.exceptions.ShapeError('Without shapes.')
with pytest.raises(pm.exceptions.ShapeError):
raise err
err = pm.exceptions.ShapeError('With shapes.', (4,3), (5,3))
assert 'actual (4, 3)' in err.args[0]
assert 'expected (5, 3)' in err.args[0]
with pytest.raises(pm.exceptions.ShapeError):
raise err
with pytest.raises(pm.exceptions.ShapeError) as exinfo:
raise pm.exceptions.ShapeError('Just the message.')
assert 'Just' in exinfo.value.args[0]
with pytest.raises(pm.exceptions.ShapeError) as exinfo:
raise pm.exceptions.ShapeError('With shapes.', actual=(2,3))
assert '(2, 3)' in exinfo.value.args[0]
with pytest.raises(pm.exceptions.ShapeError) as exinfo:
raise pm.exceptions.ShapeError('With shapes.', expected='(2,3) or (5,6)')
assert '(5,6)' in exinfo.value.args[0]
with pytest.raises(pm.exceptions.ShapeError) as exinfo:
raise pm.exceptions.ShapeError('With shapes.', actual=(), expected='(5,4) or (?,?,6)')
assert '(?,?,6)' in exinfo.value.args[0]
pass
def test_dtype_error(self):
err = pm.exceptions.DtypeError('Without dtypes.')
with pytest.raises(pm.exceptions.DtypeError):
raise err
err = pm.exceptions.DtypeError('With shapes.', np.float64, np.float32)
assert 'float64' in err.args[0]
assert 'float32' in err.args[0]
with pytest.raises(pm.exceptions.DtypeError):
raise err
pass
\ No newline at end of file
with pytest.raises(pm.exceptions.DtypeError) as exinfo:
raise pm.exceptions.DtypeError('Just the message.')
assert 'Just' in exinfo.value.args[0]
with pytest.raises(pm.exceptions.DtypeError) as exinfo:
raise pm.exceptions.DtypeError('With types.', actual=str)
assert 'str' in exinfo.value.args[0]
with pytest.raises(pm.exceptions.DtypeError) as exinfo:
raise pm.exceptions.DtypeError('With types.', expected=float)
assert 'float' in exinfo.value.args[0]
with pytest.raises(pm.exceptions.DtypeError) as exinfo:
raise pm.exceptions.DtypeError('With types.', actual=int, expected=str)
assert 'int' in exinfo.value.args[0] and 'str' in exinfo.value.args[0]
pass
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment