Split the MultiModelView so that its data is being handled by the Controller (ExpertApp).

This commit is contained in:
Phil Elson 2020-11-19 11:30:43 +01:00
parent 26f621d5af
commit 12cb3d1427

View file

@ -404,92 +404,164 @@ class CARAStateBuilder(state.StateBuilder):
return s
#: A scenario is a name and a (mutable) model.
ScenarioType = typing.Tuple[str, state.DataclassState]
class ExpertApplication:
def __init__(self):
self.multi_model_view = MultiModelView()
self._debug_output = widgets.Output()
#: A list of scenario name and ModelState instances. This is intended to be
#: mutated. Any mutation should notify the appropriate Views for handling.
self._model_scenarios: typing.List[ScenarioType] = []
self.multi_model_view = MultiModelView(self)
self.comparison_view = ComparisonView()
self.app = widgets.VBox(children=(self.multi_model_view.widget, self.comparison_view.widget))
self.widget = widgets.VBox(
children=(
self.multi_model_view.widget,
self.comparison_view.widget,
self._debug_output,
),
)
self.add_scenario('Scenario 1')
@property
def widget(self):
return self.app
class MultiModelView:
def __init__(self):
default_scenario = state.DataclassInstanceState(
def build_new_model(self):
default_model = state.DataclassInstanceState(
models.ExposureModel,
state_builder=CARAStateBuilder(),
)
default_scenario.dcs_update_from(baseline_model)
default_model.dcs_update_from(baseline_model)
# For the time-being, we have to initialise the select states. Careful
# as values might not correspond to what the baseline model says.
default_scenario.concentration_model.infected.mask.dcs_select('No mask')
self.scenarios = (default_scenario,)
self.scenario_names = ('Scenario 1',)
self.tab_views = (WidgetView(default_scenario),)
self.selected_tab = 0
self.tabs = (widgets.VBox(children=(self.build_settings_menu(0), self.tab_views[0].present())),)
self.tab_widget = widgets.Tab()
default_model.concentration_model.infected.mask.dcs_select('No mask')
return default_model
def add_scenario(self, name, copy_from_model: typing.Optional[state.DataclassInstanceState] = None):
model = self.build_new_model()
if copy_from_model is not None:
model.dcs_update_from(copy_from_model.dcs_instance())
self._model_scenarios.append((name, model))
model.dcs_observe(self.notify_model_values_changed)
self.notify_model_scenario_changed()
def _find_model_id(self, model_id):
for index, (name, model) in enumerate(list(self._model_scenarios)):
if id(model) == model_id:
return index, name, model
else:
raise ValueError("Model not found")
def rename_scenario(self, model_id, new_name):
index, _, model = self._find_model_id(model_id)
self._model_scenarios[index] = (new_name, model)
self.notify_model_scenario_changed()
def remove_scenario(self, model_id):
index, _, model = self._find_model_id(model_id)
self._model_scenarios.pop(index)
self.notify_model_scenario_changed()
def notify_model_scenario_changed(self):
self.multi_model_view.scenarios_updated(self._model_scenarios)
self.comparison_view.scenarios_updated(self._model_scenarios)
def notify_model_values_changed(self):
self.comparison_view.scenarios_updated(self._model_scenarios)
class MultiModelView:
def __init__(self, controller: ExpertApplication):
self._controller = controller
self.widget = widgets.Tab()
self._tab_model_ids: typing.List[int] = []
self._tab_widgets: typing.List[widgets.Widget] = []
self._tab_model_views: typing.List[WidgetView] = []
self._active_tab_index = 0
def scenarios_updated(self, model_scenarios: typing.Sequence[ScenarioType]):
"""
Called when a scenario is added/removed/renamed etc.
Note: Not called when the model state is modified.
"""
model_scenario_ids = []
for i, (scenario_name, model) in enumerate(model_scenarios):
if id(model) not in self._tab_model_ids:
self.add_tab(scenario_name, model)
model_scenario_ids.append(id(model))
tab_index = self._tab_model_ids.index(id(model))
self.widget.set_title(tab_index, scenario_name)
# Any remaining model_scenario_ids are no longer needed, so remove
# their tabs.
for tab_index, tab_scenario_id in enumerate(self._tab_model_ids[:]):
if tab_scenario_id not in model_scenario_ids:
self.remove_tab(tab_index)
def add_tab(self, name, model):
self._tab_model_views.append(WidgetView(model))
self._tab_model_ids.append(id(model))
tab_idx = len(self._tab_model_ids) - 1
tab_widget = widgets.VBox(
children=(
self._build_settings_menu(name, model),
self._tab_model_views[tab_idx].widget,
)
)
self._tab_widgets.append(tab_widget)
self.update_tab_widget()
def display_titles(self):
for i, name in enumerate(self.scenario_names):
self.tab_widget.set_title(i, name)
def remove_tab(self, tab_index):
assert 0 <= tab_index < len(self._tab_model_ids)
assert len(self._tab_model_ids) > 1
self._tab_model_ids.pop(tab_index)
self._tab_widgets.pop(tab_index)
self._tab_model_views.pop(tab_index)
if self._active_tab_index >= tab_index:
self._active_tab_index = max(0, self._active_tab_index - 1)
self.update_tab_widget()
def update_tab_widget(self):
self.tab_widget.children = self.tabs
self.display_titles()
self.widget.children = tuple(self._tab_widgets)
def build_settings_menu(self, tab_index):
def _build_settings_menu(self, name, model):
delete_button = widgets.Button(description='Delete Scenario', button_style='danger')
rename_text_field = widgets.Text(description='Rename Scenario:', value=self.scenario_names[tab_index],
rename_text_field = widgets.Text(description='Rename Scenario:', value=name,
style={'description_width': 'auto'})
duplicate_button = widgets.Button(description='Duplicate Scenario', button_style='success')
model_id = id(model)
def on_delete_click(b):
self.scenario_names = tuple_without_index(self.scenario_names, tab_index)
self.scenarios = tuple_without_index(self.scenarios, tab_index)
self.tab_views = tuple_without_index(self.tab_views, tab_index)
self.selected_tab = min(0, self.selected_tab - 1)
self.tabs = tuple(widgets.VBox(children=(self.build_settings_menu(i), view.present()))
for i, view in enumerate(self.tab_views))
self.update_tab_widget()
self._controller.remove_scenario(model_id)
def on_rename_text_field(change):
self.scenario_names = tuple(change['new'] if i == tab_index else value
for i, value in enumerate(self.scenario_names))
self.update_tab_widget()
self._controller.rename_scenario(model_id, new_name=change['new'])
def on_duplicate_click(b):
self.scenario_names += (self.scenario_names[tab_index] + " (copy)",)
new_scenario = state.DataclassInstanceState(
models.ExposureModel,
state_builder=CARAStateBuilder(),
)
new_scenario.dcs_update_from(self.scenarios[tab_index].dcs_instance())
self.scenarios += (new_scenario,)
self.tab_views += (WidgetView(new_scenario),)
self.tabs += (widgets.VBox(children=(self.build_settings_menu(len(self.scenario_names) - 1),
self.tab_views[-1].present())),)
self.update_tab_widget()
tab_index = self._tab_model_ids.index(model_id)
name = self.widget.get_title(tab_index)
self._controller.add_scenario(f'{name} (copy)', model)
delete_button.on_click(on_delete_click)
duplicate_button.on_click(on_duplicate_click)
rename_text_field.observe(on_rename_text_field, 'value')
buttons = duplicate_button if tab_index == 0 else widgets.HBox(children=(duplicate_button, delete_button))
# TODO: This should be dynamic - we don't want to be able to delete the
# last scenario, so this should be controlled in the remove_tab method.
buttons_w_delete = widgets.HBox(children=(duplicate_button, delete_button))
buttons = duplicate_button if len(self._tab_model_ids) < 2 else buttons_w_delete
return widgets.VBox(children=(buttons, rename_text_field))
@property
def widget(self):
return self.tab_widget
class ComparisonView:
def __init__(self):
self.figure = self.initialize_figure()
self.ax = self.initialize_axis()
self.ax = self.initialize_axes()
@property
def widget(self):
return self.figure.canvas
@staticmethod
def initialize_figure() -> matplotlib.figure.Figure:
@ -502,7 +574,7 @@ class ComparisonView:
figure.canvas.header_visible = False
return figure
def initialize_axis(self) -> matplotlib.figure.Axes:
def initialize_axes(self) -> matplotlib.figure.Axes:
ax = self.figure.add_subplot(1, 1, 1)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
@ -511,36 +583,33 @@ class ComparisonView:
ax.set_title('Concentration of infectious quanta aerosols')
return ax
def scenarios_updated(self, scenarios: typing.Sequence[ScenarioType]):
labels, models = zip(*scenarios)
conc_models: typing.Tuple[models.ConcentrationModel] = tuple(
model.concentration_model.dcs_instance() for model in models
)
self.update_plot(conc_models, labels)
def update_plot(self, conc_models: typing.Tuple[models.ConcentrationModel], labels: typing.Tuple[str]):
self.figure.clf()
self.ax.cla()
start, finish = models_start_end(conc_models)
ts = np.linspace(start, finish)
ts = np.linspace(start, finish, num=250)
concentrations = [[conc_model.concentration(t) for t in ts] for conc_model in conc_models]
for concentration in concentrations:
self.ax.plot(ts, concentration)
for label, concentration in zip(labels, concentrations):
self.ax.plot(ts, concentration, label=label)
top = max(3., max([max(conc) for conc in concentrations]))
self.ax.set_ylim(bottom=0., top=top)
self.figure.legend(labels)
self.ax.legend()
self.figure.canvas.draw()
@property
def widget(self):
return self.figure.canvas
def tuple_without_index(t: typing.Tuple, index: int) -> typing.Tuple:
return t[:index] + t[index + 1:]
def models_start_end(models: typing.Iterable[models.ConcentrationModel]) -> typing.Tuple[float, float]:
def models_start_end(models: typing.Sequence[models.ConcentrationModel]) -> typing.Tuple[float, float]:
"""
Returns the union of the presence intervals of a collection of ConcentrationModel objects
:param models: An iterable (e.g. list or tuple) of ConcentrationModel objects
:return: A tuple (start, finish) corresponding to the union of the presence intervals
Returns the earliest start and latest end time of a collection of ConcentrationModel objects
"""
start = min(model.infected.presence.boundaries()[0][0] for model in models)
finish = min(model.infected.presence.boundaries()[-1][1] for model in models)
return start, finish
infected_start = min(model.infected.presence.boundaries()[0][0] for model in models)
infected_finish = min(model.infected.presence.boundaries()[-1][1] for model in models)
return infected_start, infected_finish