diff --git a/cara/apps/expert.py b/cara/apps/expert.py index e01c1099..fcb58070 100644 --- a/cara/apps/expert.py +++ b/cara/apps/expert.py @@ -28,12 +28,36 @@ def widget_group(label_widget_pairs): return widgets.HBox([labels_w, widgets_w]) -class ConcentrationFigure: +#: A scenario is a name and a (mutable) model. +ScenarioType = typing.Tuple[str, state.DataclassState] + + +class View: + pass + +def ipympl_canvas(figure): + matplotlib.interactive(False) + ipympl.backend_nbagg.new_figure_manager_given_figure(uuid.uuid1(), figure) + figure.canvas.toolbar_visible = True + figure.canvas.toolbar.collapsed = True + figure.canvas.footer_visible = False + figure.canvas.header_visible = False + return figure.canvas + + +class ConcentrationFigure(View): def __init__(self): self.figure = matplotlib.figure.Figure(figsize=(9, 6)) + ipympl_canvas(self.figure) self.ax = self.figure.add_subplot(1, 1, 1) self.line = None + @property + def widget(self): + # Workaround to a bug with ipymlp, which doesn't work well with tabs + # unless the widget is wrapped in a container (it is seen on all tabs otherwise!). + return widgets.HBox([self.figure.canvas]) + def update(self, model: models.ConcentrationModel): resolution = 600 ts = np.linspace(sorted(model.infected.presence.transition_times())[0], @@ -61,17 +85,54 @@ class ConcentrationFigure: self.figure.canvas.draw() -def ipympl_canvas(figure: matplotlib.figure.Figure): - # Make a plain matplotlib figure render as a Jupyter widget. - matplotlib.interactive(False) - ipympl.backend_nbagg.new_figure_manager_given_figure(uuid.uuid1(), figure) - figure.canvas.toolbar_visible = True - figure.canvas.toolbar.collapsed = True - figure.canvas.footer_visible = False - figure.canvas.header_visible = False +class ComparisonFigure(View): + def __init__(self): + self.figure = matplotlib.figure.Figure(figsize=(9, 6)) + ipympl_canvas(self.figure) + self.ax = self.initialize_axes() + + @property + def widget(self): + # Workaround to a bug with ipymlp, which doesn't work well with tabs + # unless the widget is wrapped in a container (it is seen on all tabs otherwise!). + return widgets.HBox([self.figure.canvas]) + + def initialize_axes(self) -> matplotlib.figure.Axes: + ax = self.figure.add_subplot(1, 1, 1) + ax.spines['right'].set_visible(False) + ax.spines['top'].set_visible(False) + ax.set_xlabel('Time (hours)') + ax.set_ylabel('Concentration ($q/m^3$)') + ax.set_title('Concentration of infectious quanta aerosols') + return ax + + def scenarios_updated( + self, + scenarios: typing.Sequence[ScenarioType], + active_scenario_index: int + ): + labels, models = zip(*scenarios) + conc_models: typing.Tuple[models.ConcentrationModel] = tuple( + model.concentration_model.dcs_instance() for model in models + ) + self.update_plot(conc_models, labels) + + def update_plot(self, conc_models: typing.Tuple[models.ConcentrationModel], labels: typing.Tuple[str]): + self.ax.lines.clear() + start, finish = models_start_end(conc_models) + ts = np.linspace(start, finish, num=250) + concentrations = [[conc_model.concentration(t) for t in ts] for conc_model in conc_models] + for label, concentration in zip(labels, concentrations): + self.ax.plot(ts, concentration, label=label) + + top = max(3., max([max(conc) for conc in concentrations])) + self.ax.set_ylim(bottom=0., top=top) + + self.ax.legend() + self.figure.canvas.draw() -class WidgetView: +class ModelWidgets(View): def __init__(self, model_state: state.DataclassState): self.model_state = model_state self.model_state.dcs_observe(self.update) @@ -91,17 +152,16 @@ class WidgetView: # And the output widget figure. concentration = ConcentrationFigure() self.plots.append(concentration) - ipympl_canvas(concentration.figure) - self.widgets['results'] = collapsible([ - widgets.HBox([ - concentration.figure.canvas, - self.out, - ]) - ], 'Results', start_collapsed=False) + # self.widgets['results'] = collapsible([ + # widgets.HBox([ + # concentration.widget, + # self.out, + # ]) + # ], 'Results', start_collapsed=False) # Join inputs and outputs together in a single widget for convenience. - self.widget.children += (self.widgets['results'], ) + # self.widget.children += (self.widgets['results'], ) def prepare_output(self): pass @@ -404,10 +464,6 @@ class CARAStateBuilder(state.StateBuilder): return s -#: A scenario is a name and a (mutable) model. -ScenarioType = typing.Tuple[str, state.DataclassState] - - class ExpertApplication: def __init__(self): self._debug_output = widgets.Output() @@ -415,13 +471,21 @@ class ExpertApplication: #: A list of scenario name and ModelState instances. This is intended to be #: mutated. Any mutation should notify the appropriate Views for handling. self._model_scenarios: typing.List[ScenarioType] = [] + self._active_scenario = 0 self.multi_model_view = MultiModelView(self) - self.comparison_view = ComparisonView() - self.widget = widgets.VBox( + self.comparison_view = ComparisonFigure() + self.current_scenario_figure = ConcentrationFigure() + self._results_tab = widgets.Tab(children=( + self.current_scenario_figure.widget, + self.comparison_view.widget, + # self._debug_output, + )) + for i, title in enumerate(['Current scenario', 'Scenario comparison', "Debug"]): + self._results_tab.set_title(i, title) + self.widget = widgets.HBox( children=( self.multi_model_view.widget, - self.comparison_view.widget, - self._debug_output, + self._results_tab, ), ) self.add_scenario('Scenario 1') @@ -442,6 +506,7 @@ class ExpertApplication: if copy_from_model is not None: model.dcs_update_from(copy_from_model.dcs_instance()) self._model_scenarios.append((name, model)) + self._active_scenario = len(self._model_scenarios) - 1 model.dcs_observe(self.notify_model_values_changed) self.notify_model_scenario_changed() @@ -460,26 +525,46 @@ class ExpertApplication: def remove_scenario(self, model_id): index, _, model = self._find_model_id(model_id) self._model_scenarios.pop(index) + if self._active_scenario >= index: + self._active_scenario = max(self._active_scenario - 1, 0) self.notify_model_scenario_changed() + def set_active_scenario(self, model_id): + index, _, model = self._find_model_id(model_id) + self._active_scenario = index + self.notify_model_scenario_changed() + self.notify_model_values_changed() + def notify_model_scenario_changed(self): - self.multi_model_view.scenarios_updated(self._model_scenarios) - self.comparison_view.scenarios_updated(self._model_scenarios) + """ + Occurs when the set of scenarios has been modified, but not if the values of the scenario has changed. + + """ + self.multi_model_view.scenarios_updated(self._model_scenarios, self._active_scenario) + self.comparison_view.scenarios_updated(self._model_scenarios, self._active_scenario) def notify_model_values_changed(self): - self.comparison_view.scenarios_updated(self._model_scenarios) + """ + Occurs when *any* value in *any* of the scenarios has been modified. + """ + self.comparison_view.scenarios_updated(self._model_scenarios, self._active_scenario) + self.current_scenario_figure.update(self._model_scenarios[self._active_scenario][1].concentration_model.dcs_instance()) class MultiModelView: def __init__(self, controller: ExpertApplication): self._controller = controller self.widget = widgets.Tab() + self.widget.observe(self._on_tab_change, 'selected_index') self._tab_model_ids: typing.List[int] = [] self._tab_widgets: typing.List[widgets.Widget] = [] - self._tab_model_views: typing.List[WidgetView] = [] - self._active_tab_index = 0 + self._tab_model_views: typing.List[ModelWidgets] = [] - def scenarios_updated(self, model_scenarios: typing.Sequence[ScenarioType]): + def scenarios_updated( + self, + model_scenarios: typing.Sequence[ScenarioType], + active_scenario_index: int + ): """ Called when a scenario is added/removed/renamed etc. @@ -500,8 +585,12 @@ class MultiModelView: if tab_scenario_id not in model_scenario_ids: self.remove_tab(tab_index) + assert self._tab_model_ids == model_scenario_ids + + self.widget.selected_index = active_scenario_index + def add_tab(self, name, model): - self._tab_model_views.append(WidgetView(model)) + self._tab_model_views.append(ModelWidgets(model)) self._tab_model_ids.append(id(model)) tab_idx = len(self._tab_model_ids) - 1 tab_widget = widgets.VBox( @@ -526,6 +615,11 @@ class MultiModelView: def update_tab_widget(self): self.widget.children = tuple(self._tab_widgets) + def _on_tab_change(self, change): + self._controller.set_active_scenario( + self._tab_model_ids[change['new']] + ) + def _build_settings_menu(self, name, model): delete_button = widgets.Button(description='Delete Scenario', button_style='danger') rename_text_field = widgets.Text(description='Rename Scenario:', value=name, @@ -554,57 +648,6 @@ class MultiModelView: return widgets.VBox(children=(buttons, rename_text_field)) -class ComparisonView: - def __init__(self): - self.figure = self.initialize_figure() - self.ax = self.initialize_axes() - - @property - def widget(self): - return self.figure.canvas - - @staticmethod - def initialize_figure() -> matplotlib.figure.Figure: - figure = matplotlib.figure.Figure(figsize=(9, 6)) - matplotlib.interactive(False) - ipympl.backend_nbagg.new_figure_manager_given_figure(uuid.uuid1(), figure) - figure.canvas.toolbar_visible = True - figure.canvas.toolbar.collapsed = True - figure.canvas.footer_visible = False - figure.canvas.header_visible = False - return figure - - def initialize_axes(self) -> matplotlib.figure.Axes: - ax = self.figure.add_subplot(1, 1, 1) - ax.spines['right'].set_visible(False) - ax.spines['top'].set_visible(False) - ax.set_xlabel('Time (hours)') - ax.set_ylabel('Concentration ($q/m^3$)') - ax.set_title('Concentration of infectious quanta aerosols') - return ax - - def scenarios_updated(self, scenarios: typing.Sequence[ScenarioType]): - labels, models = zip(*scenarios) - conc_models: typing.Tuple[models.ConcentrationModel] = tuple( - model.concentration_model.dcs_instance() for model in models - ) - self.update_plot(conc_models, labels) - - def update_plot(self, conc_models: typing.Tuple[models.ConcentrationModel], labels: typing.Tuple[str]): - self.ax.cla() - start, finish = models_start_end(conc_models) - ts = np.linspace(start, finish, num=250) - concentrations = [[conc_model.concentration(t) for t in ts] for conc_model in conc_models] - for label, concentration in zip(labels, concentrations): - self.ax.plot(ts, concentration, label=label) - - top = max(3., max([max(conc) for conc in concentrations])) - self.ax.set_ylim(bottom=0., top=top) - - self.ax.legend() - self.figure.canvas.draw() - - def models_start_end(models: typing.Sequence[models.ConcentrationModel]) -> typing.Tuple[float, float]: """ Returns the earliest start and latest end time of a collection of ConcentrationModel objects