From f63e1d37607f9c46de5bd737a3f3d893dfb82d0a Mon Sep 17 00:00:00 2001
From: Phil Elson
Date: Wed, 21 Oct 2020 20:29:17 +0200
Subject: [PATCH] Abstract the model state so that we can mutate it
conveniently.
This rather large change adds a layer between the underlying (immutable) model and the application.
In doing so we can avoid the use of a global state (useful for the purposes of configuring multiple models in the same application later on) and it also unlocks the ability to implement an MVC-like separation of concerns - again, the intention is that when it comes to comparisons, we will just be able to re-use our application views.
I was hoping that ``cara.state`` could have been avoided in lieu of using traitlets, but unfortunately I found a number of limitations which were prohibitive for its use here.
Foremost of which was the lack of first-class dataclass support and the difficulty in needing either to use instances of the model (immutable) or duplicate the model and its structure in a mutable form and use the ``traitlets.Instance`` type.
Instead I opted for doing it myself - the ``cara.state`` module would make a very good standalone project in the future.
---
app/cara.ipynb | 261 ++-------------------------------------
cara/apps.py | 202 ++++++++++++++++++++++++++++++
cara/state.py | 229 ++++++++++++++++++++++++++++++++++
cara/tests/test_apps.py | 9 ++
cara/tests/test_state.py | 161 ++++++++++++++++++++++++
setup.py | 6 +-
6 files changed, 616 insertions(+), 252 deletions(-)
create mode 100644 cara/apps.py
create mode 100644 cara/state.py
create mode 100644 cara/tests/test_apps.py
create mode 100644 cara/tests/test_state.py
diff --git a/app/cara.ipynb b/app/cara.ipynb
index 931d3197..d3cb17e5 100644
--- a/app/cara.ipynb
+++ b/app/cara.ipynb
@@ -13,259 +13,22 @@
"
"
]
},
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "%matplotlib widget\n",
- "import ipywidgets as widgets\n",
- "import matplotlib.pyplot as plt\n",
- "import numpy as np\n",
- "import typing"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "import cara.models\n",
- "\n",
- "\n",
- "def prepare_model(volume, n_infected=1, n_exposed=10, mask='Type I') -> cara.models.Model:\n",
- " \"\"\"\n",
- " Transform configurable values into a cara model instance.\n",
- " \n",
- " \"\"\"\n",
- " model = cara.models.Model(\n",
- " room=cara.models.Room(volume=volume),\n",
- " ventilation=cara.models.PeriodicWindow(period=120, duration=120, inside_temp=293, outside_temp=283,\n",
- " window_height=1.6, opening_length=0.6, cd_b=0.6),\n",
- " infected=cara.models.InfectedPerson(\n",
- " virus=cara.models.Virus.types['SARS_CoV_2'],\n",
- " present_times=((0, 4), (5, 8)),\n",
- " mask=cara.models.Mask.types[mask],\n",
- " activity=cara.models.Activity.types['Light exercise'],\n",
- " expiration=cara.models.Expiration.types['Unmodulated Vocalization'],\n",
- " ),\n",
- " infected_occupants=n_infected,\n",
- " exposed_occupants=n_exposed,\n",
- " exposed_activity=cara.models.Activity.types['Light exercise'],\n",
- " )\n",
- " return model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Setup our plotting environment.\n",
- "plt.interactive(False)\n",
- "fig_concentration_over_time = plt.figure()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Define some useful widget machinery.\n",
- "\n",
- "def collapsible(widgets_to_collapse: typing.List, title: str, start_collapsed=True):\n",
- " collapsed = widgets.Accordion([widgets.VBox(widgets_to_collapse)])\n",
- " collapsed.set_title(0, title)\n",
- " if start_collapsed:\n",
- " collapsed.selected_index = None\n",
- " return collapsed\n",
- "\n",
- "\n",
- "def widget_group(label_widget_pairs):\n",
- " labels, widgets_ = zip(*label_widget_pairs) \n",
- " labels_w = widgets.VBox(labels)\n",
- " widgets_w = widgets.VBox(widgets_)\n",
- " return widgets.HBox([labels_w, widgets_w])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "eb109c0f63e149d69e763aec5d404db2",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Accordion(children=(VBox(children=(HBox(children=(VBox(children=(Label(value='Room volume'),)), VBox(children=…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "room_volume = widgets.IntSlider(value=75, min=10, max=150)\n",
- "mask_used = widgets.Checkbox(value=True, description='Mask worn')\n",
- "\n",
- "collapsible(\n",
- " [widget_group(\n",
- " [[widgets.Label('Room volume'), room_volume]]\n",
- " )],\n",
- " title='Specification of workplace', start_collapsed=False,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "49ad604786f546f58dba54f1f6e7eded",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Accordion(children=(VBox(children=(HBox(children=(VBox(children=(Label(value='Ventilation type'),)), VBox(chil…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "ventilation_widgets = {\n",
- " 'Natural': widgets.Label('Currently hard-coded to window-example from mathematica notebook'),\n",
- " 'other': widgets.Label('Not yet implemented.')\n",
- "}\n",
- "for name, widget in ventilation_widgets.items():\n",
- " widget.layout.visible = False\n",
- "\n",
- "ventilation_w = widgets.ToggleButtons(\n",
- " options=['Natural'], # cara.models.Ventilation.types.keys(),\n",
- ")\n",
- "def toggle_ventilation(value):\n",
- " for name, widget in ventilation_widgets.items():\n",
- " widget.layout.display = 'none'\n",
- " other = ventilation_widgets['other']\n",
- " widget = ventilation_widgets.get(value, other)\n",
- " widget.layout.visible = True\n",
- " widget.layout.display = 'block'\n",
- "\n",
- "ventilation_w.observe(lambda event: toggle_ventilation(event['new']), 'value')\n",
- "toggle_ventilation(ventilation_w.value)\n",
- "\n",
- "\n",
- "collapsible(\n",
- " [widget_group([[widgets.Label('Ventilation type'), ventilation_w]])]\n",
- " + list(ventilation_widgets.values()),\n",
- " title='Ventilation scheme'\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "8e76a49d0212462d81200a3959dcd3ff",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Accordion(children=(VBox(children=(HBox(children=(Canvas(footer_visible=False, header_visible=False, toolbar=T…"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "import matplotlib.pyplot as plt\n",
- "\n",
- "line = None\n",
- "# plt.ioff()\n",
- "# fig = plt.figure()\n",
- "fig = fig_concentration_over_time\n",
- "\n",
- "def plot_concentrations(_):\n",
- " global line\n",
- " model = prepare_model(room_volume.value)\n",
- "\n",
- " ts = np.arange(0, 10., 0.01)\n",
- " concentration = [model.concentration(t) for t in ts]\n",
- "\n",
- " ax = fig.gca()\n",
- " \n",
- " plt.text(0.5, 0.9, 'Without masks & window open', transform=ax.transAxes, ha='center')\n",
- " if line is None:\n",
- " ax.spines['right'].set_visible(False)\n",
- " ax.spines['top'].set_visible(False)\n",
- " [line] = plt.plot(ts, concentration)\n",
- " ax.set_xlabel('Time (hours)')\n",
- " ax.set_ylabel('Concentration ($q/m^3$)')\n",
- " plt.title('Concentration of infectious quanta aerosols')\n",
- " \n",
- " ax.set_ymargin(0.2)\n",
- " ax.set_ylim(bottom=0)\n",
- " else:\n",
- " line.set_data(ts, concentration)\n",
- " ax.relim()\n",
- " ax.autoscale_view()\n",
- " \n",
- " plt.draw()\n",
- "\n",
- "# print(f'Probability of infection: {np.round(model.[\"P\"], 1)}')\n",
- "# print(f'Expected number of new cases: {prepared[\"R0\"]}')\n",
- " \n",
- "\n",
- "# widgets.interact(\n",
- "# plot_concentrations,\n",
- "# volume=room_volume,\n",
- "# n_exposed=widgets.IntSlider(value=10, min=0, max=25),\n",
- "# n_infected=widgets.IntSlider(value=1, min=1, max=5),\n",
- "# );\n",
- "\n",
- "\n",
- "for observable in [room_volume]:\n",
- " observable.observe(plot_concentrations)\n",
- "\n",
- "plot_concentrations(1)\n",
- "\n",
- "fig.canvas.toolbar_visible = True\n",
- "fig.canvas.toolbar.collapsed = True\n",
- "fig.canvas.footer_visible = False\n",
- "fig.canvas.header_visible = False\n",
- "\n",
- "\n",
- "collapsible([\n",
- " widgets.HBox([\n",
- " fig.canvas,\n",
- " # text_report,\n",
- " ])\n",
- "], 'Report', start_collapsed=False)\n"
- ]
- },
{
"cell_type": "code",
"execution_count": null,
- "metadata": {},
+ "metadata": {
+ "scrolled": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
"outputs": [],
- "source": []
+ "source": [
+ "import cara.apps\n",
+ "\n",
+ "app = cara.apps.ExpertApplication()\n",
+ "app.widget"
+ ]
}
],
"metadata": {
diff --git a/cara/apps.py b/cara/apps.py
new file mode 100644
index 00000000..eb2b530a
--- /dev/null
+++ b/cara/apps.py
@@ -0,0 +1,202 @@
+import typing
+import uuid
+
+import ipympl.backend_nbagg
+import ipywidgets as widgets
+import numpy as np
+import matplotlib
+import matplotlib.figure
+
+from cara import models
+from cara import state
+
+
+def collapsible(widgets_to_collapse: typing.List, title: str, start_collapsed=True):
+ collapsed = widgets.Accordion([widgets.VBox(widgets_to_collapse)])
+ collapsed.set_title(0, title)
+ if start_collapsed:
+ collapsed.selected_index = None
+ return collapsed
+
+
+def widget_group(label_widget_pairs):
+ labels, widgets_ = zip(*label_widget_pairs)
+ labels_w = widgets.VBox(labels)
+ widgets_w = widgets.VBox(widgets_)
+ return widgets.HBox([labels_w, widgets_w])
+
+
+class ConcentrationFigure:
+ def __init__(self):
+ self.figure = matplotlib.figure.Figure(figsize=(9, 6))
+ self.ax = self.figure.add_subplot(1, 1, 1)
+ self.line = None
+
+ def update(self, model: models.Model):
+ resolution = 600
+ ts = np.linspace(0, 10, resolution)
+ concentration = [model.concentration(t) for t in ts]
+ if self.line is None:
+ [self.line] = self.ax.plot(ts, concentration)
+ ax = self.ax
+
+ ax.text(0.5, 0.9, 'Without masks & window open', transform=ax.transAxes, ha='center')
+
+ ax.spines['right'].set_visible(False)
+ ax.spines['top'].set_visible(False)
+
+ ax.set_xlabel('Time (hours)')
+ ax.set_ylabel('Concentration ($q/m^3$)')
+ ax.set_title('Concentration of infectious quanta aerosols')
+ ax.set_ymargin(0.2)
+ ax.set_ylim(bottom=0)
+ else:
+ self.line.set_data(ts, concentration)
+ self.ax.relim()
+ self.ax.autoscale_view()
+ self.figure.canvas.draw()
+
+
+def ipympl_canvas(figure: matplotlib.figure.Figure):
+ # Make a plain matplotlib figure render as a Jupyter widget.
+ matplotlib.interactive(False)
+ ipympl.backend_nbagg.new_figure_manager_given_figure(uuid.uuid1(), figure)
+ figure.canvas.toolbar_visible = True
+ figure.canvas.toolbar.collapsed = True
+ figure.canvas.footer_visible = False
+ figure.canvas.header_visible = False
+
+
+class WidgetView:
+ def __init__(self, model_state: state.DataclassState):
+ self.model_state = model_state
+ self.model_state.dcs_observe(self.update)
+ #: The widgets that this view produces (inputs and outputs together)
+ self.widget = widgets.VBox([])
+ self.widgets = {}
+ self.plots = []
+ self.construct_widgets()
+ # Trigger the first result.
+ self.update()
+
+ def construct_widgets(self):
+ # Build the input widgets.
+ self._build_widget(self.model_state)
+
+ # And the output widget figure.
+ concentration = ConcentrationFigure()
+ self.plots.append(concentration)
+ ipympl_canvas(concentration.figure)
+
+ self.widgets['results'] = collapsible([
+ widgets.HBox([
+ concentration.figure.canvas,
+ ])
+ ], 'Results', start_collapsed=False)
+
+ # Join inputs and outputs together in a single widget for convenience.
+ self.widget.children += (self.widgets['results'], )
+
+ def prepare_output(self):
+ pass
+
+ def update(self):
+ model = self.model_state.dcs_instance()
+ for plot in self.plots:
+ plot.update(model)
+
+ def _build_widget(self, node):
+ if isinstance(node, state.DataclassState):
+ if node._base == models.Ventilation:
+ self.widget.children += (self._build_ventilation(node), )
+ elif node._base == models.Room:
+ self.widget.children += (self._build_room(node), )
+ else:
+ # Don't do anything with this state, but recurse down in case
+ # its children want widgets.
+ for name, child in node._data.items():
+ self._build_widget(child)
+
+ def _build_room(self, node):
+ room_volume = widgets.IntSlider(value=node.volume, min=10, max=150)
+ mask_used = widgets.Checkbox(value=True, description='Mask worn')
+
+ def on_value_change(change):
+ node.volume = change['new']
+
+ # TODO: Link the state back to the widget, not just the other way around.
+ room_volume.observe(on_value_change, names=['value'])
+
+ widget = collapsible(
+ [widget_group(
+ [[widgets.Label('Room volume'), room_volume]]
+ )],
+ title='Specification of workplace', start_collapsed=False,
+ )
+ return widget
+
+ def _build_ventilation(self, node):
+ ventilation_widgets = {
+ 'Natural': widgets.Label('Currently hard-coded to window-example from mathematica notebook'),
+ 'other': widgets.Label('Not yet implemented.')
+ }
+ for name, widget in ventilation_widgets.items():
+ widget.layout.visible = False
+
+ ventilation_w = widgets.ToggleButtons(
+ options=ventilation_widgets.keys(),
+ )
+
+ def toggle_ventilation(value):
+ for name, widget in ventilation_widgets.items():
+ widget.layout.display = 'none'
+ other = ventilation_widgets['other']
+ widget = ventilation_widgets.get(value, other)
+ widget.layout.visible = True
+ widget.layout.display = 'block'
+
+ ventilation_w.observe(lambda event: toggle_ventilation(event['new']), 'value')
+ toggle_ventilation(ventilation_w.value)
+
+ w = collapsible(
+ [widget_group([[widgets.Label('Ventilation type'), ventilation_w]])]
+ + list(ventilation_widgets.values()),
+ title='Ventilation scheme'
+ )
+ return w
+
+ def present(self):
+ return self.widget
+
+
+baseline_model = models.Model(
+ room=models.Room(volume=75),
+ ventilation=models.PeriodicWindow(
+ period=120, duration=120, inside_temp=293, outside_temp=283, cd_b=0.6,
+ window_height=1.6, opening_length=0.6,
+ ),
+ infected=models.InfectedPerson(
+ virus=models.Virus.types['SARS_CoV_2'],
+ present_times=((0, 4), (5, 8)),
+ mask=models.Mask.types['No mask'],
+ activity=models.Activity.types['Light exercise'],
+ expiration=models.Expiration.types['Unmodulated Vocalization'],
+ ),
+ infected_occupants=1,
+ exposed_occupants=10,
+ exposed_activity=models.Activity.types['Light exercise'],
+)
+
+
+class ExpertApplication:
+ def __init__(self):
+ self.model_state = state.DataclassState(models.Model)
+ self.model_state.dcs_update_from(
+ baseline_model
+ )
+ self.view = WidgetView(self.model_state)
+ # self._widget = widgets.Text("WIP")
+
+ @property
+ def widget(self):
+ return self.view.present()
diff --git a/cara/state.py b/cara/state.py
new file mode 100644
index 00000000..97dc7a0f
--- /dev/null
+++ b/cara/state.py
@@ -0,0 +1,229 @@
+"""
+This module is entirely in support of providing a convenient mutable counterpart
+to frozen dataclasses. Significant effort went into to trying to use traitlets
+for this purpose, but the need to define class-level attributes proved to be a
+limitation that meant we could not mutate the state from one subclass to another
+after the state was instantiated.
+
+This module MUST not import other parts of cara as this would point at a
+leaky abstraction.
+
+"""
+from contextlib import contextmanager
+import dataclasses
+import typing
+
+
+Datamodel_T = typing.Type
+dataclass_instance = typing.Any
+
+
+class StateBuilder:
+ def visit(self, field: dataclasses.Field):
+ builder = self.resolve_builder(field)
+ return builder(field.type)
+
+ def resolve_builder(self, field: dataclasses.Field):
+ method_name = [
+ f'build_name_{field.name}',
+ f'build_type_{field.type.__name__}',
+ ]
+ for name in method_name:
+ method = getattr(self, name, None)
+ if method is not None:
+ return method
+ return self.build_generic
+
+ def build_generic(self, type_to_build: typing.Type):
+ return DataclassState(type_to_build, self)
+
+
+class DataclassState:
+ """
+ Represents the state of a frozen dataclass.
+ No type checking of the attributes is attempted.
+
+ Setting the state can be done with:
+
+ setattr(state, attr, value)
+
+ Accessing the instance that this state represents can be done with:
+
+ state.dcs_instance()
+
+ Changing the type to a subclass of the base that this state represents can be done with:
+
+ state.dcs_set_instance_type(ASubclassOfBase)
+
+ """
+
+ def __init__(self, dataclass: Datamodel_T, state_builder=StateBuilder()):
+ # Note that the constructor does *not* insert any data by default. It
+ # therefore doesn't build nested DataclassState instances when a dataclass contains another.
+ # For that, use the build classmethod.
+ if not dataclasses.is_dataclass(dataclass):
+ raise TypeError("The given class is not a valid dataclass")
+ if not isinstance(dataclass, type):
+ raise TypeError("A dataclass type must be provided, not an instance of one")
+
+ with self._object_setattr():
+ #: The base instance which this state must support.
+ self._base = dataclass
+ #: The actual instance type that this state represents (i.e. may be a
+ #: subclass of _base).
+ self._instance_type = dataclass
+
+ #: The instance of dataclass which this state represents. Undefined until
+ #: sufficient data is provided.
+ self._instance = None
+ self._data = {}
+ self._observers: typing.List[callable] = []
+ self._state_builder = state_builder
+
+ self.dcs_set_instance_type(dataclass)
+
+ def __repr__(self):
+ return f""
+
+ def _instance_attrs(self):
+ return [field.name for field in dataclasses.fields(self._instance_type)]
+
+ def dcs_observe(self, callback: typing.Callable):
+ self._observers.append(callback)
+
+ def dcs_update_from(self, data: dataclass_instance):
+ self.dcs_set_instance_type(data.__class__)
+ for field in dataclasses.fields(data):
+ attr = field.name
+ current_value = self._data.get(attr, None)
+ new_value = getattr(data, attr)
+ if dataclasses.is_dataclass(field.type):
+ assert isinstance(current_value, DataclassState)
+ current_value.dcs_update_from(new_value)
+ else:
+ self._data[attr] = new_value
+
+ def _fire_observers(self):
+ self._instance = None
+ for observer in self._observers:
+ observer()
+
+ @contextmanager
+ def _object_setattr(self):
+ self._use_base_setattr = True
+ yield
+ self._use_base_setattr = False
+
+ def __getattr__(self, name):
+ try:
+ return super().__getattribute__(name)
+ except AttributeError:
+ pass
+ if name in self._data:
+ return self._data[name]
+ elif name in self._instance_attrs():
+ raise ValueError(f"State not yet set for {name}")
+ else:
+ raise AttributeError(f"Attribute {name} does not exist on {self._instance_type.__name__}")
+
+ def __setattr__(self, name, value):
+ if name in self.__dict__ or self.__dict__.get('_use_base_setattr', True):
+ return object.__setattr__(self, name, value)
+ if name in self._instance_attrs():
+ self._dcs_set_value(name, value)
+
+ def _dcs_set_value(self, attr_name, value):
+ valid_attrs = self._instance_attrs()
+ if isinstance(value, DataclassState):
+ # TODO: We need to check that the value is acceptable. (needs
+ # thinking about)
+ # assert value._base ==
+ pass
+ # TODO: Inject some notifications here to tell any holding DataclassState
+ # instances that we've changed.
+
+ if attr_name in valid_attrs:
+ self._data[attr_name] = value
+ self._fire_observers()
+ else:
+ raise AttributeError(f"No attribute {attr_name} on a {self._instance_type.__name__}")
+
+ def dcs_set_instance_type(self, instance_dataclass: Datamodel_T):
+ if not dataclasses.is_dataclass(instance_dataclass):
+ raise TypeError("The given class is not a valid dataclass")
+ if not issubclass(instance_dataclass, self._base):
+ raise TypeError(f"The dataclass type provided ({instance_dataclass}) must be a subclass of the base ({self._base})")
+ self._instance_type = instance_dataclass
+
+ self._data.clear()
+ for field in dataclasses.fields(instance_dataclass):
+ if dataclasses.is_dataclass(field.type):
+ self._data[field.name] = self._state_builder.visit(field)
+ self._data[field.name].dcs_observe(self._fire_observers)
+
+ def _instance_state(self):
+ # Note: this method should not validate that the args are complete or
+ # overspecified.
+ kwargs = {}
+ for name, data in self._data.items():
+ if isinstance(data, DataclassState):
+ data = data._instance_state()
+ kwargs[name] = data
+ return kwargs
+
+ def _instance_kwargs(self):
+ # Note: this method should not validate that the args are complete or
+ # overspecified.
+ kwargs = {}
+ for name, data in self._data.items():
+ if isinstance(data, DataclassState):
+ data = data.dcs_instance()
+ kwargs[name] = data
+ return kwargs
+
+ def dcs_instance(self):
+ if self._instance is None:
+ # TODO: Check if we are able to create an instance with our data...
+ self._instance = self._instance_type(**self._instance_kwargs())
+ return self._instance
+
+
+class DataclassStatePredefined(DataclassState):
+ """
+ Only a pre-defined selection of states for the given type are allowed.
+ Selected by name (the keys in the dictionary).
+
+ You can change the chosen state with:
+
+ state.dcs_select(name)
+
+ """
+ def __init__(self, dataclass: Datamodel_T, choices: typing.Dict[typing.Hashable, dataclass_instance]):
+ super().__init__(dataclass=dataclass)
+
+ with self._object_setattr():
+ self._choices = choices
+ self._selected = None
+ # Pick the first choice until we know otherwise.
+ self.dcs_select(list(choices.keys())[0])
+
+ def dcs_select(self, name: typing.Hashable):
+ if name not in self._choices:
+ raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._choices)}')
+ self._selected = name
+ self._instance = self._choices[name]
+
+ def dcs_instance(self):
+ return self._choices[self._selected]
+
+ def __repr__(self):
+ return f""
+
+ def _instance_kwargs(self):
+ raise NotImplementedError("Doesn't make much sense")
+
+ def _instance_state(self):
+ return dataclasses.asdict(self.dcs_instance())
+
+ def _instance_kwargs(self):
+ return dataclasses.asdict(self.dcs_instance())
diff --git a/cara/tests/test_apps.py b/cara/tests/test_apps.py
new file mode 100644
index 00000000..26851489
--- /dev/null
+++ b/cara/tests/test_apps.py
@@ -0,0 +1,9 @@
+import cara.apps
+
+
+def test_app():
+ # To start with, let's just test that the application runs. We don't try to
+ # do anything fancy to verify how it looks etc., we leave that for manual
+ # testing.
+ expert_app = cara.apps.ExpertApplication()
+ assert expert_app.model_state.room.volume == 75
diff --git a/cara/tests/test_state.py b/cara/tests/test_state.py
new file mode 100644
index 00000000..4d228836
--- /dev/null
+++ b/cara/tests/test_state.py
@@ -0,0 +1,161 @@
+from dataclasses import dataclass
+import typing
+from unittest.mock import Mock
+
+import pytest
+
+from cara import state
+
+
+@dataclass
+class DCSimple:
+ attr1: str
+ attr2: int
+
+
+@dataclass
+class DCSimpleSubclass(DCSimple):
+ attr3: float
+
+
+@dataclass
+class DCOverrideSubclass(DCSimple):
+ attr1: float
+
+
+@dataclass
+class DCClassVar(DCSimple):
+ a_class_var: typing.ClassVar[int]
+
+
+@dataclass
+class DCRecursive(DCSimple):
+ simple: DCSimple
+
+
+@dataclass
+class DCNested:
+ simple: DCSimple
+ others: typing.List[DCSimple]
+ vanilla: str = 'Default'
+
+
+@dataclass
+class DCNestedDeep:
+ child: DCNested
+
+
+@pytest.fixture
+def dc_simple():
+ return DCSimple
+
+
+def test_DCS_construct():
+ s = state.DataclassState(DCSimple)
+ assert repr(s) == ''
+
+ with pytest.raises(TypeError, match=r"A dataclass type must be provided, not an instance of one"):
+ state.DataclassState(DCSimple('', 1))
+
+ with pytest.raises(TypeError, match="The given class is not a valid dataclass"):
+ state.DataclassState(None)
+
+
+def test_DCS_construct_nested():
+ s = state.DataclassState(DCNested)
+ assert repr(s) == ""
+
+
+@pytest.mark.xfail
+def test_DCS_subclass():
+ s = state.DataclassState(DCSimple)
+ s.dcs_set_instance_type(DCSimpleSubclass)
+ s.set('attr3', 3.14)
+ assert s._instance_kwargs() == {'attr3': 3.14}
+ s.dcs_set_instance_type(DCSimple)
+ # TODO: Make this fail.
+ assert s._instance_kwargs() == {}
+
+
+def test_DCS_setattr():
+ s = state.DataclassState(DCSimple)
+ s.attr1 = 'Hello world'
+ assert s._instance_kwargs() == {'attr1': 'Hello world'}
+
+
+@pytest.mark.xfail
+def test_DCS_type_check():
+ s = state.DataclassState(DCSimple)
+ with pytest.raises(TypeError):
+ # TODO: Should we make this fail? It involves type-checking / validation.
+ s.attr1 = 1
+
+
+def test_DCS_update_from_instance():
+ s = state.DataclassState(DCSimple)
+ s.dcs_update_from(DCSimple('a1', 2))
+ assert s._instance_type == DCSimple
+ assert s._instance_kwargs() == {'attr1': 'a1', 'attr2': 2}
+
+
+def test_DCS_update_from_instance_subclass():
+ s = state.DataclassState(DCSimple)
+ s.dcs_update_from(DCSimpleSubclass('a1', 2, 3.14))
+ assert s._instance_type == DCSimpleSubclass
+ assert s._instance_kwargs() == {'attr1': 'a1', 'attr2': 2, 'attr3': 3.14}
+
+
+def test_DCS_update_from_instance_nested():
+ s = state.DataclassState(DCNested)
+ nested = DCNested(DCSimpleSubclass('a1', 2, 3.14), [])
+ s.dcs_update_from(nested)
+ assert s.simple.dcs_instance() == nested.simple
+ assert s.dcs_instance() == nested
+
+
+def test_observe_instance_nested():
+ top_level = Mock()
+ nested = Mock()
+
+ s = state.DataclassState(DCNested)
+
+ s.dcs_observe(top_level)
+ s.simple.dcs_observe(nested)
+
+ s.simple.attr1 = 'something new'
+ top_level.assert_called_with()
+ nested.assert_called_with()
+
+ top_level.reset_mock()
+ nested.reset_mock()
+ s.vanilla = 'something new'
+ top_level.assert_called_with()
+ nested.assert_not_called()
+
+
+def test_DCS_predefined():
+ opt1 = DCSimple('a', 1)
+ opt2 = DCSimpleSubclass('b', 2, 3.14)
+ s = state.DataclassStatePredefined(
+ DCSimple, {'option 1': opt1, 'option 2': opt2}
+ )
+ assert s._selected == 'option 1'
+ # TODO: This should fail.
+ s.attr1 = 'can I set it?'
+ assert s.dcs_instance() == opt1
+
+ s.dcs_select('option 2')
+ assert s.dcs_instance() == opt2
+
+ assert repr(s) == ""
+
+ # TODO: This should fail too.
+ s.dcs_update_from(opt1)
+ assert s.dcs_instance() == opt2
+
+
+def test_DCS_non_dataclass_attrs():
+ val = DCClassVar('a', 1)
+ s = state.DataclassState(DCSimple)
+ s.dcs_update_from(val)
+ s.dcs_instance() == val
diff --git a/setup.py b/setup.py
index a28560ad..77ea1551 100644
--- a/setup.py
+++ b/setup.py
@@ -17,17 +17,17 @@ with (HERE / 'README.md').open('rt') as fh:
REQUIREMENTS: dict = {
'core': [
'dataclasses; python_version < "3.7"',
- 'numpy',
- ],
- 'app': [
'ipykernel',
'ipympl',
'ipywidgets',
'matplotlib',
+ 'numpy',
'voila >=0.2.4',
],
+ 'app': [],
'test': [
'pytest',
+ 'pytest-tornasync', # Unused, but needed because of a downstream dependency.
],
'dev': [
'jupyterlab',