Abstract the model state so that we can mutate it conveniently.

This rather large change adds a layer between the underlying (immutable) model and the application.
In doing so we can avoid the use of a global state (useful for the purposes of configuring multiple models in the same application later on) and it also unlocks the ability to implement an MVC-like separation of concerns - again, the intention is that when it comes to comparisons, we will just be able to re-use our application views.

I was hoping that ``cara.state`` could have been avoided in lieu of using traitlets, but unfortunately I found a number of limitations which were prohibitive for its use here.
Foremost of which was the lack of first-class dataclass support and the difficulty in needing either to use instances of the model (immutable) or duplicate the model and its structure in a mutable form and use the ``traitlets.Instance`` type.
Instead I opted for doing it myself - the ``cara.state`` module would make a very good standalone project in the future.
This commit is contained in:
Phil Elson 2020-10-21 20:29:17 +02:00
parent 69730bfb2a
commit f63e1d3760
6 changed files with 616 additions and 252 deletions

View file

@ -13,259 +13,22 @@
"</p>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib widget\n",
"import ipywidgets as widgets\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import typing"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import cara.models\n",
"\n",
"\n",
"def prepare_model(volume, n_infected=1, n_exposed=10, mask='Type I') -> cara.models.Model:\n",
" \"\"\"\n",
" Transform configurable values into a cara model instance.\n",
" \n",
" \"\"\"\n",
" model = cara.models.Model(\n",
" room=cara.models.Room(volume=volume),\n",
" ventilation=cara.models.PeriodicWindow(period=120, duration=120, inside_temp=293, outside_temp=283,\n",
" window_height=1.6, opening_length=0.6, cd_b=0.6),\n",
" infected=cara.models.InfectedPerson(\n",
" virus=cara.models.Virus.types['SARS_CoV_2'],\n",
" present_times=((0, 4), (5, 8)),\n",
" mask=cara.models.Mask.types[mask],\n",
" activity=cara.models.Activity.types['Light exercise'],\n",
" expiration=cara.models.Expiration.types['Unmodulated Vocalization'],\n",
" ),\n",
" infected_occupants=n_infected,\n",
" exposed_occupants=n_exposed,\n",
" exposed_activity=cara.models.Activity.types['Light exercise'],\n",
" )\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Setup our plotting environment.\n",
"plt.interactive(False)\n",
"fig_concentration_over_time = plt.figure()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Define some useful widget machinery.\n",
"\n",
"def collapsible(widgets_to_collapse: typing.List, title: str, start_collapsed=True):\n",
" collapsed = widgets.Accordion([widgets.VBox(widgets_to_collapse)])\n",
" collapsed.set_title(0, title)\n",
" if start_collapsed:\n",
" collapsed.selected_index = None\n",
" return collapsed\n",
"\n",
"\n",
"def widget_group(label_widget_pairs):\n",
" labels, widgets_ = zip(*label_widget_pairs) \n",
" labels_w = widgets.VBox(labels)\n",
" widgets_w = widgets.VBox(widgets_)\n",
" return widgets.HBox([labels_w, widgets_w])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eb109c0f63e149d69e763aec5d404db2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Accordion(children=(VBox(children=(HBox(children=(VBox(children=(Label(value='Room volume'),)), VBox(children=…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"room_volume = widgets.IntSlider(value=75, min=10, max=150)\n",
"mask_used = widgets.Checkbox(value=True, description='Mask worn')\n",
"\n",
"collapsible(\n",
" [widget_group(\n",
" [[widgets.Label('Room volume'), room_volume]]\n",
" )],\n",
" title='Specification of workplace', start_collapsed=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "49ad604786f546f58dba54f1f6e7eded",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Accordion(children=(VBox(children=(HBox(children=(VBox(children=(Label(value='Ventilation type'),)), VBox(chil…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ventilation_widgets = {\n",
" 'Natural': widgets.Label('Currently hard-coded to window-example from mathematica notebook'),\n",
" 'other': widgets.Label('Not yet implemented.')\n",
"}\n",
"for name, widget in ventilation_widgets.items():\n",
" widget.layout.visible = False\n",
"\n",
"ventilation_w = widgets.ToggleButtons(\n",
" options=['Natural'], # cara.models.Ventilation.types.keys(),\n",
")\n",
"def toggle_ventilation(value):\n",
" for name, widget in ventilation_widgets.items():\n",
" widget.layout.display = 'none'\n",
" other = ventilation_widgets['other']\n",
" widget = ventilation_widgets.get(value, other)\n",
" widget.layout.visible = True\n",
" widget.layout.display = 'block'\n",
"\n",
"ventilation_w.observe(lambda event: toggle_ventilation(event['new']), 'value')\n",
"toggle_ventilation(ventilation_w.value)\n",
"\n",
"\n",
"collapsible(\n",
" [widget_group([[widgets.Label('Ventilation type'), ventilation_w]])]\n",
" + list(ventilation_widgets.values()),\n",
" title='Ventilation scheme'\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8e76a49d0212462d81200a3959dcd3ff",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Accordion(children=(VBox(children=(HBox(children=(Canvas(footer_visible=False, header_visible=False, toolbar=T…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"line = None\n",
"# plt.ioff()\n",
"# fig = plt.figure()\n",
"fig = fig_concentration_over_time\n",
"\n",
"def plot_concentrations(_):\n",
" global line\n",
" model = prepare_model(room_volume.value)\n",
"\n",
" ts = np.arange(0, 10., 0.01)\n",
" concentration = [model.concentration(t) for t in ts]\n",
"\n",
" ax = fig.gca()\n",
" \n",
" plt.text(0.5, 0.9, 'Without masks & window open', transform=ax.transAxes, ha='center')\n",
" if line is None:\n",
" ax.spines['right'].set_visible(False)\n",
" ax.spines['top'].set_visible(False)\n",
" [line] = plt.plot(ts, concentration)\n",
" ax.set_xlabel('Time (hours)')\n",
" ax.set_ylabel('Concentration ($q/m^3$)')\n",
" plt.title('Concentration of infectious quanta aerosols')\n",
" \n",
" ax.set_ymargin(0.2)\n",
" ax.set_ylim(bottom=0)\n",
" else:\n",
" line.set_data(ts, concentration)\n",
" ax.relim()\n",
" ax.autoscale_view()\n",
" \n",
" plt.draw()\n",
"\n",
"# print(f'Probability of infection: {np.round(model.[\"P\"], 1)}')\n",
"# print(f'Expected number of new cases: {prepared[\"R0\"]}')\n",
" \n",
"\n",
"# widgets.interact(\n",
"# plot_concentrations,\n",
"# volume=room_volume,\n",
"# n_exposed=widgets.IntSlider(value=10, min=0, max=25),\n",
"# n_infected=widgets.IntSlider(value=1, min=1, max=5),\n",
"# );\n",
"\n",
"\n",
"for observable in [room_volume]:\n",
" observable.observe(plot_concentrations)\n",
"\n",
"plot_concentrations(1)\n",
"\n",
"fig.canvas.toolbar_visible = True\n",
"fig.canvas.toolbar.collapsed = True\n",
"fig.canvas.footer_visible = False\n",
"fig.canvas.header_visible = False\n",
"\n",
"\n",
"collapsible([\n",
" widgets.HBox([\n",
" fig.canvas,\n",
" # text_report,\n",
" ])\n",
"], 'Report', start_collapsed=False)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
"source": [
"import cara.apps\n",
"\n",
"app = cara.apps.ExpertApplication()\n",
"app.widget"
]
}
],
"metadata": {

202
cara/apps.py Normal file
View file

@ -0,0 +1,202 @@
import typing
import uuid
import ipympl.backend_nbagg
import ipywidgets as widgets
import numpy as np
import matplotlib
import matplotlib.figure
from cara import models
from cara import state
def collapsible(widgets_to_collapse: typing.List, title: str, start_collapsed=True):
collapsed = widgets.Accordion([widgets.VBox(widgets_to_collapse)])
collapsed.set_title(0, title)
if start_collapsed:
collapsed.selected_index = None
return collapsed
def widget_group(label_widget_pairs):
labels, widgets_ = zip(*label_widget_pairs)
labels_w = widgets.VBox(labels)
widgets_w = widgets.VBox(widgets_)
return widgets.HBox([labels_w, widgets_w])
class ConcentrationFigure:
def __init__(self):
self.figure = matplotlib.figure.Figure(figsize=(9, 6))
self.ax = self.figure.add_subplot(1, 1, 1)
self.line = None
def update(self, model: models.Model):
resolution = 600
ts = np.linspace(0, 10, resolution)
concentration = [model.concentration(t) for t in ts]
if self.line is None:
[self.line] = self.ax.plot(ts, concentration)
ax = self.ax
ax.text(0.5, 0.9, 'Without masks & window open', transform=ax.transAxes, ha='center')
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.set_xlabel('Time (hours)')
ax.set_ylabel('Concentration ($q/m^3$)')
ax.set_title('Concentration of infectious quanta aerosols')
ax.set_ymargin(0.2)
ax.set_ylim(bottom=0)
else:
self.line.set_data(ts, concentration)
self.ax.relim()
self.ax.autoscale_view()
self.figure.canvas.draw()
def ipympl_canvas(figure: matplotlib.figure.Figure):
# Make a plain matplotlib figure render as a Jupyter widget.
matplotlib.interactive(False)
ipympl.backend_nbagg.new_figure_manager_given_figure(uuid.uuid1(), figure)
figure.canvas.toolbar_visible = True
figure.canvas.toolbar.collapsed = True
figure.canvas.footer_visible = False
figure.canvas.header_visible = False
class WidgetView:
def __init__(self, model_state: state.DataclassState):
self.model_state = model_state
self.model_state.dcs_observe(self.update)
#: The widgets that this view produces (inputs and outputs together)
self.widget = widgets.VBox([])
self.widgets = {}
self.plots = []
self.construct_widgets()
# Trigger the first result.
self.update()
def construct_widgets(self):
# Build the input widgets.
self._build_widget(self.model_state)
# And the output widget figure.
concentration = ConcentrationFigure()
self.plots.append(concentration)
ipympl_canvas(concentration.figure)
self.widgets['results'] = collapsible([
widgets.HBox([
concentration.figure.canvas,
])
], 'Results', start_collapsed=False)
# Join inputs and outputs together in a single widget for convenience.
self.widget.children += (self.widgets['results'], )
def prepare_output(self):
pass
def update(self):
model = self.model_state.dcs_instance()
for plot in self.plots:
plot.update(model)
def _build_widget(self, node):
if isinstance(node, state.DataclassState):
if node._base == models.Ventilation:
self.widget.children += (self._build_ventilation(node), )
elif node._base == models.Room:
self.widget.children += (self._build_room(node), )
else:
# Don't do anything with this state, but recurse down in case
# its children want widgets.
for name, child in node._data.items():
self._build_widget(child)
def _build_room(self, node):
room_volume = widgets.IntSlider(value=node.volume, min=10, max=150)
mask_used = widgets.Checkbox(value=True, description='Mask worn')
def on_value_change(change):
node.volume = change['new']
# TODO: Link the state back to the widget, not just the other way around.
room_volume.observe(on_value_change, names=['value'])
widget = collapsible(
[widget_group(
[[widgets.Label('Room volume'), room_volume]]
)],
title='Specification of workplace', start_collapsed=False,
)
return widget
def _build_ventilation(self, node):
ventilation_widgets = {
'Natural': widgets.Label('Currently hard-coded to window-example from mathematica notebook'),
'other': widgets.Label('Not yet implemented.')
}
for name, widget in ventilation_widgets.items():
widget.layout.visible = False
ventilation_w = widgets.ToggleButtons(
options=ventilation_widgets.keys(),
)
def toggle_ventilation(value):
for name, widget in ventilation_widgets.items():
widget.layout.display = 'none'
other = ventilation_widgets['other']
widget = ventilation_widgets.get(value, other)
widget.layout.visible = True
widget.layout.display = 'block'
ventilation_w.observe(lambda event: toggle_ventilation(event['new']), 'value')
toggle_ventilation(ventilation_w.value)
w = collapsible(
[widget_group([[widgets.Label('Ventilation type'), ventilation_w]])]
+ list(ventilation_widgets.values()),
title='Ventilation scheme'
)
return w
def present(self):
return self.widget
baseline_model = models.Model(
room=models.Room(volume=75),
ventilation=models.PeriodicWindow(
period=120, duration=120, inside_temp=293, outside_temp=283, cd_b=0.6,
window_height=1.6, opening_length=0.6,
),
infected=models.InfectedPerson(
virus=models.Virus.types['SARS_CoV_2'],
present_times=((0, 4), (5, 8)),
mask=models.Mask.types['No mask'],
activity=models.Activity.types['Light exercise'],
expiration=models.Expiration.types['Unmodulated Vocalization'],
),
infected_occupants=1,
exposed_occupants=10,
exposed_activity=models.Activity.types['Light exercise'],
)
class ExpertApplication:
def __init__(self):
self.model_state = state.DataclassState(models.Model)
self.model_state.dcs_update_from(
baseline_model
)
self.view = WidgetView(self.model_state)
# self._widget = widgets.Text("WIP")
@property
def widget(self):
return self.view.present()

229
cara/state.py Normal file
View file

@ -0,0 +1,229 @@
"""
This module is entirely in support of providing a convenient mutable counterpart
to frozen dataclasses. Significant effort went into to trying to use traitlets
for this purpose, but the need to define class-level attributes proved to be a
limitation that meant we could not mutate the state from one subclass to another
after the state was instantiated.
This module MUST not import other parts of cara as this would point at a
leaky abstraction.
"""
from contextlib import contextmanager
import dataclasses
import typing
Datamodel_T = typing.Type
dataclass_instance = typing.Any
class StateBuilder:
def visit(self, field: dataclasses.Field):
builder = self.resolve_builder(field)
return builder(field.type)
def resolve_builder(self, field: dataclasses.Field):
method_name = [
f'build_name_{field.name}',
f'build_type_{field.type.__name__}',
]
for name in method_name:
method = getattr(self, name, None)
if method is not None:
return method
return self.build_generic
def build_generic(self, type_to_build: typing.Type):
return DataclassState(type_to_build, self)
class DataclassState:
"""
Represents the state of a frozen dataclass.
No type checking of the attributes is attempted.
Setting the state can be done with:
setattr(state, attr, value)
Accessing the instance that this state represents can be done with:
state.dcs_instance()
Changing the type to a subclass of the base that this state represents can be done with:
state.dcs_set_instance_type(ASubclassOfBase)
"""
def __init__(self, dataclass: Datamodel_T, state_builder=StateBuilder()):
# Note that the constructor does *not* insert any data by default. It
# therefore doesn't build nested DataclassState instances when a dataclass contains another.
# For that, use the build classmethod.
if not dataclasses.is_dataclass(dataclass):
raise TypeError("The given class is not a valid dataclass")
if not isinstance(dataclass, type):
raise TypeError("A dataclass type must be provided, not an instance of one")
with self._object_setattr():
#: The base instance which this state must support.
self._base = dataclass
#: The actual instance type that this state represents (i.e. may be a
#: subclass of _base).
self._instance_type = dataclass
#: The instance of dataclass which this state represents. Undefined until
#: sufficient data is provided.
self._instance = None
self._data = {}
self._observers: typing.List[callable] = []
self._state_builder = state_builder
self.dcs_set_instance_type(dataclass)
def __repr__(self):
return f"<state for {self._instance_type.__name__}(**{self._instance_state()})>"
def _instance_attrs(self):
return [field.name for field in dataclasses.fields(self._instance_type)]
def dcs_observe(self, callback: typing.Callable):
self._observers.append(callback)
def dcs_update_from(self, data: dataclass_instance):
self.dcs_set_instance_type(data.__class__)
for field in dataclasses.fields(data):
attr = field.name
current_value = self._data.get(attr, None)
new_value = getattr(data, attr)
if dataclasses.is_dataclass(field.type):
assert isinstance(current_value, DataclassState)
current_value.dcs_update_from(new_value)
else:
self._data[attr] = new_value
def _fire_observers(self):
self._instance = None
for observer in self._observers:
observer()
@contextmanager
def _object_setattr(self):
self._use_base_setattr = True
yield
self._use_base_setattr = False
def __getattr__(self, name):
try:
return super().__getattribute__(name)
except AttributeError:
pass
if name in self._data:
return self._data[name]
elif name in self._instance_attrs():
raise ValueError(f"State not yet set for {name}")
else:
raise AttributeError(f"Attribute {name} does not exist on {self._instance_type.__name__}")
def __setattr__(self, name, value):
if name in self.__dict__ or self.__dict__.get('_use_base_setattr', True):
return object.__setattr__(self, name, value)
if name in self._instance_attrs():
self._dcs_set_value(name, value)
def _dcs_set_value(self, attr_name, value):
valid_attrs = self._instance_attrs()
if isinstance(value, DataclassState):
# TODO: We need to check that the value is acceptable. (needs
# thinking about)
# assert value._base ==
pass
# TODO: Inject some notifications here to tell any holding DataclassState
# instances that we've changed.
if attr_name in valid_attrs:
self._data[attr_name] = value
self._fire_observers()
else:
raise AttributeError(f"No attribute {attr_name} on a {self._instance_type.__name__}")
def dcs_set_instance_type(self, instance_dataclass: Datamodel_T):
if not dataclasses.is_dataclass(instance_dataclass):
raise TypeError("The given class is not a valid dataclass")
if not issubclass(instance_dataclass, self._base):
raise TypeError(f"The dataclass type provided ({instance_dataclass}) must be a subclass of the base ({self._base})")
self._instance_type = instance_dataclass
self._data.clear()
for field in dataclasses.fields(instance_dataclass):
if dataclasses.is_dataclass(field.type):
self._data[field.name] = self._state_builder.visit(field)
self._data[field.name].dcs_observe(self._fire_observers)
def _instance_state(self):
# Note: this method should not validate that the args are complete or
# overspecified.
kwargs = {}
for name, data in self._data.items():
if isinstance(data, DataclassState):
data = data._instance_state()
kwargs[name] = data
return kwargs
def _instance_kwargs(self):
# Note: this method should not validate that the args are complete or
# overspecified.
kwargs = {}
for name, data in self._data.items():
if isinstance(data, DataclassState):
data = data.dcs_instance()
kwargs[name] = data
return kwargs
def dcs_instance(self):
if self._instance is None:
# TODO: Check if we are able to create an instance with our data...
self._instance = self._instance_type(**self._instance_kwargs())
return self._instance
class DataclassStatePredefined(DataclassState):
"""
Only a pre-defined selection of states for the given type are allowed.
Selected by name (the keys in the dictionary).
You can change the chosen state with:
state.dcs_select(name)
"""
def __init__(self, dataclass: Datamodel_T, choices: typing.Dict[typing.Hashable, dataclass_instance]):
super().__init__(dataclass=dataclass)
with self._object_setattr():
self._choices = choices
self._selected = None
# Pick the first choice until we know otherwise.
self.dcs_select(list(choices.keys())[0])
def dcs_select(self, name: typing.Hashable):
if name not in self._choices:
raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._choices)}')
self._selected = name
self._instance = self._choices[name]
def dcs_instance(self):
return self._choices[self._selected]
def __repr__(self):
return f"<state for {self._instance_type.__name__}. '{self._selected}' selected>"
def _instance_kwargs(self):
raise NotImplementedError("Doesn't make much sense")
def _instance_state(self):
return dataclasses.asdict(self.dcs_instance())
def _instance_kwargs(self):
return dataclasses.asdict(self.dcs_instance())

9
cara/tests/test_apps.py Normal file
View file

@ -0,0 +1,9 @@
import cara.apps
def test_app():
# To start with, let's just test that the application runs. We don't try to
# do anything fancy to verify how it looks etc., we leave that for manual
# testing.
expert_app = cara.apps.ExpertApplication()
assert expert_app.model_state.room.volume == 75

161
cara/tests/test_state.py Normal file
View file

@ -0,0 +1,161 @@
from dataclasses import dataclass
import typing
from unittest.mock import Mock
import pytest
from cara import state
@dataclass
class DCSimple:
attr1: str
attr2: int
@dataclass
class DCSimpleSubclass(DCSimple):
attr3: float
@dataclass
class DCOverrideSubclass(DCSimple):
attr1: float
@dataclass
class DCClassVar(DCSimple):
a_class_var: typing.ClassVar[int]
@dataclass
class DCRecursive(DCSimple):
simple: DCSimple
@dataclass
class DCNested:
simple: DCSimple
others: typing.List[DCSimple]
vanilla: str = 'Default'
@dataclass
class DCNestedDeep:
child: DCNested
@pytest.fixture
def dc_simple():
return DCSimple
def test_DCS_construct():
s = state.DataclassState(DCSimple)
assert repr(s) == '<state for DCSimple(**{})>'
with pytest.raises(TypeError, match=r"A dataclass type must be provided, not an instance of one"):
state.DataclassState(DCSimple('', 1))
with pytest.raises(TypeError, match="The given class is not a valid dataclass"):
state.DataclassState(None)
def test_DCS_construct_nested():
s = state.DataclassState(DCNested)
assert repr(s) == "<state for DCNested(**{'simple': {}})>"
@pytest.mark.xfail
def test_DCS_subclass():
s = state.DataclassState(DCSimple)
s.dcs_set_instance_type(DCSimpleSubclass)
s.set('attr3', 3.14)
assert s._instance_kwargs() == {'attr3': 3.14}
s.dcs_set_instance_type(DCSimple)
# TODO: Make this fail.
assert s._instance_kwargs() == {}
def test_DCS_setattr():
s = state.DataclassState(DCSimple)
s.attr1 = 'Hello world'
assert s._instance_kwargs() == {'attr1': 'Hello world'}
@pytest.mark.xfail
def test_DCS_type_check():
s = state.DataclassState(DCSimple)
with pytest.raises(TypeError):
# TODO: Should we make this fail? It involves type-checking / validation.
s.attr1 = 1
def test_DCS_update_from_instance():
s = state.DataclassState(DCSimple)
s.dcs_update_from(DCSimple('a1', 2))
assert s._instance_type == DCSimple
assert s._instance_kwargs() == {'attr1': 'a1', 'attr2': 2}
def test_DCS_update_from_instance_subclass():
s = state.DataclassState(DCSimple)
s.dcs_update_from(DCSimpleSubclass('a1', 2, 3.14))
assert s._instance_type == DCSimpleSubclass
assert s._instance_kwargs() == {'attr1': 'a1', 'attr2': 2, 'attr3': 3.14}
def test_DCS_update_from_instance_nested():
s = state.DataclassState(DCNested)
nested = DCNested(DCSimpleSubclass('a1', 2, 3.14), [])
s.dcs_update_from(nested)
assert s.simple.dcs_instance() == nested.simple
assert s.dcs_instance() == nested
def test_observe_instance_nested():
top_level = Mock()
nested = Mock()
s = state.DataclassState(DCNested)
s.dcs_observe(top_level)
s.simple.dcs_observe(nested)
s.simple.attr1 = 'something new'
top_level.assert_called_with()
nested.assert_called_with()
top_level.reset_mock()
nested.reset_mock()
s.vanilla = 'something new'
top_level.assert_called_with()
nested.assert_not_called()
def test_DCS_predefined():
opt1 = DCSimple('a', 1)
opt2 = DCSimpleSubclass('b', 2, 3.14)
s = state.DataclassStatePredefined(
DCSimple, {'option 1': opt1, 'option 2': opt2}
)
assert s._selected == 'option 1'
# TODO: This should fail.
s.attr1 = 'can I set it?'
assert s.dcs_instance() == opt1
s.dcs_select('option 2')
assert s.dcs_instance() == opt2
assert repr(s) == "<state for DCSimple. 'option 2' selected>"
# TODO: This should fail too.
s.dcs_update_from(opt1)
assert s.dcs_instance() == opt2
def test_DCS_non_dataclass_attrs():
val = DCClassVar('a', 1)
s = state.DataclassState(DCSimple)
s.dcs_update_from(val)
s.dcs_instance() == val

View file

@ -17,17 +17,17 @@ with (HERE / 'README.md').open('rt') as fh:
REQUIREMENTS: dict = {
'core': [
'dataclasses; python_version < "3.7"',
'numpy',
],
'app': [
'ipykernel',
'ipympl',
'ipywidgets',
'matplotlib',
'numpy',
'voila >=0.2.4',
],
'app': [],
'test': [
'pytest',
'pytest-tornasync', # Unused, but needed because of a downstream dependency.
],
'dev': [
'jupyterlab',