From 12cb3d1427264477b549551c8b7c53e3e062ebdf Mon Sep 17 00:00:00 2001 From: Phil Elson Date: Thu, 19 Nov 2020 11:30:43 +0100 Subject: [PATCH] Split the MultiModelView so that its data is being handled by the Controller (ExpertApp). --- cara/apps/expert.py | 219 +++++++++++++++++++++++++++++--------------- 1 file changed, 144 insertions(+), 75 deletions(-) diff --git a/cara/apps/expert.py b/cara/apps/expert.py index e019ad6b..e01c1099 100644 --- a/cara/apps/expert.py +++ b/cara/apps/expert.py @@ -404,92 +404,164 @@ class CARAStateBuilder(state.StateBuilder): return s +#: A scenario is a name and a (mutable) model. +ScenarioType = typing.Tuple[str, state.DataclassState] + + class ExpertApplication: def __init__(self): - self.multi_model_view = MultiModelView() + self._debug_output = widgets.Output() + + #: A list of scenario name and ModelState instances. This is intended to be + #: mutated. Any mutation should notify the appropriate Views for handling. + self._model_scenarios: typing.List[ScenarioType] = [] + self.multi_model_view = MultiModelView(self) self.comparison_view = ComparisonView() - self.app = widgets.VBox(children=(self.multi_model_view.widget, self.comparison_view.widget)) + self.widget = widgets.VBox( + children=( + self.multi_model_view.widget, + self.comparison_view.widget, + self._debug_output, + ), + ) + self.add_scenario('Scenario 1') - @property - def widget(self): - return self.app - - -class MultiModelView: - def __init__(self): - default_scenario = state.DataclassInstanceState( + def build_new_model(self): + default_model = state.DataclassInstanceState( models.ExposureModel, state_builder=CARAStateBuilder(), ) - default_scenario.dcs_update_from(baseline_model) + default_model.dcs_update_from(baseline_model) # For the time-being, we have to initialise the select states. Careful # as values might not correspond to what the baseline model says. - default_scenario.concentration_model.infected.mask.dcs_select('No mask') - self.scenarios = (default_scenario,) - self.scenario_names = ('Scenario 1',) - self.tab_views = (WidgetView(default_scenario),) - self.selected_tab = 0 - self.tabs = (widgets.VBox(children=(self.build_settings_menu(0), self.tab_views[0].present())),) - self.tab_widget = widgets.Tab() + default_model.concentration_model.infected.mask.dcs_select('No mask') + return default_model + + def add_scenario(self, name, copy_from_model: typing.Optional[state.DataclassInstanceState] = None): + model = self.build_new_model() + if copy_from_model is not None: + model.dcs_update_from(copy_from_model.dcs_instance()) + self._model_scenarios.append((name, model)) + model.dcs_observe(self.notify_model_values_changed) + self.notify_model_scenario_changed() + + def _find_model_id(self, model_id): + for index, (name, model) in enumerate(list(self._model_scenarios)): + if id(model) == model_id: + return index, name, model + else: + raise ValueError("Model not found") + + def rename_scenario(self, model_id, new_name): + index, _, model = self._find_model_id(model_id) + self._model_scenarios[index] = (new_name, model) + self.notify_model_scenario_changed() + + def remove_scenario(self, model_id): + index, _, model = self._find_model_id(model_id) + self._model_scenarios.pop(index) + self.notify_model_scenario_changed() + + def notify_model_scenario_changed(self): + self.multi_model_view.scenarios_updated(self._model_scenarios) + self.comparison_view.scenarios_updated(self._model_scenarios) + + def notify_model_values_changed(self): + self.comparison_view.scenarios_updated(self._model_scenarios) + + +class MultiModelView: + def __init__(self, controller: ExpertApplication): + self._controller = controller + self.widget = widgets.Tab() + self._tab_model_ids: typing.List[int] = [] + self._tab_widgets: typing.List[widgets.Widget] = [] + self._tab_model_views: typing.List[WidgetView] = [] + self._active_tab_index = 0 + + def scenarios_updated(self, model_scenarios: typing.Sequence[ScenarioType]): + """ + Called when a scenario is added/removed/renamed etc. + + Note: Not called when the model state is modified. + + """ + model_scenario_ids = [] + for i, (scenario_name, model) in enumerate(model_scenarios): + if id(model) not in self._tab_model_ids: + self.add_tab(scenario_name, model) + model_scenario_ids.append(id(model)) + tab_index = self._tab_model_ids.index(id(model)) + self.widget.set_title(tab_index, scenario_name) + + # Any remaining model_scenario_ids are no longer needed, so remove + # their tabs. + for tab_index, tab_scenario_id in enumerate(self._tab_model_ids[:]): + if tab_scenario_id not in model_scenario_ids: + self.remove_tab(tab_index) + + def add_tab(self, name, model): + self._tab_model_views.append(WidgetView(model)) + self._tab_model_ids.append(id(model)) + tab_idx = len(self._tab_model_ids) - 1 + tab_widget = widgets.VBox( + children=( + self._build_settings_menu(name, model), + self._tab_model_views[tab_idx].widget, + ) + ) + self._tab_widgets.append(tab_widget) self.update_tab_widget() - def display_titles(self): - for i, name in enumerate(self.scenario_names): - self.tab_widget.set_title(i, name) + def remove_tab(self, tab_index): + assert 0 <= tab_index < len(self._tab_model_ids) + assert len(self._tab_model_ids) > 1 + self._tab_model_ids.pop(tab_index) + self._tab_widgets.pop(tab_index) + self._tab_model_views.pop(tab_index) + if self._active_tab_index >= tab_index: + self._active_tab_index = max(0, self._active_tab_index - 1) + self.update_tab_widget() def update_tab_widget(self): - self.tab_widget.children = self.tabs - self.display_titles() + self.widget.children = tuple(self._tab_widgets) - def build_settings_menu(self, tab_index): + def _build_settings_menu(self, name, model): delete_button = widgets.Button(description='Delete Scenario', button_style='danger') - rename_text_field = widgets.Text(description='Rename Scenario:', value=self.scenario_names[tab_index], + rename_text_field = widgets.Text(description='Rename Scenario:', value=name, style={'description_width': 'auto'}) duplicate_button = widgets.Button(description='Duplicate Scenario', button_style='success') + model_id = id(model) def on_delete_click(b): - self.scenario_names = tuple_without_index(self.scenario_names, tab_index) - self.scenarios = tuple_without_index(self.scenarios, tab_index) - self.tab_views = tuple_without_index(self.tab_views, tab_index) - self.selected_tab = min(0, self.selected_tab - 1) - self.tabs = tuple(widgets.VBox(children=(self.build_settings_menu(i), view.present())) - for i, view in enumerate(self.tab_views)) - self.update_tab_widget() + self._controller.remove_scenario(model_id) def on_rename_text_field(change): - self.scenario_names = tuple(change['new'] if i == tab_index else value - for i, value in enumerate(self.scenario_names)) - self.update_tab_widget() + self._controller.rename_scenario(model_id, new_name=change['new']) def on_duplicate_click(b): - self.scenario_names += (self.scenario_names[tab_index] + " (copy)",) - new_scenario = state.DataclassInstanceState( - models.ExposureModel, - state_builder=CARAStateBuilder(), - ) - new_scenario.dcs_update_from(self.scenarios[tab_index].dcs_instance()) - self.scenarios += (new_scenario,) - - self.tab_views += (WidgetView(new_scenario),) - self.tabs += (widgets.VBox(children=(self.build_settings_menu(len(self.scenario_names) - 1), - self.tab_views[-1].present())),) - self.update_tab_widget() + tab_index = self._tab_model_ids.index(model_id) + name = self.widget.get_title(tab_index) + self._controller.add_scenario(f'{name} (copy)', model) delete_button.on_click(on_delete_click) duplicate_button.on_click(on_duplicate_click) rename_text_field.observe(on_rename_text_field, 'value') - buttons = duplicate_button if tab_index == 0 else widgets.HBox(children=(duplicate_button, delete_button)) + # TODO: This should be dynamic - we don't want to be able to delete the + # last scenario, so this should be controlled in the remove_tab method. + buttons_w_delete = widgets.HBox(children=(duplicate_button, delete_button)) + buttons = duplicate_button if len(self._tab_model_ids) < 2 else buttons_w_delete return widgets.VBox(children=(buttons, rename_text_field)) - @property - def widget(self): - return self.tab_widget - class ComparisonView: def __init__(self): self.figure = self.initialize_figure() - self.ax = self.initialize_axis() + self.ax = self.initialize_axes() + + @property + def widget(self): + return self.figure.canvas @staticmethod def initialize_figure() -> matplotlib.figure.Figure: @@ -502,7 +574,7 @@ class ComparisonView: figure.canvas.header_visible = False return figure - def initialize_axis(self) -> matplotlib.figure.Axes: + def initialize_axes(self) -> matplotlib.figure.Axes: ax = self.figure.add_subplot(1, 1, 1) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) @@ -511,36 +583,33 @@ class ComparisonView: ax.set_title('Concentration of infectious quanta aerosols') return ax + def scenarios_updated(self, scenarios: typing.Sequence[ScenarioType]): + labels, models = zip(*scenarios) + conc_models: typing.Tuple[models.ConcentrationModel] = tuple( + model.concentration_model.dcs_instance() for model in models + ) + self.update_plot(conc_models, labels) + def update_plot(self, conc_models: typing.Tuple[models.ConcentrationModel], labels: typing.Tuple[str]): - self.figure.clf() + self.ax.cla() start, finish = models_start_end(conc_models) - ts = np.linspace(start, finish) + ts = np.linspace(start, finish, num=250) concentrations = [[conc_model.concentration(t) for t in ts] for conc_model in conc_models] - for concentration in concentrations: - self.ax.plot(ts, concentration) + for label, concentration in zip(labels, concentrations): + self.ax.plot(ts, concentration, label=label) top = max(3., max([max(conc) for conc in concentrations])) self.ax.set_ylim(bottom=0., top=top) - self.figure.legend(labels) - + self.ax.legend() self.figure.canvas.draw() - @property - def widget(self): - return self.figure.canvas - -def tuple_without_index(t: typing.Tuple, index: int) -> typing.Tuple: - return t[:index] + t[index + 1:] - - -def models_start_end(models: typing.Iterable[models.ConcentrationModel]) -> typing.Tuple[float, float]: +def models_start_end(models: typing.Sequence[models.ConcentrationModel]) -> typing.Tuple[float, float]: """ - Returns the union of the presence intervals of a collection of ConcentrationModel objects - :param models: An iterable (e.g. list or tuple) of ConcentrationModel objects - :return: A tuple (start, finish) corresponding to the union of the presence intervals + Returns the earliest start and latest end time of a collection of ConcentrationModel objects + """ - start = min(model.infected.presence.boundaries()[0][0] for model in models) - finish = min(model.infected.presence.boundaries()[-1][1] for model in models) - return start, finish + infected_start = min(model.infected.presence.boundaries()[0][0] for model in models) + infected_finish = min(model.infected.presence.boundaries()[-1][1] for model in models) + return infected_start, infected_finish