Merge branch 'comparison-tab' into 'master'
Expert app scenario comparisons See merge request cara/cara!89
This commit is contained in:
commit
d05708ea37
2 changed files with 303 additions and 112 deletions
|
|
@ -13,7 +13,7 @@ from cara import state
|
|||
from cara import data
|
||||
|
||||
|
||||
def collapsible(widgets_to_collapse: typing.List, title: str, start_collapsed=True):
|
||||
def collapsible(widgets_to_collapse: typing.List, title: str, start_collapsed=False):
|
||||
collapsed = widgets.Accordion([widgets.VBox(widgets_to_collapse)])
|
||||
collapsed.set_title(0, title)
|
||||
if start_collapsed:
|
||||
|
|
@ -28,13 +28,66 @@ def widget_group(label_widget_pairs):
|
|||
return widgets.HBox([labels_w, widgets_w])
|
||||
|
||||
|
||||
class ConcentrationFigure:
|
||||
#: A scenario is a name and a (mutable) model.
|
||||
ScenarioType = typing.Tuple[str, state.DataclassState]
|
||||
|
||||
|
||||
class View:
|
||||
"""
|
||||
A thing which exposes a ``.widget`` attribute which is a view on some
|
||||
data. This view is essentially a complex combination of widgets, along with
|
||||
some event handling capabilities, which may or may not be sent back up to
|
||||
the underlying controller.
|
||||
|
||||
We strive hard to keep "Model" data out of the View (and try to avoid
|
||||
storing it at all on the View itself), instead relying on being able
|
||||
to notify, and receive notifications, of important events from the Controller.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Controller:
|
||||
"""
|
||||
The singleton thing which is the top-level Application.
|
||||
|
||||
It is responsible for owning the Model data and the Views, and
|
||||
orchestrating event messages to each if the Model/View change.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def ipympl_canvas(figure):
|
||||
matplotlib.interactive(False)
|
||||
ipympl.backend_nbagg.new_figure_manager_given_figure(uuid.uuid1(), figure)
|
||||
figure.canvas.toolbar_visible = True
|
||||
figure.canvas.toolbar.collapsed = True
|
||||
figure.canvas.footer_visible = False
|
||||
figure.canvas.header_visible = False
|
||||
return figure.canvas
|
||||
|
||||
|
||||
class ExposureModelResult(View):
|
||||
def __init__(self):
|
||||
self.figure = matplotlib.figure.Figure(figsize=(9, 6))
|
||||
ipympl_canvas(self.figure)
|
||||
self.html_output = widgets.HTML()
|
||||
self.ax = self.figure.add_subplot(1, 1, 1)
|
||||
self.line = None
|
||||
|
||||
def update(self, model: models.ConcentrationModel):
|
||||
@property
|
||||
def widget(self):
|
||||
return widgets.VBox([
|
||||
self.html_output,
|
||||
self.figure.canvas,
|
||||
])
|
||||
|
||||
def update(self, model: models.ExposureModel):
|
||||
self.update_plot(model.concentration_model)
|
||||
self.update_textual_result(model)
|
||||
|
||||
def update_plot(self, model: models.ConcentrationModel):
|
||||
resolution = 600
|
||||
ts = np.linspace(sorted(model.infected.presence.transition_times())[0],
|
||||
sorted(model.infected.presence.transition_times())[-1], resolution)
|
||||
|
|
@ -60,70 +113,75 @@ class ConcentrationFigure:
|
|||
self.ax.set_ylim(bottom=0., top=top)
|
||||
self.figure.canvas.draw()
|
||||
|
||||
def update_textual_result(self, model: models.ExposureModel):
|
||||
lines = []
|
||||
P = model.infection_probability()
|
||||
lines.append(f'Emission rate (quanta/hr): {model.concentration_model.infected.emission_rate_when_present()}')
|
||||
lines.append(f'Probability of infection: {np.round(P, 0)}%')
|
||||
|
||||
def ipympl_canvas(figure: matplotlib.figure.Figure):
|
||||
# Make a plain matplotlib figure render as a Jupyter widget.
|
||||
matplotlib.interactive(False)
|
||||
ipympl.backend_nbagg.new_figure_manager_given_figure(uuid.uuid1(), figure)
|
||||
figure.canvas.toolbar_visible = True
|
||||
figure.canvas.toolbar.collapsed = True
|
||||
figure.canvas.footer_visible = False
|
||||
figure.canvas.header_visible = False
|
||||
lines.append(f'Number of exposed: {model.exposed.number}')
|
||||
|
||||
new_cases = np.round(model.expected_new_cases(), 1)
|
||||
lines.append(f'Number of expected new cases: {new_cases}')
|
||||
|
||||
R0 = np.round(model.reproduction_number(), 1)
|
||||
lines.append(f'Reproduction number (R0): {R0}')
|
||||
|
||||
self.html_output.value = '<br>\n'.join(lines)
|
||||
|
||||
|
||||
class WidgetView:
|
||||
class ExposureComparissonResult(View):
|
||||
def __init__(self):
|
||||
self.figure = matplotlib.figure.Figure(figsize=(9, 6))
|
||||
ipympl_canvas(self.figure)
|
||||
self.ax = self.initialize_axes()
|
||||
|
||||
@property
|
||||
def widget(self):
|
||||
# Workaround to a bug with ipymlp, which doesn't work well with tabs
|
||||
# unless the widget is wrapped in a container (it is seen on all tabs otherwise!).
|
||||
return widgets.HBox([self.figure.canvas])
|
||||
|
||||
def initialize_axes(self) -> matplotlib.figure.Axes:
|
||||
ax = self.figure.add_subplot(1, 1, 1)
|
||||
ax.spines['right'].set_visible(False)
|
||||
ax.spines['top'].set_visible(False)
|
||||
ax.set_xlabel('Time (hours)')
|
||||
ax.set_ylabel('Concentration ($q/m^3$)')
|
||||
ax.set_title('Concentration of infectious quanta aerosols')
|
||||
return ax
|
||||
|
||||
def scenarios_updated(self, scenarios: typing.Sequence[ScenarioType], _):
|
||||
labels, models = zip(*scenarios)
|
||||
conc_models: typing.Tuple[models.ConcentrationModel] = tuple(
|
||||
model.concentration_model.dcs_instance() for model in models
|
||||
)
|
||||
self.update_plot(conc_models, labels)
|
||||
|
||||
def update_plot(self, conc_models: typing.Tuple[models.ConcentrationModel], labels: typing.Tuple[str]):
|
||||
self.ax.lines.clear()
|
||||
start, finish = models_start_end(conc_models)
|
||||
ts = np.linspace(start, finish, num=250)
|
||||
concentrations = [[conc_model.concentration(t) for t in ts] for conc_model in conc_models]
|
||||
for label, concentration in zip(labels, concentrations):
|
||||
self.ax.plot(ts, concentration, label=label)
|
||||
|
||||
top = max(3., max([max(conc) for conc in concentrations]))
|
||||
self.ax.set_ylim(bottom=0., top=top)
|
||||
|
||||
self.ax.legend()
|
||||
self.figure.canvas.draw()
|
||||
|
||||
|
||||
class ModelWidgets(View):
|
||||
def __init__(self, model_state: state.DataclassState):
|
||||
self.model_state = model_state
|
||||
self.model_state.dcs_observe(self.update)
|
||||
#: The widgets that this view produces (inputs and outputs together)
|
||||
self.widget = widgets.VBox([])
|
||||
self.widgets = {}
|
||||
self.out = widgets.Output()
|
||||
self.plots = []
|
||||
self.construct_widgets()
|
||||
# Trigger the first result.
|
||||
self.update()
|
||||
self.construct_widgets(model_state)
|
||||
|
||||
def construct_widgets(self):
|
||||
def construct_widgets(self, model_state: state.DataclassState):
|
||||
# Build the input widgets.
|
||||
self._build_widget(self.model_state)
|
||||
|
||||
# And the output widget figure.
|
||||
concentration = ConcentrationFigure()
|
||||
self.plots.append(concentration)
|
||||
ipympl_canvas(concentration.figure)
|
||||
|
||||
self.widgets['results'] = collapsible([
|
||||
widgets.HBox([
|
||||
concentration.figure.canvas,
|
||||
self.out,
|
||||
])
|
||||
], 'Results', start_collapsed=False)
|
||||
|
||||
# Join inputs and outputs together in a single widget for convenience.
|
||||
self.widget.children += (self.widgets['results'], )
|
||||
|
||||
def prepare_output(self):
|
||||
pass
|
||||
|
||||
def update(self):
|
||||
model: models.ExposureModel = self.model_state.dcs_instance()
|
||||
for plot in self.plots:
|
||||
plot.update(model.concentration_model)
|
||||
|
||||
self.out.clear_output()
|
||||
with self.out:
|
||||
P = model.infection_probability()
|
||||
print(f'Emission rate (quanta/hr): {model.concentration_model.infected.emission_rate_when_present()}')
|
||||
print(f'Probability of infection: {np.round(P, 0)}%')
|
||||
|
||||
print(f'Number of exposed: {model.exposed.number}')
|
||||
|
||||
new_cases = np.round(model.expected_new_cases(), 1)
|
||||
print(f'Number of expected new cases: {new_cases}')
|
||||
|
||||
R0 = np.round(model.reproduction_number(), 1)
|
||||
print(f'Reproduction number (R0): {R0}')
|
||||
self._build_widget(model_state)
|
||||
|
||||
def _build_widget(self, node):
|
||||
self.widget.children += (self._build_room(node.concentration_model.room),)
|
||||
|
|
@ -160,7 +218,7 @@ class WidgetView:
|
|||
[widget_group(
|
||||
[[widgets.Label('Room volume'), room_volume]]
|
||||
)],
|
||||
title='Specification of workplace', start_collapsed=False,
|
||||
title='Specification of workplace',
|
||||
)
|
||||
return widget
|
||||
|
||||
|
|
@ -347,7 +405,7 @@ class WidgetView:
|
|||
w = collapsible(
|
||||
[widget_group([[widgets.Label('Ventilation type'), ventilation_w]])]
|
||||
+ list(ventilation_widgets.values()),
|
||||
title='Ventilation scheme'
|
||||
title='Ventilation scheme',
|
||||
)
|
||||
return w
|
||||
|
||||
|
|
@ -404,75 +462,195 @@ class CARAStateBuilder(state.StateBuilder):
|
|||
return s
|
||||
|
||||
|
||||
class ExpertApplication:
|
||||
class ExpertApplication(Controller):
|
||||
def __init__(self):
|
||||
default_scenario = state.DataclassInstanceState(
|
||||
models.ExposureModel,
|
||||
state_builder=CARAStateBuilder(),
|
||||
)
|
||||
default_scenario.dcs_update_from(baseline_model)
|
||||
self._debug_output = widgets.Output()
|
||||
|
||||
#: A list of scenario name and ModelState instances. This is intended to be
|
||||
#: mutated. Any mutation should notify the appropriate Views for handling.
|
||||
self._model_scenarios: typing.List[ScenarioType] = []
|
||||
self._active_scenario = 0
|
||||
self.multi_model_view = MultiModelView(self)
|
||||
self.comparison_view = ExposureComparissonResult()
|
||||
self.current_scenario_figure = ExposureModelResult()
|
||||
self._results_tab = widgets.Tab(children=(
|
||||
self.current_scenario_figure.widget,
|
||||
self.comparison_view.widget,
|
||||
# self._debug_output,
|
||||
))
|
||||
for i, title in enumerate(['Current scenario', 'Scenario comparison', "Debug"]):
|
||||
self._results_tab.set_title(i, title)
|
||||
self.widget = widgets.HBox(
|
||||
children=(
|
||||
self.multi_model_view.widget,
|
||||
self._results_tab,
|
||||
),
|
||||
)
|
||||
self.add_scenario('Scenario 1')
|
||||
|
||||
def build_new_model(self):
|
||||
default_model = state.DataclassInstanceState(
|
||||
models.ExposureModel,
|
||||
state_builder=CARAStateBuilder(),
|
||||
)
|
||||
default_model.dcs_update_from(baseline_model)
|
||||
# For the time-being, we have to initialise the select states. Careful
|
||||
# as values might not correspond to what the baseline model says.
|
||||
default_scenario.concentration_model.infected.mask.dcs_select('No mask')
|
||||
self.scenarios = (default_scenario,)
|
||||
self.scenario_names = ('Scenario 1',)
|
||||
self.views = (WidgetView(default_scenario),)
|
||||
self.selected_tab = 0
|
||||
self.tabs = (widgets.VBox(children=(self.build_settings_menu(0), self.views[0].present())),)
|
||||
self.tab_widget = widgets.Tab(children=self.tabs)
|
||||
self.display_titles()
|
||||
default_model.concentration_model.infected.mask.dcs_select('No mask')
|
||||
return default_model
|
||||
|
||||
def display_titles(self):
|
||||
for i, name in enumerate(self.scenario_names):
|
||||
self.tab_widget.set_title(i, name)
|
||||
def add_scenario(self, name, copy_from_model: typing.Optional[state.DataclassInstanceState] = None):
|
||||
model = self.build_new_model()
|
||||
if copy_from_model is not None:
|
||||
model.dcs_update_from(copy_from_model.dcs_instance())
|
||||
self._model_scenarios.append((name, model))
|
||||
self._active_scenario = len(self._model_scenarios) - 1
|
||||
model.dcs_observe(self.notify_model_values_changed)
|
||||
self.notify_scenarios_changed()
|
||||
|
||||
def _find_model_id(self, model_id):
|
||||
for index, (name, model) in enumerate(list(self._model_scenarios)):
|
||||
if id(model) == model_id:
|
||||
return index, name, model
|
||||
else:
|
||||
raise ValueError("Model not found")
|
||||
|
||||
def rename_scenario(self, model_id, new_name):
|
||||
index, _, model = self._find_model_id(model_id)
|
||||
self._model_scenarios[index] = (new_name, model)
|
||||
self.notify_scenarios_changed()
|
||||
|
||||
def remove_scenario(self, model_id):
|
||||
index, _, model = self._find_model_id(model_id)
|
||||
self._model_scenarios.pop(index)
|
||||
if self._active_scenario >= index:
|
||||
self._active_scenario = max(self._active_scenario - 1, 0)
|
||||
self.notify_scenarios_changed()
|
||||
|
||||
def set_active_scenario(self, model_id):
|
||||
index, _, model = self._find_model_id(model_id)
|
||||
self._active_scenario = index
|
||||
self.notify_scenarios_changed()
|
||||
self.notify_model_values_changed()
|
||||
|
||||
def notify_scenarios_changed(self):
|
||||
"""
|
||||
Occurs when the set of scenarios has been modified, but not if the values of the scenario has changed.
|
||||
|
||||
"""
|
||||
self.multi_model_view.scenarios_updated(self._model_scenarios, self._active_scenario)
|
||||
self.comparison_view.scenarios_updated(self._model_scenarios, self._active_scenario)
|
||||
|
||||
def notify_model_values_changed(self):
|
||||
"""
|
||||
Occurs when *any* value in *any* of the scenarios has been modified.
|
||||
"""
|
||||
self.comparison_view.scenarios_updated(self._model_scenarios, self._active_scenario)
|
||||
self.current_scenario_figure.update(self._model_scenarios[self._active_scenario][1].dcs_instance())
|
||||
|
||||
|
||||
class MultiModelView(View):
|
||||
def __init__(self, controller: ExpertApplication):
|
||||
self._controller = controller
|
||||
self.widget = widgets.Tab()
|
||||
self.widget.observe(self._on_tab_change, 'selected_index')
|
||||
self._tab_model_ids: typing.List[int] = []
|
||||
self._tab_widgets: typing.List[widgets.Widget] = []
|
||||
self._tab_model_views: typing.List[ModelWidgets] = []
|
||||
|
||||
def scenarios_updated(
|
||||
self,
|
||||
model_scenarios: typing.Sequence[ScenarioType],
|
||||
active_scenario_index: int
|
||||
):
|
||||
"""
|
||||
Called when a scenario is added/removed/renamed etc.
|
||||
|
||||
Note: Not called when the model state is modified.
|
||||
|
||||
"""
|
||||
model_scenario_ids = []
|
||||
for i, (scenario_name, model) in enumerate(model_scenarios):
|
||||
if id(model) not in self._tab_model_ids:
|
||||
self.add_tab(scenario_name, model)
|
||||
model_scenario_ids.append(id(model))
|
||||
tab_index = self._tab_model_ids.index(id(model))
|
||||
self.widget.set_title(tab_index, scenario_name)
|
||||
|
||||
# Any remaining model_scenario_ids are no longer needed, so remove
|
||||
# their tabs.
|
||||
for tab_index, tab_scenario_id in enumerate(self._tab_model_ids[:]):
|
||||
if tab_scenario_id not in model_scenario_ids:
|
||||
self.remove_tab(tab_index)
|
||||
|
||||
assert self._tab_model_ids == model_scenario_ids
|
||||
|
||||
self.widget.selected_index = active_scenario_index
|
||||
|
||||
def add_tab(self, name, model):
|
||||
self._tab_model_views.append(ModelWidgets(model))
|
||||
self._tab_model_ids.append(id(model))
|
||||
tab_idx = len(self._tab_model_ids) - 1
|
||||
tab_widget = widgets.VBox(
|
||||
children=(
|
||||
self._build_settings_menu(name, model),
|
||||
self._tab_model_views[tab_idx].widget,
|
||||
)
|
||||
)
|
||||
self._tab_widgets.append(tab_widget)
|
||||
self.update_tab_widget()
|
||||
|
||||
def remove_tab(self, tab_index):
|
||||
assert 0 <= tab_index < len(self._tab_model_ids)
|
||||
assert len(self._tab_model_ids) > 1
|
||||
self._tab_model_ids.pop(tab_index)
|
||||
self._tab_widgets.pop(tab_index)
|
||||
self._tab_model_views.pop(tab_index)
|
||||
if self._active_tab_index >= tab_index:
|
||||
self._active_tab_index = max(0, self._active_tab_index - 1)
|
||||
self.update_tab_widget()
|
||||
|
||||
def update_tab_widget(self):
|
||||
self.tab_widget.children = self.tabs
|
||||
self.display_titles()
|
||||
self.widget.children = tuple(self._tab_widgets)
|
||||
|
||||
def build_settings_menu(self, tab_index):
|
||||
def _on_tab_change(self, change):
|
||||
self._controller.set_active_scenario(
|
||||
self._tab_model_ids[change['new']]
|
||||
)
|
||||
|
||||
def _build_settings_menu(self, name, model):
|
||||
delete_button = widgets.Button(description='Delete Scenario', button_style='danger')
|
||||
rename_text_field = widgets.Text(description='Rename Scenario:', value=self.scenario_names[tab_index],
|
||||
rename_text_field = widgets.Text(description='Rename Scenario:', value=name,
|
||||
style={'description_width': 'auto'})
|
||||
duplicate_button = widgets.Button(description='Duplicate Scenario', button_style='success')
|
||||
model_id = id(model)
|
||||
|
||||
def on_delete_click(b):
|
||||
self.scenario_names = tuple_without_index(self.scenario_names, tab_index)
|
||||
self.scenarios = tuple_without_index(self.scenarios, tab_index)
|
||||
self.views = tuple_without_index(self.views, tab_index)
|
||||
self.selected_tab = min(0, self.selected_tab - 1)
|
||||
self.tabs = tuple(widgets.VBox(children=(self.build_settings_menu(i), view.present()))
|
||||
for i, view in enumerate(self.views))
|
||||
self.update_tab_widget()
|
||||
self._controller.remove_scenario(model_id)
|
||||
|
||||
def on_rename_text_field(change):
|
||||
self.scenario_names = tuple(change['new'] if i == tab_index else value
|
||||
for i, value in enumerate(self.scenario_names))
|
||||
self.update_tab_widget()
|
||||
self._controller.rename_scenario(model_id, new_name=change['new'])
|
||||
|
||||
def on_duplicate_click(b):
|
||||
self.scenario_names += (self.scenario_names[tab_index] + " (copy)",)
|
||||
new_scenario = state.DataclassInstanceState(
|
||||
models.ExposureModel,
|
||||
state_builder=CARAStateBuilder(),
|
||||
)
|
||||
new_scenario.dcs_update_from(self.scenarios[tab_index].dcs_instance())
|
||||
self.scenarios += (new_scenario,)
|
||||
|
||||
self.views += (WidgetView(new_scenario),)
|
||||
self.tabs += (widgets.VBox(children=(self.build_settings_menu(len(self.scenario_names) - 1), self.views[-1].present())),)
|
||||
self.update_tab_widget()
|
||||
tab_index = self._tab_model_ids.index(model_id)
|
||||
name = self.widget.get_title(tab_index)
|
||||
self._controller.add_scenario(f'{name} (copy)', model)
|
||||
|
||||
delete_button.on_click(on_delete_click)
|
||||
duplicate_button.on_click(on_duplicate_click)
|
||||
rename_text_field.observe(on_rename_text_field, 'value')
|
||||
buttons = duplicate_button if tab_index == 0 else widgets.HBox(children=(duplicate_button, delete_button))
|
||||
# TODO: This should be dynamic - we don't want to be able to delete the
|
||||
# last scenario, so this should be controlled in the remove_tab method.
|
||||
buttons_w_delete = widgets.HBox(children=(duplicate_button, delete_button))
|
||||
buttons = duplicate_button if len(self._tab_model_ids) < 2 else buttons_w_delete
|
||||
return widgets.VBox(children=(buttons, rename_text_field))
|
||||
|
||||
@property
|
||||
def widget(self):
|
||||
return self.tab_widget
|
||||
|
||||
def models_start_end(models: typing.Sequence[models.ConcentrationModel]) -> typing.Tuple[float, float]:
|
||||
"""
|
||||
Returns the earliest start and latest end time of a collection of ConcentrationModel objects
|
||||
|
||||
def tuple_without_index(t: typing.Tuple, index: int) -> typing.Tuple:
|
||||
return t[:index] + t[index + 1:]
|
||||
"""
|
||||
infected_start = min(model.infected.presence.boundaries()[0][0] for model in models)
|
||||
infected_finish = min(model.infected.presence.boundaries()[-1][1] for model in models)
|
||||
return infected_start, infected_finish
|
||||
|
|
|
|||
|
|
@ -1,9 +1,22 @@
|
|||
import pytest
|
||||
|
||||
import cara.apps
|
||||
|
||||
|
||||
def test_app():
|
||||
@pytest.fixture
|
||||
def expert_app():
|
||||
return cara.apps.ExpertApplication()
|
||||
|
||||
|
||||
def test_app(expert_app):
|
||||
# To start with, let's just test that the application runs. We don't try to
|
||||
# do anything fancy to verify how it looks etc., we leave that for manual
|
||||
# testing.
|
||||
expert_app = cara.apps.ExpertApplication()
|
||||
assert expert_app.scenario_names[0] == "Scenario 1"
|
||||
assert expert_app._model_scenarios[0][0] == "Scenario 1"
|
||||
|
||||
|
||||
def test_new_scenario_changes_tab(expert_app):
|
||||
# Adding a new scenario should change the tab index of the multi-model view.
|
||||
assert expert_app.multi_model_view.widget.selected_index == 0
|
||||
expert_app.add_scenario("Another scenario")
|
||||
assert expert_app.multi_model_view.widget.selected_index == 1
|
||||
|
|
|
|||
Loading…
Reference in a new issue