diff --git a/cara/apps/expert.py b/cara/apps/expert.py index 410e9127..9dc9289d 100644 --- a/cara/apps/expert.py +++ b/cara/apps/expert.py @@ -13,7 +13,7 @@ from cara import state from cara import data -def collapsible(widgets_to_collapse: typing.List, title: str, start_collapsed=True): +def collapsible(widgets_to_collapse: typing.List, title: str, start_collapsed=False): collapsed = widgets.Accordion([widgets.VBox(widgets_to_collapse)]) collapsed.set_title(0, title) if start_collapsed: @@ -28,13 +28,66 @@ def widget_group(label_widget_pairs): return widgets.HBox([labels_w, widgets_w]) -class ConcentrationFigure: +#: A scenario is a name and a (mutable) model. +ScenarioType = typing.Tuple[str, state.DataclassState] + + +class View: + """ + A thing which exposes a ``.widget`` attribute which is a view on some + data. This view is essentially a complex combination of widgets, along with + some event handling capabilities, which may or may not be sent back up to + the underlying controller. + + We strive hard to keep "Model" data out of the View (and try to avoid + storing it at all on the View itself), instead relying on being able + to notify, and receive notifications, of important events from the Controller. + + """ + pass + + +class Controller: + """ + The singleton thing which is the top-level Application. + + It is responsible for owning the Model data and the Views, and + orchestrating event messages to each if the Model/View change. + + """ + pass + + +def ipympl_canvas(figure): + matplotlib.interactive(False) + ipympl.backend_nbagg.new_figure_manager_given_figure(uuid.uuid1(), figure) + figure.canvas.toolbar_visible = True + figure.canvas.toolbar.collapsed = True + figure.canvas.footer_visible = False + figure.canvas.header_visible = False + return figure.canvas + + +class ExposureModelResult(View): def __init__(self): self.figure = matplotlib.figure.Figure(figsize=(9, 6)) + ipympl_canvas(self.figure) + self.html_output = widgets.HTML() self.ax = self.figure.add_subplot(1, 1, 1) self.line = None - def update(self, model: models.ConcentrationModel): + @property + def widget(self): + return widgets.VBox([ + self.html_output, + self.figure.canvas, + ]) + + def update(self, model: models.ExposureModel): + self.update_plot(model.concentration_model) + self.update_textual_result(model) + + def update_plot(self, model: models.ConcentrationModel): resolution = 600 ts = np.linspace(sorted(model.infected.presence.transition_times())[0], sorted(model.infected.presence.transition_times())[-1], resolution) @@ -60,70 +113,75 @@ class ConcentrationFigure: self.ax.set_ylim(bottom=0., top=top) self.figure.canvas.draw() + def update_textual_result(self, model: models.ExposureModel): + lines = [] + P = model.infection_probability() + lines.append(f'Emission rate (quanta/hr): {model.concentration_model.infected.emission_rate_when_present()}') + lines.append(f'Probability of infection: {np.round(P, 0)}%') -def ipympl_canvas(figure: matplotlib.figure.Figure): - # Make a plain matplotlib figure render as a Jupyter widget. - matplotlib.interactive(False) - ipympl.backend_nbagg.new_figure_manager_given_figure(uuid.uuid1(), figure) - figure.canvas.toolbar_visible = True - figure.canvas.toolbar.collapsed = True - figure.canvas.footer_visible = False - figure.canvas.header_visible = False + lines.append(f'Number of exposed: {model.exposed.number}') + + new_cases = np.round(model.expected_new_cases(), 1) + lines.append(f'Number of expected new cases: {new_cases}') + + R0 = np.round(model.reproduction_number(), 1) + lines.append(f'Reproduction number (R0): {R0}') + + self.html_output.value = '
\n'.join(lines) -class WidgetView: +class ExposureComparissonResult(View): + def __init__(self): + self.figure = matplotlib.figure.Figure(figsize=(9, 6)) + ipympl_canvas(self.figure) + self.ax = self.initialize_axes() + + @property + def widget(self): + # Workaround to a bug with ipymlp, which doesn't work well with tabs + # unless the widget is wrapped in a container (it is seen on all tabs otherwise!). + return widgets.HBox([self.figure.canvas]) + + def initialize_axes(self) -> matplotlib.figure.Axes: + ax = self.figure.add_subplot(1, 1, 1) + ax.spines['right'].set_visible(False) + ax.spines['top'].set_visible(False) + ax.set_xlabel('Time (hours)') + ax.set_ylabel('Concentration ($q/m^3$)') + ax.set_title('Concentration of infectious quanta aerosols') + return ax + + def scenarios_updated(self, scenarios: typing.Sequence[ScenarioType], _): + labels, models = zip(*scenarios) + conc_models: typing.Tuple[models.ConcentrationModel] = tuple( + model.concentration_model.dcs_instance() for model in models + ) + self.update_plot(conc_models, labels) + + def update_plot(self, conc_models: typing.Tuple[models.ConcentrationModel], labels: typing.Tuple[str]): + self.ax.lines.clear() + start, finish = models_start_end(conc_models) + ts = np.linspace(start, finish, num=250) + concentrations = [[conc_model.concentration(t) for t in ts] for conc_model in conc_models] + for label, concentration in zip(labels, concentrations): + self.ax.plot(ts, concentration, label=label) + + top = max(3., max([max(conc) for conc in concentrations])) + self.ax.set_ylim(bottom=0., top=top) + + self.ax.legend() + self.figure.canvas.draw() + + +class ModelWidgets(View): def __init__(self, model_state: state.DataclassState): - self.model_state = model_state - self.model_state.dcs_observe(self.update) #: The widgets that this view produces (inputs and outputs together) self.widget = widgets.VBox([]) - self.widgets = {} - self.out = widgets.Output() - self.plots = [] - self.construct_widgets() - # Trigger the first result. - self.update() + self.construct_widgets(model_state) - def construct_widgets(self): + def construct_widgets(self, model_state: state.DataclassState): # Build the input widgets. - self._build_widget(self.model_state) - - # And the output widget figure. - concentration = ConcentrationFigure() - self.plots.append(concentration) - ipympl_canvas(concentration.figure) - - self.widgets['results'] = collapsible([ - widgets.HBox([ - concentration.figure.canvas, - self.out, - ]) - ], 'Results', start_collapsed=False) - - # Join inputs and outputs together in a single widget for convenience. - self.widget.children += (self.widgets['results'], ) - - def prepare_output(self): - pass - - def update(self): - model: models.ExposureModel = self.model_state.dcs_instance() - for plot in self.plots: - plot.update(model.concentration_model) - - self.out.clear_output() - with self.out: - P = model.infection_probability() - print(f'Emission rate (quanta/hr): {model.concentration_model.infected.emission_rate_when_present()}') - print(f'Probability of infection: {np.round(P, 0)}%') - - print(f'Number of exposed: {model.exposed.number}') - - new_cases = np.round(model.expected_new_cases(), 1) - print(f'Number of expected new cases: {new_cases}') - - R0 = np.round(model.reproduction_number(), 1) - print(f'Reproduction number (R0): {R0}') + self._build_widget(model_state) def _build_widget(self, node): self.widget.children += (self._build_room(node.concentration_model.room),) @@ -160,7 +218,7 @@ class WidgetView: [widget_group( [[widgets.Label('Room volume'), room_volume]] )], - title='Specification of workplace', start_collapsed=False, + title='Specification of workplace', ) return widget @@ -347,7 +405,7 @@ class WidgetView: w = collapsible( [widget_group([[widgets.Label('Ventilation type'), ventilation_w]])] + list(ventilation_widgets.values()), - title='Ventilation scheme' + title='Ventilation scheme', ) return w @@ -404,75 +462,195 @@ class CARAStateBuilder(state.StateBuilder): return s -class ExpertApplication: +class ExpertApplication(Controller): def __init__(self): - default_scenario = state.DataclassInstanceState( - models.ExposureModel, - state_builder=CARAStateBuilder(), - ) - default_scenario.dcs_update_from(baseline_model) + self._debug_output = widgets.Output() + + #: A list of scenario name and ModelState instances. This is intended to be + #: mutated. Any mutation should notify the appropriate Views for handling. + self._model_scenarios: typing.List[ScenarioType] = [] + self._active_scenario = 0 + self.multi_model_view = MultiModelView(self) + self.comparison_view = ExposureComparissonResult() + self.current_scenario_figure = ExposureModelResult() + self._results_tab = widgets.Tab(children=( + self.current_scenario_figure.widget, + self.comparison_view.widget, + # self._debug_output, + )) + for i, title in enumerate(['Current scenario', 'Scenario comparison', "Debug"]): + self._results_tab.set_title(i, title) + self.widget = widgets.HBox( + children=( + self.multi_model_view.widget, + self._results_tab, + ), + ) + self.add_scenario('Scenario 1') + + def build_new_model(self): + default_model = state.DataclassInstanceState( + models.ExposureModel, + state_builder=CARAStateBuilder(), + ) + default_model.dcs_update_from(baseline_model) # For the time-being, we have to initialise the select states. Careful # as values might not correspond to what the baseline model says. - default_scenario.concentration_model.infected.mask.dcs_select('No mask') - self.scenarios = (default_scenario,) - self.scenario_names = ('Scenario 1',) - self.views = (WidgetView(default_scenario),) - self.selected_tab = 0 - self.tabs = (widgets.VBox(children=(self.build_settings_menu(0), self.views[0].present())),) - self.tab_widget = widgets.Tab(children=self.tabs) - self.display_titles() + default_model.concentration_model.infected.mask.dcs_select('No mask') + return default_model - def display_titles(self): - for i, name in enumerate(self.scenario_names): - self.tab_widget.set_title(i, name) + def add_scenario(self, name, copy_from_model: typing.Optional[state.DataclassInstanceState] = None): + model = self.build_new_model() + if copy_from_model is not None: + model.dcs_update_from(copy_from_model.dcs_instance()) + self._model_scenarios.append((name, model)) + self._active_scenario = len(self._model_scenarios) - 1 + model.dcs_observe(self.notify_model_values_changed) + self.notify_scenarios_changed() + + def _find_model_id(self, model_id): + for index, (name, model) in enumerate(list(self._model_scenarios)): + if id(model) == model_id: + return index, name, model + else: + raise ValueError("Model not found") + + def rename_scenario(self, model_id, new_name): + index, _, model = self._find_model_id(model_id) + self._model_scenarios[index] = (new_name, model) + self.notify_scenarios_changed() + + def remove_scenario(self, model_id): + index, _, model = self._find_model_id(model_id) + self._model_scenarios.pop(index) + if self._active_scenario >= index: + self._active_scenario = max(self._active_scenario - 1, 0) + self.notify_scenarios_changed() + + def set_active_scenario(self, model_id): + index, _, model = self._find_model_id(model_id) + self._active_scenario = index + self.notify_scenarios_changed() + self.notify_model_values_changed() + + def notify_scenarios_changed(self): + """ + Occurs when the set of scenarios has been modified, but not if the values of the scenario has changed. + + """ + self.multi_model_view.scenarios_updated(self._model_scenarios, self._active_scenario) + self.comparison_view.scenarios_updated(self._model_scenarios, self._active_scenario) + + def notify_model_values_changed(self): + """ + Occurs when *any* value in *any* of the scenarios has been modified. + """ + self.comparison_view.scenarios_updated(self._model_scenarios, self._active_scenario) + self.current_scenario_figure.update(self._model_scenarios[self._active_scenario][1].dcs_instance()) + + +class MultiModelView(View): + def __init__(self, controller: ExpertApplication): + self._controller = controller + self.widget = widgets.Tab() + self.widget.observe(self._on_tab_change, 'selected_index') + self._tab_model_ids: typing.List[int] = [] + self._tab_widgets: typing.List[widgets.Widget] = [] + self._tab_model_views: typing.List[ModelWidgets] = [] + + def scenarios_updated( + self, + model_scenarios: typing.Sequence[ScenarioType], + active_scenario_index: int + ): + """ + Called when a scenario is added/removed/renamed etc. + + Note: Not called when the model state is modified. + + """ + model_scenario_ids = [] + for i, (scenario_name, model) in enumerate(model_scenarios): + if id(model) not in self._tab_model_ids: + self.add_tab(scenario_name, model) + model_scenario_ids.append(id(model)) + tab_index = self._tab_model_ids.index(id(model)) + self.widget.set_title(tab_index, scenario_name) + + # Any remaining model_scenario_ids are no longer needed, so remove + # their tabs. + for tab_index, tab_scenario_id in enumerate(self._tab_model_ids[:]): + if tab_scenario_id not in model_scenario_ids: + self.remove_tab(tab_index) + + assert self._tab_model_ids == model_scenario_ids + + self.widget.selected_index = active_scenario_index + + def add_tab(self, name, model): + self._tab_model_views.append(ModelWidgets(model)) + self._tab_model_ids.append(id(model)) + tab_idx = len(self._tab_model_ids) - 1 + tab_widget = widgets.VBox( + children=( + self._build_settings_menu(name, model), + self._tab_model_views[tab_idx].widget, + ) + ) + self._tab_widgets.append(tab_widget) + self.update_tab_widget() + + def remove_tab(self, tab_index): + assert 0 <= tab_index < len(self._tab_model_ids) + assert len(self._tab_model_ids) > 1 + self._tab_model_ids.pop(tab_index) + self._tab_widgets.pop(tab_index) + self._tab_model_views.pop(tab_index) + if self._active_tab_index >= tab_index: + self._active_tab_index = max(0, self._active_tab_index - 1) + self.update_tab_widget() def update_tab_widget(self): - self.tab_widget.children = self.tabs - self.display_titles() + self.widget.children = tuple(self._tab_widgets) - def build_settings_menu(self, tab_index): + def _on_tab_change(self, change): + self._controller.set_active_scenario( + self._tab_model_ids[change['new']] + ) + + def _build_settings_menu(self, name, model): delete_button = widgets.Button(description='Delete Scenario', button_style='danger') - rename_text_field = widgets.Text(description='Rename Scenario:', value=self.scenario_names[tab_index], + rename_text_field = widgets.Text(description='Rename Scenario:', value=name, style={'description_width': 'auto'}) duplicate_button = widgets.Button(description='Duplicate Scenario', button_style='success') + model_id = id(model) def on_delete_click(b): - self.scenario_names = tuple_without_index(self.scenario_names, tab_index) - self.scenarios = tuple_without_index(self.scenarios, tab_index) - self.views = tuple_without_index(self.views, tab_index) - self.selected_tab = min(0, self.selected_tab - 1) - self.tabs = tuple(widgets.VBox(children=(self.build_settings_menu(i), view.present())) - for i, view in enumerate(self.views)) - self.update_tab_widget() + self._controller.remove_scenario(model_id) def on_rename_text_field(change): - self.scenario_names = tuple(change['new'] if i == tab_index else value - for i, value in enumerate(self.scenario_names)) - self.update_tab_widget() + self._controller.rename_scenario(model_id, new_name=change['new']) def on_duplicate_click(b): - self.scenario_names += (self.scenario_names[tab_index] + " (copy)",) - new_scenario = state.DataclassInstanceState( - models.ExposureModel, - state_builder=CARAStateBuilder(), - ) - new_scenario.dcs_update_from(self.scenarios[tab_index].dcs_instance()) - self.scenarios += (new_scenario,) - - self.views += (WidgetView(new_scenario),) - self.tabs += (widgets.VBox(children=(self.build_settings_menu(len(self.scenario_names) - 1), self.views[-1].present())),) - self.update_tab_widget() + tab_index = self._tab_model_ids.index(model_id) + name = self.widget.get_title(tab_index) + self._controller.add_scenario(f'{name} (copy)', model) delete_button.on_click(on_delete_click) duplicate_button.on_click(on_duplicate_click) rename_text_field.observe(on_rename_text_field, 'value') - buttons = duplicate_button if tab_index == 0 else widgets.HBox(children=(duplicate_button, delete_button)) + # TODO: This should be dynamic - we don't want to be able to delete the + # last scenario, so this should be controlled in the remove_tab method. + buttons_w_delete = widgets.HBox(children=(duplicate_button, delete_button)) + buttons = duplicate_button if len(self._tab_model_ids) < 2 else buttons_w_delete return widgets.VBox(children=(buttons, rename_text_field)) - @property - def widget(self): - return self.tab_widget +def models_start_end(models: typing.Sequence[models.ConcentrationModel]) -> typing.Tuple[float, float]: + """ + Returns the earliest start and latest end time of a collection of ConcentrationModel objects -def tuple_without_index(t: typing.Tuple, index: int) -> typing.Tuple: - return t[:index] + t[index + 1:] + """ + infected_start = min(model.infected.presence.boundaries()[0][0] for model in models) + infected_finish = min(model.infected.presence.boundaries()[-1][1] for model in models) + return infected_start, infected_finish diff --git a/cara/tests/apps/test_expert_app.py b/cara/tests/apps/test_expert_app.py index 1cb55d3b..3e4d2362 100644 --- a/cara/tests/apps/test_expert_app.py +++ b/cara/tests/apps/test_expert_app.py @@ -1,9 +1,22 @@ +import pytest + import cara.apps -def test_app(): +@pytest.fixture +def expert_app(): + return cara.apps.ExpertApplication() + + +def test_app(expert_app): # To start with, let's just test that the application runs. We don't try to # do anything fancy to verify how it looks etc., we leave that for manual # testing. - expert_app = cara.apps.ExpertApplication() - assert expert_app.scenario_names[0] == "Scenario 1" + assert expert_app._model_scenarios[0][0] == "Scenario 1" + + +def test_new_scenario_changes_tab(expert_app): + # Adding a new scenario should change the tab index of the multi-model view. + assert expert_app.multi_model_view.widget.selected_index == 0 + expert_app.add_scenario("Another scenario") + assert expert_app.multi_model_view.widget.selected_index == 1