Abstract the model state so that we can mutate it conveniently.
This rather large change adds a layer between the underlying (immutable) model and the application. In doing so we can avoid the use of a global state (useful for the purposes of configuring multiple models in the same application later on) and it also unlocks the ability to implement an MVC-like separation of concerns - again, the intention is that when it comes to comparisons, we will just be able to re-use our application views. I was hoping that ``cara.state`` could have been avoided in lieu of using traitlets, but unfortunately I found a number of limitations which were prohibitive for its use here. Foremost of which was the lack of first-class dataclass support and the difficulty in needing either to use instances of the model (immutable) or duplicate the model and its structure in a mutable form and use the ``traitlets.Instance`` type. Instead I opted for doing it myself - the ``cara.state`` module would make a very good standalone project in the future.
This commit is contained in:
parent
69730bfb2a
commit
f63e1d3760
6 changed files with 616 additions and 252 deletions
261
app/cara.ipynb
261
app/cara.ipynb
|
|
@ -13,259 +13,22 @@
|
|||
"</p>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib widget\n",
|
||||
"import ipywidgets as widgets\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"import typing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cara.models\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def prepare_model(volume, n_infected=1, n_exposed=10, mask='Type I') -> cara.models.Model:\n",
|
||||
" \"\"\"\n",
|
||||
" Transform configurable values into a cara model instance.\n",
|
||||
" \n",
|
||||
" \"\"\"\n",
|
||||
" model = cara.models.Model(\n",
|
||||
" room=cara.models.Room(volume=volume),\n",
|
||||
" ventilation=cara.models.PeriodicWindow(period=120, duration=120, inside_temp=293, outside_temp=283,\n",
|
||||
" window_height=1.6, opening_length=0.6, cd_b=0.6),\n",
|
||||
" infected=cara.models.InfectedPerson(\n",
|
||||
" virus=cara.models.Virus.types['SARS_CoV_2'],\n",
|
||||
" present_times=((0, 4), (5, 8)),\n",
|
||||
" mask=cara.models.Mask.types[mask],\n",
|
||||
" activity=cara.models.Activity.types['Light exercise'],\n",
|
||||
" expiration=cara.models.Expiration.types['Unmodulated Vocalization'],\n",
|
||||
" ),\n",
|
||||
" infected_occupants=n_infected,\n",
|
||||
" exposed_occupants=n_exposed,\n",
|
||||
" exposed_activity=cara.models.Activity.types['Light exercise'],\n",
|
||||
" )\n",
|
||||
" return model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Setup our plotting environment.\n",
|
||||
"plt.interactive(False)\n",
|
||||
"fig_concentration_over_time = plt.figure()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define some useful widget machinery.\n",
|
||||
"\n",
|
||||
"def collapsible(widgets_to_collapse: typing.List, title: str, start_collapsed=True):\n",
|
||||
" collapsed = widgets.Accordion([widgets.VBox(widgets_to_collapse)])\n",
|
||||
" collapsed.set_title(0, title)\n",
|
||||
" if start_collapsed:\n",
|
||||
" collapsed.selected_index = None\n",
|
||||
" return collapsed\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def widget_group(label_widget_pairs):\n",
|
||||
" labels, widgets_ = zip(*label_widget_pairs) \n",
|
||||
" labels_w = widgets.VBox(labels)\n",
|
||||
" widgets_w = widgets.VBox(widgets_)\n",
|
||||
" return widgets.HBox([labels_w, widgets_w])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "eb109c0f63e149d69e763aec5d404db2",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Accordion(children=(VBox(children=(HBox(children=(VBox(children=(Label(value='Room volume'),)), VBox(children=…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"room_volume = widgets.IntSlider(value=75, min=10, max=150)\n",
|
||||
"mask_used = widgets.Checkbox(value=True, description='Mask worn')\n",
|
||||
"\n",
|
||||
"collapsible(\n",
|
||||
" [widget_group(\n",
|
||||
" [[widgets.Label('Room volume'), room_volume]]\n",
|
||||
" )],\n",
|
||||
" title='Specification of workplace', start_collapsed=False,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "49ad604786f546f58dba54f1f6e7eded",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Accordion(children=(VBox(children=(HBox(children=(VBox(children=(Label(value='Ventilation type'),)), VBox(chil…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ventilation_widgets = {\n",
|
||||
" 'Natural': widgets.Label('Currently hard-coded to window-example from mathematica notebook'),\n",
|
||||
" 'other': widgets.Label('Not yet implemented.')\n",
|
||||
"}\n",
|
||||
"for name, widget in ventilation_widgets.items():\n",
|
||||
" widget.layout.visible = False\n",
|
||||
"\n",
|
||||
"ventilation_w = widgets.ToggleButtons(\n",
|
||||
" options=['Natural'], # cara.models.Ventilation.types.keys(),\n",
|
||||
")\n",
|
||||
"def toggle_ventilation(value):\n",
|
||||
" for name, widget in ventilation_widgets.items():\n",
|
||||
" widget.layout.display = 'none'\n",
|
||||
" other = ventilation_widgets['other']\n",
|
||||
" widget = ventilation_widgets.get(value, other)\n",
|
||||
" widget.layout.visible = True\n",
|
||||
" widget.layout.display = 'block'\n",
|
||||
"\n",
|
||||
"ventilation_w.observe(lambda event: toggle_ventilation(event['new']), 'value')\n",
|
||||
"toggle_ventilation(ventilation_w.value)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"collapsible(\n",
|
||||
" [widget_group([[widgets.Label('Ventilation type'), ventilation_w]])]\n",
|
||||
" + list(ventilation_widgets.values()),\n",
|
||||
" title='Ventilation scheme'\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "8e76a49d0212462d81200a3959dcd3ff",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Accordion(children=(VBox(children=(HBox(children=(Canvas(footer_visible=False, header_visible=False, toolbar=T…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"line = None\n",
|
||||
"# plt.ioff()\n",
|
||||
"# fig = plt.figure()\n",
|
||||
"fig = fig_concentration_over_time\n",
|
||||
"\n",
|
||||
"def plot_concentrations(_):\n",
|
||||
" global line\n",
|
||||
" model = prepare_model(room_volume.value)\n",
|
||||
"\n",
|
||||
" ts = np.arange(0, 10., 0.01)\n",
|
||||
" concentration = [model.concentration(t) for t in ts]\n",
|
||||
"\n",
|
||||
" ax = fig.gca()\n",
|
||||
" \n",
|
||||
" plt.text(0.5, 0.9, 'Without masks & window open', transform=ax.transAxes, ha='center')\n",
|
||||
" if line is None:\n",
|
||||
" ax.spines['right'].set_visible(False)\n",
|
||||
" ax.spines['top'].set_visible(False)\n",
|
||||
" [line] = plt.plot(ts, concentration)\n",
|
||||
" ax.set_xlabel('Time (hours)')\n",
|
||||
" ax.set_ylabel('Concentration ($q/m^3$)')\n",
|
||||
" plt.title('Concentration of infectious quanta aerosols')\n",
|
||||
" \n",
|
||||
" ax.set_ymargin(0.2)\n",
|
||||
" ax.set_ylim(bottom=0)\n",
|
||||
" else:\n",
|
||||
" line.set_data(ts, concentration)\n",
|
||||
" ax.relim()\n",
|
||||
" ax.autoscale_view()\n",
|
||||
" \n",
|
||||
" plt.draw()\n",
|
||||
"\n",
|
||||
"# print(f'Probability of infection: {np.round(model.[\"P\"], 1)}')\n",
|
||||
"# print(f'Expected number of new cases: {prepared[\"R0\"]}')\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"# widgets.interact(\n",
|
||||
"# plot_concentrations,\n",
|
||||
"# volume=room_volume,\n",
|
||||
"# n_exposed=widgets.IntSlider(value=10, min=0, max=25),\n",
|
||||
"# n_infected=widgets.IntSlider(value=1, min=1, max=5),\n",
|
||||
"# );\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for observable in [room_volume]:\n",
|
||||
" observable.observe(plot_concentrations)\n",
|
||||
"\n",
|
||||
"plot_concentrations(1)\n",
|
||||
"\n",
|
||||
"fig.canvas.toolbar_visible = True\n",
|
||||
"fig.canvas.toolbar.collapsed = True\n",
|
||||
"fig.canvas.footer_visible = False\n",
|
||||
"fig.canvas.header_visible = False\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"collapsible([\n",
|
||||
" widgets.HBox([\n",
|
||||
" fig.canvas,\n",
|
||||
" # text_report,\n",
|
||||
" ])\n",
|
||||
"], 'Report', start_collapsed=False)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"scrolled": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"import cara.apps\n",
|
||||
"\n",
|
||||
"app = cara.apps.ExpertApplication()\n",
|
||||
"app.widget"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
|||
202
cara/apps.py
Normal file
202
cara/apps.py
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
import typing
|
||||
import uuid
|
||||
|
||||
import ipympl.backend_nbagg
|
||||
import ipywidgets as widgets
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import matplotlib.figure
|
||||
|
||||
from cara import models
|
||||
from cara import state
|
||||
|
||||
|
||||
def collapsible(widgets_to_collapse: typing.List, title: str, start_collapsed=True):
|
||||
collapsed = widgets.Accordion([widgets.VBox(widgets_to_collapse)])
|
||||
collapsed.set_title(0, title)
|
||||
if start_collapsed:
|
||||
collapsed.selected_index = None
|
||||
return collapsed
|
||||
|
||||
|
||||
def widget_group(label_widget_pairs):
|
||||
labels, widgets_ = zip(*label_widget_pairs)
|
||||
labels_w = widgets.VBox(labels)
|
||||
widgets_w = widgets.VBox(widgets_)
|
||||
return widgets.HBox([labels_w, widgets_w])
|
||||
|
||||
|
||||
class ConcentrationFigure:
|
||||
def __init__(self):
|
||||
self.figure = matplotlib.figure.Figure(figsize=(9, 6))
|
||||
self.ax = self.figure.add_subplot(1, 1, 1)
|
||||
self.line = None
|
||||
|
||||
def update(self, model: models.Model):
|
||||
resolution = 600
|
||||
ts = np.linspace(0, 10, resolution)
|
||||
concentration = [model.concentration(t) for t in ts]
|
||||
if self.line is None:
|
||||
[self.line] = self.ax.plot(ts, concentration)
|
||||
ax = self.ax
|
||||
|
||||
ax.text(0.5, 0.9, 'Without masks & window open', transform=ax.transAxes, ha='center')
|
||||
|
||||
ax.spines['right'].set_visible(False)
|
||||
ax.spines['top'].set_visible(False)
|
||||
|
||||
ax.set_xlabel('Time (hours)')
|
||||
ax.set_ylabel('Concentration ($q/m^3$)')
|
||||
ax.set_title('Concentration of infectious quanta aerosols')
|
||||
ax.set_ymargin(0.2)
|
||||
ax.set_ylim(bottom=0)
|
||||
else:
|
||||
self.line.set_data(ts, concentration)
|
||||
self.ax.relim()
|
||||
self.ax.autoscale_view()
|
||||
self.figure.canvas.draw()
|
||||
|
||||
|
||||
def ipympl_canvas(figure: matplotlib.figure.Figure):
|
||||
# Make a plain matplotlib figure render as a Jupyter widget.
|
||||
matplotlib.interactive(False)
|
||||
ipympl.backend_nbagg.new_figure_manager_given_figure(uuid.uuid1(), figure)
|
||||
figure.canvas.toolbar_visible = True
|
||||
figure.canvas.toolbar.collapsed = True
|
||||
figure.canvas.footer_visible = False
|
||||
figure.canvas.header_visible = False
|
||||
|
||||
|
||||
class WidgetView:
|
||||
def __init__(self, model_state: state.DataclassState):
|
||||
self.model_state = model_state
|
||||
self.model_state.dcs_observe(self.update)
|
||||
#: The widgets that this view produces (inputs and outputs together)
|
||||
self.widget = widgets.VBox([])
|
||||
self.widgets = {}
|
||||
self.plots = []
|
||||
self.construct_widgets()
|
||||
# Trigger the first result.
|
||||
self.update()
|
||||
|
||||
def construct_widgets(self):
|
||||
# Build the input widgets.
|
||||
self._build_widget(self.model_state)
|
||||
|
||||
# And the output widget figure.
|
||||
concentration = ConcentrationFigure()
|
||||
self.plots.append(concentration)
|
||||
ipympl_canvas(concentration.figure)
|
||||
|
||||
self.widgets['results'] = collapsible([
|
||||
widgets.HBox([
|
||||
concentration.figure.canvas,
|
||||
])
|
||||
], 'Results', start_collapsed=False)
|
||||
|
||||
# Join inputs and outputs together in a single widget for convenience.
|
||||
self.widget.children += (self.widgets['results'], )
|
||||
|
||||
def prepare_output(self):
|
||||
pass
|
||||
|
||||
def update(self):
|
||||
model = self.model_state.dcs_instance()
|
||||
for plot in self.plots:
|
||||
plot.update(model)
|
||||
|
||||
def _build_widget(self, node):
|
||||
if isinstance(node, state.DataclassState):
|
||||
if node._base == models.Ventilation:
|
||||
self.widget.children += (self._build_ventilation(node), )
|
||||
elif node._base == models.Room:
|
||||
self.widget.children += (self._build_room(node), )
|
||||
else:
|
||||
# Don't do anything with this state, but recurse down in case
|
||||
# its children want widgets.
|
||||
for name, child in node._data.items():
|
||||
self._build_widget(child)
|
||||
|
||||
def _build_room(self, node):
|
||||
room_volume = widgets.IntSlider(value=node.volume, min=10, max=150)
|
||||
mask_used = widgets.Checkbox(value=True, description='Mask worn')
|
||||
|
||||
def on_value_change(change):
|
||||
node.volume = change['new']
|
||||
|
||||
# TODO: Link the state back to the widget, not just the other way around.
|
||||
room_volume.observe(on_value_change, names=['value'])
|
||||
|
||||
widget = collapsible(
|
||||
[widget_group(
|
||||
[[widgets.Label('Room volume'), room_volume]]
|
||||
)],
|
||||
title='Specification of workplace', start_collapsed=False,
|
||||
)
|
||||
return widget
|
||||
|
||||
def _build_ventilation(self, node):
|
||||
ventilation_widgets = {
|
||||
'Natural': widgets.Label('Currently hard-coded to window-example from mathematica notebook'),
|
||||
'other': widgets.Label('Not yet implemented.')
|
||||
}
|
||||
for name, widget in ventilation_widgets.items():
|
||||
widget.layout.visible = False
|
||||
|
||||
ventilation_w = widgets.ToggleButtons(
|
||||
options=ventilation_widgets.keys(),
|
||||
)
|
||||
|
||||
def toggle_ventilation(value):
|
||||
for name, widget in ventilation_widgets.items():
|
||||
widget.layout.display = 'none'
|
||||
other = ventilation_widgets['other']
|
||||
widget = ventilation_widgets.get(value, other)
|
||||
widget.layout.visible = True
|
||||
widget.layout.display = 'block'
|
||||
|
||||
ventilation_w.observe(lambda event: toggle_ventilation(event['new']), 'value')
|
||||
toggle_ventilation(ventilation_w.value)
|
||||
|
||||
w = collapsible(
|
||||
[widget_group([[widgets.Label('Ventilation type'), ventilation_w]])]
|
||||
+ list(ventilation_widgets.values()),
|
||||
title='Ventilation scheme'
|
||||
)
|
||||
return w
|
||||
|
||||
def present(self):
|
||||
return self.widget
|
||||
|
||||
|
||||
baseline_model = models.Model(
|
||||
room=models.Room(volume=75),
|
||||
ventilation=models.PeriodicWindow(
|
||||
period=120, duration=120, inside_temp=293, outside_temp=283, cd_b=0.6,
|
||||
window_height=1.6, opening_length=0.6,
|
||||
),
|
||||
infected=models.InfectedPerson(
|
||||
virus=models.Virus.types['SARS_CoV_2'],
|
||||
present_times=((0, 4), (5, 8)),
|
||||
mask=models.Mask.types['No mask'],
|
||||
activity=models.Activity.types['Light exercise'],
|
||||
expiration=models.Expiration.types['Unmodulated Vocalization'],
|
||||
),
|
||||
infected_occupants=1,
|
||||
exposed_occupants=10,
|
||||
exposed_activity=models.Activity.types['Light exercise'],
|
||||
)
|
||||
|
||||
|
||||
class ExpertApplication:
|
||||
def __init__(self):
|
||||
self.model_state = state.DataclassState(models.Model)
|
||||
self.model_state.dcs_update_from(
|
||||
baseline_model
|
||||
)
|
||||
self.view = WidgetView(self.model_state)
|
||||
# self._widget = widgets.Text("WIP")
|
||||
|
||||
@property
|
||||
def widget(self):
|
||||
return self.view.present()
|
||||
229
cara/state.py
Normal file
229
cara/state.py
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
"""
|
||||
This module is entirely in support of providing a convenient mutable counterpart
|
||||
to frozen dataclasses. Significant effort went into to trying to use traitlets
|
||||
for this purpose, but the need to define class-level attributes proved to be a
|
||||
limitation that meant we could not mutate the state from one subclass to another
|
||||
after the state was instantiated.
|
||||
|
||||
This module MUST not import other parts of cara as this would point at a
|
||||
leaky abstraction.
|
||||
|
||||
"""
|
||||
from contextlib import contextmanager
|
||||
import dataclasses
|
||||
import typing
|
||||
|
||||
|
||||
Datamodel_T = typing.Type
|
||||
dataclass_instance = typing.Any
|
||||
|
||||
|
||||
class StateBuilder:
|
||||
def visit(self, field: dataclasses.Field):
|
||||
builder = self.resolve_builder(field)
|
||||
return builder(field.type)
|
||||
|
||||
def resolve_builder(self, field: dataclasses.Field):
|
||||
method_name = [
|
||||
f'build_name_{field.name}',
|
||||
f'build_type_{field.type.__name__}',
|
||||
]
|
||||
for name in method_name:
|
||||
method = getattr(self, name, None)
|
||||
if method is not None:
|
||||
return method
|
||||
return self.build_generic
|
||||
|
||||
def build_generic(self, type_to_build: typing.Type):
|
||||
return DataclassState(type_to_build, self)
|
||||
|
||||
|
||||
class DataclassState:
|
||||
"""
|
||||
Represents the state of a frozen dataclass.
|
||||
No type checking of the attributes is attempted.
|
||||
|
||||
Setting the state can be done with:
|
||||
|
||||
setattr(state, attr, value)
|
||||
|
||||
Accessing the instance that this state represents can be done with:
|
||||
|
||||
state.dcs_instance()
|
||||
|
||||
Changing the type to a subclass of the base that this state represents can be done with:
|
||||
|
||||
state.dcs_set_instance_type(ASubclassOfBase)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dataclass: Datamodel_T, state_builder=StateBuilder()):
|
||||
# Note that the constructor does *not* insert any data by default. It
|
||||
# therefore doesn't build nested DataclassState instances when a dataclass contains another.
|
||||
# For that, use the build classmethod.
|
||||
if not dataclasses.is_dataclass(dataclass):
|
||||
raise TypeError("The given class is not a valid dataclass")
|
||||
if not isinstance(dataclass, type):
|
||||
raise TypeError("A dataclass type must be provided, not an instance of one")
|
||||
|
||||
with self._object_setattr():
|
||||
#: The base instance which this state must support.
|
||||
self._base = dataclass
|
||||
#: The actual instance type that this state represents (i.e. may be a
|
||||
#: subclass of _base).
|
||||
self._instance_type = dataclass
|
||||
|
||||
#: The instance of dataclass which this state represents. Undefined until
|
||||
#: sufficient data is provided.
|
||||
self._instance = None
|
||||
self._data = {}
|
||||
self._observers: typing.List[callable] = []
|
||||
self._state_builder = state_builder
|
||||
|
||||
self.dcs_set_instance_type(dataclass)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<state for {self._instance_type.__name__}(**{self._instance_state()})>"
|
||||
|
||||
def _instance_attrs(self):
|
||||
return [field.name for field in dataclasses.fields(self._instance_type)]
|
||||
|
||||
def dcs_observe(self, callback: typing.Callable):
|
||||
self._observers.append(callback)
|
||||
|
||||
def dcs_update_from(self, data: dataclass_instance):
|
||||
self.dcs_set_instance_type(data.__class__)
|
||||
for field in dataclasses.fields(data):
|
||||
attr = field.name
|
||||
current_value = self._data.get(attr, None)
|
||||
new_value = getattr(data, attr)
|
||||
if dataclasses.is_dataclass(field.type):
|
||||
assert isinstance(current_value, DataclassState)
|
||||
current_value.dcs_update_from(new_value)
|
||||
else:
|
||||
self._data[attr] = new_value
|
||||
|
||||
def _fire_observers(self):
|
||||
self._instance = None
|
||||
for observer in self._observers:
|
||||
observer()
|
||||
|
||||
@contextmanager
|
||||
def _object_setattr(self):
|
||||
self._use_base_setattr = True
|
||||
yield
|
||||
self._use_base_setattr = False
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattribute__(name)
|
||||
except AttributeError:
|
||||
pass
|
||||
if name in self._data:
|
||||
return self._data[name]
|
||||
elif name in self._instance_attrs():
|
||||
raise ValueError(f"State not yet set for {name}")
|
||||
else:
|
||||
raise AttributeError(f"Attribute {name} does not exist on {self._instance_type.__name__}")
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in self.__dict__ or self.__dict__.get('_use_base_setattr', True):
|
||||
return object.__setattr__(self, name, value)
|
||||
if name in self._instance_attrs():
|
||||
self._dcs_set_value(name, value)
|
||||
|
||||
def _dcs_set_value(self, attr_name, value):
|
||||
valid_attrs = self._instance_attrs()
|
||||
if isinstance(value, DataclassState):
|
||||
# TODO: We need to check that the value is acceptable. (needs
|
||||
# thinking about)
|
||||
# assert value._base ==
|
||||
pass
|
||||
# TODO: Inject some notifications here to tell any holding DataclassState
|
||||
# instances that we've changed.
|
||||
|
||||
if attr_name in valid_attrs:
|
||||
self._data[attr_name] = value
|
||||
self._fire_observers()
|
||||
else:
|
||||
raise AttributeError(f"No attribute {attr_name} on a {self._instance_type.__name__}")
|
||||
|
||||
def dcs_set_instance_type(self, instance_dataclass: Datamodel_T):
|
||||
if not dataclasses.is_dataclass(instance_dataclass):
|
||||
raise TypeError("The given class is not a valid dataclass")
|
||||
if not issubclass(instance_dataclass, self._base):
|
||||
raise TypeError(f"The dataclass type provided ({instance_dataclass}) must be a subclass of the base ({self._base})")
|
||||
self._instance_type = instance_dataclass
|
||||
|
||||
self._data.clear()
|
||||
for field in dataclasses.fields(instance_dataclass):
|
||||
if dataclasses.is_dataclass(field.type):
|
||||
self._data[field.name] = self._state_builder.visit(field)
|
||||
self._data[field.name].dcs_observe(self._fire_observers)
|
||||
|
||||
def _instance_state(self):
|
||||
# Note: this method should not validate that the args are complete or
|
||||
# overspecified.
|
||||
kwargs = {}
|
||||
for name, data in self._data.items():
|
||||
if isinstance(data, DataclassState):
|
||||
data = data._instance_state()
|
||||
kwargs[name] = data
|
||||
return kwargs
|
||||
|
||||
def _instance_kwargs(self):
|
||||
# Note: this method should not validate that the args are complete or
|
||||
# overspecified.
|
||||
kwargs = {}
|
||||
for name, data in self._data.items():
|
||||
if isinstance(data, DataclassState):
|
||||
data = data.dcs_instance()
|
||||
kwargs[name] = data
|
||||
return kwargs
|
||||
|
||||
def dcs_instance(self):
|
||||
if self._instance is None:
|
||||
# TODO: Check if we are able to create an instance with our data...
|
||||
self._instance = self._instance_type(**self._instance_kwargs())
|
||||
return self._instance
|
||||
|
||||
|
||||
class DataclassStatePredefined(DataclassState):
|
||||
"""
|
||||
Only a pre-defined selection of states for the given type are allowed.
|
||||
Selected by name (the keys in the dictionary).
|
||||
|
||||
You can change the chosen state with:
|
||||
|
||||
state.dcs_select(name)
|
||||
|
||||
"""
|
||||
def __init__(self, dataclass: Datamodel_T, choices: typing.Dict[typing.Hashable, dataclass_instance]):
|
||||
super().__init__(dataclass=dataclass)
|
||||
|
||||
with self._object_setattr():
|
||||
self._choices = choices
|
||||
self._selected = None
|
||||
# Pick the first choice until we know otherwise.
|
||||
self.dcs_select(list(choices.keys())[0])
|
||||
|
||||
def dcs_select(self, name: typing.Hashable):
|
||||
if name not in self._choices:
|
||||
raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._choices)}')
|
||||
self._selected = name
|
||||
self._instance = self._choices[name]
|
||||
|
||||
def dcs_instance(self):
|
||||
return self._choices[self._selected]
|
||||
|
||||
def __repr__(self):
|
||||
return f"<state for {self._instance_type.__name__}. '{self._selected}' selected>"
|
||||
|
||||
def _instance_kwargs(self):
|
||||
raise NotImplementedError("Doesn't make much sense")
|
||||
|
||||
def _instance_state(self):
|
||||
return dataclasses.asdict(self.dcs_instance())
|
||||
|
||||
def _instance_kwargs(self):
|
||||
return dataclasses.asdict(self.dcs_instance())
|
||||
9
cara/tests/test_apps.py
Normal file
9
cara/tests/test_apps.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
import cara.apps
|
||||
|
||||
|
||||
def test_app():
|
||||
# To start with, let's just test that the application runs. We don't try to
|
||||
# do anything fancy to verify how it looks etc., we leave that for manual
|
||||
# testing.
|
||||
expert_app = cara.apps.ExpertApplication()
|
||||
assert expert_app.model_state.room.volume == 75
|
||||
161
cara/tests/test_state.py
Normal file
161
cara/tests/test_state.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
from dataclasses import dataclass
|
||||
import typing
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from cara import state
|
||||
|
||||
|
||||
@dataclass
|
||||
class DCSimple:
|
||||
attr1: str
|
||||
attr2: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class DCSimpleSubclass(DCSimple):
|
||||
attr3: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class DCOverrideSubclass(DCSimple):
|
||||
attr1: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class DCClassVar(DCSimple):
|
||||
a_class_var: typing.ClassVar[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DCRecursive(DCSimple):
|
||||
simple: DCSimple
|
||||
|
||||
|
||||
@dataclass
|
||||
class DCNested:
|
||||
simple: DCSimple
|
||||
others: typing.List[DCSimple]
|
||||
vanilla: str = 'Default'
|
||||
|
||||
|
||||
@dataclass
|
||||
class DCNestedDeep:
|
||||
child: DCNested
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dc_simple():
|
||||
return DCSimple
|
||||
|
||||
|
||||
def test_DCS_construct():
|
||||
s = state.DataclassState(DCSimple)
|
||||
assert repr(s) == '<state for DCSimple(**{})>'
|
||||
|
||||
with pytest.raises(TypeError, match=r"A dataclass type must be provided, not an instance of one"):
|
||||
state.DataclassState(DCSimple('', 1))
|
||||
|
||||
with pytest.raises(TypeError, match="The given class is not a valid dataclass"):
|
||||
state.DataclassState(None)
|
||||
|
||||
|
||||
def test_DCS_construct_nested():
|
||||
s = state.DataclassState(DCNested)
|
||||
assert repr(s) == "<state for DCNested(**{'simple': {}})>"
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_DCS_subclass():
|
||||
s = state.DataclassState(DCSimple)
|
||||
s.dcs_set_instance_type(DCSimpleSubclass)
|
||||
s.set('attr3', 3.14)
|
||||
assert s._instance_kwargs() == {'attr3': 3.14}
|
||||
s.dcs_set_instance_type(DCSimple)
|
||||
# TODO: Make this fail.
|
||||
assert s._instance_kwargs() == {}
|
||||
|
||||
|
||||
def test_DCS_setattr():
|
||||
s = state.DataclassState(DCSimple)
|
||||
s.attr1 = 'Hello world'
|
||||
assert s._instance_kwargs() == {'attr1': 'Hello world'}
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_DCS_type_check():
|
||||
s = state.DataclassState(DCSimple)
|
||||
with pytest.raises(TypeError):
|
||||
# TODO: Should we make this fail? It involves type-checking / validation.
|
||||
s.attr1 = 1
|
||||
|
||||
|
||||
def test_DCS_update_from_instance():
|
||||
s = state.DataclassState(DCSimple)
|
||||
s.dcs_update_from(DCSimple('a1', 2))
|
||||
assert s._instance_type == DCSimple
|
||||
assert s._instance_kwargs() == {'attr1': 'a1', 'attr2': 2}
|
||||
|
||||
|
||||
def test_DCS_update_from_instance_subclass():
|
||||
s = state.DataclassState(DCSimple)
|
||||
s.dcs_update_from(DCSimpleSubclass('a1', 2, 3.14))
|
||||
assert s._instance_type == DCSimpleSubclass
|
||||
assert s._instance_kwargs() == {'attr1': 'a1', 'attr2': 2, 'attr3': 3.14}
|
||||
|
||||
|
||||
def test_DCS_update_from_instance_nested():
|
||||
s = state.DataclassState(DCNested)
|
||||
nested = DCNested(DCSimpleSubclass('a1', 2, 3.14), [])
|
||||
s.dcs_update_from(nested)
|
||||
assert s.simple.dcs_instance() == nested.simple
|
||||
assert s.dcs_instance() == nested
|
||||
|
||||
|
||||
def test_observe_instance_nested():
|
||||
top_level = Mock()
|
||||
nested = Mock()
|
||||
|
||||
s = state.DataclassState(DCNested)
|
||||
|
||||
s.dcs_observe(top_level)
|
||||
s.simple.dcs_observe(nested)
|
||||
|
||||
s.simple.attr1 = 'something new'
|
||||
top_level.assert_called_with()
|
||||
nested.assert_called_with()
|
||||
|
||||
top_level.reset_mock()
|
||||
nested.reset_mock()
|
||||
s.vanilla = 'something new'
|
||||
top_level.assert_called_with()
|
||||
nested.assert_not_called()
|
||||
|
||||
|
||||
def test_DCS_predefined():
|
||||
opt1 = DCSimple('a', 1)
|
||||
opt2 = DCSimpleSubclass('b', 2, 3.14)
|
||||
s = state.DataclassStatePredefined(
|
||||
DCSimple, {'option 1': opt1, 'option 2': opt2}
|
||||
)
|
||||
assert s._selected == 'option 1'
|
||||
# TODO: This should fail.
|
||||
s.attr1 = 'can I set it?'
|
||||
assert s.dcs_instance() == opt1
|
||||
|
||||
s.dcs_select('option 2')
|
||||
assert s.dcs_instance() == opt2
|
||||
|
||||
assert repr(s) == "<state for DCSimple. 'option 2' selected>"
|
||||
|
||||
# TODO: This should fail too.
|
||||
s.dcs_update_from(opt1)
|
||||
assert s.dcs_instance() == opt2
|
||||
|
||||
|
||||
def test_DCS_non_dataclass_attrs():
|
||||
val = DCClassVar('a', 1)
|
||||
s = state.DataclassState(DCSimple)
|
||||
s.dcs_update_from(val)
|
||||
s.dcs_instance() == val
|
||||
6
setup.py
6
setup.py
|
|
@ -17,17 +17,17 @@ with (HERE / 'README.md').open('rt') as fh:
|
|||
REQUIREMENTS: dict = {
|
||||
'core': [
|
||||
'dataclasses; python_version < "3.7"',
|
||||
'numpy',
|
||||
],
|
||||
'app': [
|
||||
'ipykernel',
|
||||
'ipympl',
|
||||
'ipywidgets',
|
||||
'matplotlib',
|
||||
'numpy',
|
||||
'voila >=0.2.4',
|
||||
],
|
||||
'app': [],
|
||||
'test': [
|
||||
'pytest',
|
||||
'pytest-tornasync', # Unused, but needed because of a downstream dependency.
|
||||
],
|
||||
'dev': [
|
||||
'jupyterlab',
|
||||
|
|
|
|||
Loading…
Reference in a new issue