Merge branch 'comparison-tab' into 'master'

Expert app scenario comparisons

See merge request cara/cara!89
This commit is contained in:
Philip James Elson 2020-11-19 13:38:43 +00:00
commit d05708ea37
2 changed files with 303 additions and 112 deletions

View file

@ -13,7 +13,7 @@ from cara import state
from cara import data
def collapsible(widgets_to_collapse: typing.List, title: str, start_collapsed=True):
def collapsible(widgets_to_collapse: typing.List, title: str, start_collapsed=False):
collapsed = widgets.Accordion([widgets.VBox(widgets_to_collapse)])
collapsed.set_title(0, title)
if start_collapsed:
@ -28,13 +28,66 @@ def widget_group(label_widget_pairs):
return widgets.HBox([labels_w, widgets_w])
class ConcentrationFigure:
#: A scenario is a name and a (mutable) model.
ScenarioType = typing.Tuple[str, state.DataclassState]
class View:
"""
A thing which exposes a ``.widget`` attribute which is a view on some
data. This view is essentially a complex combination of widgets, along with
some event handling capabilities, which may or may not be sent back up to
the underlying controller.
We strive hard to keep "Model" data out of the View (and try to avoid
storing it at all on the View itself), instead relying on being able
to notify, and receive notifications, of important events from the Controller.
"""
pass
class Controller:
"""
The singleton thing which is the top-level Application.
It is responsible for owning the Model data and the Views, and
orchestrating event messages to each if the Model/View change.
"""
pass
def ipympl_canvas(figure):
matplotlib.interactive(False)
ipympl.backend_nbagg.new_figure_manager_given_figure(uuid.uuid1(), figure)
figure.canvas.toolbar_visible = True
figure.canvas.toolbar.collapsed = True
figure.canvas.footer_visible = False
figure.canvas.header_visible = False
return figure.canvas
class ExposureModelResult(View):
def __init__(self):
self.figure = matplotlib.figure.Figure(figsize=(9, 6))
ipympl_canvas(self.figure)
self.html_output = widgets.HTML()
self.ax = self.figure.add_subplot(1, 1, 1)
self.line = None
def update(self, model: models.ConcentrationModel):
@property
def widget(self):
return widgets.VBox([
self.html_output,
self.figure.canvas,
])
def update(self, model: models.ExposureModel):
self.update_plot(model.concentration_model)
self.update_textual_result(model)
def update_plot(self, model: models.ConcentrationModel):
resolution = 600
ts = np.linspace(sorted(model.infected.presence.transition_times())[0],
sorted(model.infected.presence.transition_times())[-1], resolution)
@ -60,70 +113,75 @@ class ConcentrationFigure:
self.ax.set_ylim(bottom=0., top=top)
self.figure.canvas.draw()
def update_textual_result(self, model: models.ExposureModel):
lines = []
P = model.infection_probability()
lines.append(f'Emission rate (quanta/hr): {model.concentration_model.infected.emission_rate_when_present()}')
lines.append(f'Probability of infection: {np.round(P, 0)}%')
def ipympl_canvas(figure: matplotlib.figure.Figure):
# Make a plain matplotlib figure render as a Jupyter widget.
matplotlib.interactive(False)
ipympl.backend_nbagg.new_figure_manager_given_figure(uuid.uuid1(), figure)
figure.canvas.toolbar_visible = True
figure.canvas.toolbar.collapsed = True
figure.canvas.footer_visible = False
figure.canvas.header_visible = False
lines.append(f'Number of exposed: {model.exposed.number}')
new_cases = np.round(model.expected_new_cases(), 1)
lines.append(f'Number of expected new cases: {new_cases}')
R0 = np.round(model.reproduction_number(), 1)
lines.append(f'Reproduction number (R0): {R0}')
self.html_output.value = '<br>\n'.join(lines)
class WidgetView:
class ExposureComparissonResult(View):
def __init__(self):
self.figure = matplotlib.figure.Figure(figsize=(9, 6))
ipympl_canvas(self.figure)
self.ax = self.initialize_axes()
@property
def widget(self):
# Workaround to a bug with ipymlp, which doesn't work well with tabs
# unless the widget is wrapped in a container (it is seen on all tabs otherwise!).
return widgets.HBox([self.figure.canvas])
def initialize_axes(self) -> matplotlib.figure.Axes:
ax = self.figure.add_subplot(1, 1, 1)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.set_xlabel('Time (hours)')
ax.set_ylabel('Concentration ($q/m^3$)')
ax.set_title('Concentration of infectious quanta aerosols')
return ax
def scenarios_updated(self, scenarios: typing.Sequence[ScenarioType], _):
labels, models = zip(*scenarios)
conc_models: typing.Tuple[models.ConcentrationModel] = tuple(
model.concentration_model.dcs_instance() for model in models
)
self.update_plot(conc_models, labels)
def update_plot(self, conc_models: typing.Tuple[models.ConcentrationModel], labels: typing.Tuple[str]):
self.ax.lines.clear()
start, finish = models_start_end(conc_models)
ts = np.linspace(start, finish, num=250)
concentrations = [[conc_model.concentration(t) for t in ts] for conc_model in conc_models]
for label, concentration in zip(labels, concentrations):
self.ax.plot(ts, concentration, label=label)
top = max(3., max([max(conc) for conc in concentrations]))
self.ax.set_ylim(bottom=0., top=top)
self.ax.legend()
self.figure.canvas.draw()
class ModelWidgets(View):
def __init__(self, model_state: state.DataclassState):
self.model_state = model_state
self.model_state.dcs_observe(self.update)
#: The widgets that this view produces (inputs and outputs together)
self.widget = widgets.VBox([])
self.widgets = {}
self.out = widgets.Output()
self.plots = []
self.construct_widgets()
# Trigger the first result.
self.update()
self.construct_widgets(model_state)
def construct_widgets(self):
def construct_widgets(self, model_state: state.DataclassState):
# Build the input widgets.
self._build_widget(self.model_state)
# And the output widget figure.
concentration = ConcentrationFigure()
self.plots.append(concentration)
ipympl_canvas(concentration.figure)
self.widgets['results'] = collapsible([
widgets.HBox([
concentration.figure.canvas,
self.out,
])
], 'Results', start_collapsed=False)
# Join inputs and outputs together in a single widget for convenience.
self.widget.children += (self.widgets['results'], )
def prepare_output(self):
pass
def update(self):
model: models.ExposureModel = self.model_state.dcs_instance()
for plot in self.plots:
plot.update(model.concentration_model)
self.out.clear_output()
with self.out:
P = model.infection_probability()
print(f'Emission rate (quanta/hr): {model.concentration_model.infected.emission_rate_when_present()}')
print(f'Probability of infection: {np.round(P, 0)}%')
print(f'Number of exposed: {model.exposed.number}')
new_cases = np.round(model.expected_new_cases(), 1)
print(f'Number of expected new cases: {new_cases}')
R0 = np.round(model.reproduction_number(), 1)
print(f'Reproduction number (R0): {R0}')
self._build_widget(model_state)
def _build_widget(self, node):
self.widget.children += (self._build_room(node.concentration_model.room),)
@ -160,7 +218,7 @@ class WidgetView:
[widget_group(
[[widgets.Label('Room volume'), room_volume]]
)],
title='Specification of workplace', start_collapsed=False,
title='Specification of workplace',
)
return widget
@ -347,7 +405,7 @@ class WidgetView:
w = collapsible(
[widget_group([[widgets.Label('Ventilation type'), ventilation_w]])]
+ list(ventilation_widgets.values()),
title='Ventilation scheme'
title='Ventilation scheme',
)
return w
@ -404,75 +462,195 @@ class CARAStateBuilder(state.StateBuilder):
return s
class ExpertApplication:
class ExpertApplication(Controller):
def __init__(self):
default_scenario = state.DataclassInstanceState(
models.ExposureModel,
state_builder=CARAStateBuilder(),
)
default_scenario.dcs_update_from(baseline_model)
self._debug_output = widgets.Output()
#: A list of scenario name and ModelState instances. This is intended to be
#: mutated. Any mutation should notify the appropriate Views for handling.
self._model_scenarios: typing.List[ScenarioType] = []
self._active_scenario = 0
self.multi_model_view = MultiModelView(self)
self.comparison_view = ExposureComparissonResult()
self.current_scenario_figure = ExposureModelResult()
self._results_tab = widgets.Tab(children=(
self.current_scenario_figure.widget,
self.comparison_view.widget,
# self._debug_output,
))
for i, title in enumerate(['Current scenario', 'Scenario comparison', "Debug"]):
self._results_tab.set_title(i, title)
self.widget = widgets.HBox(
children=(
self.multi_model_view.widget,
self._results_tab,
),
)
self.add_scenario('Scenario 1')
def build_new_model(self):
default_model = state.DataclassInstanceState(
models.ExposureModel,
state_builder=CARAStateBuilder(),
)
default_model.dcs_update_from(baseline_model)
# For the time-being, we have to initialise the select states. Careful
# as values might not correspond to what the baseline model says.
default_scenario.concentration_model.infected.mask.dcs_select('No mask')
self.scenarios = (default_scenario,)
self.scenario_names = ('Scenario 1',)
self.views = (WidgetView(default_scenario),)
self.selected_tab = 0
self.tabs = (widgets.VBox(children=(self.build_settings_menu(0), self.views[0].present())),)
self.tab_widget = widgets.Tab(children=self.tabs)
self.display_titles()
default_model.concentration_model.infected.mask.dcs_select('No mask')
return default_model
def display_titles(self):
for i, name in enumerate(self.scenario_names):
self.tab_widget.set_title(i, name)
def add_scenario(self, name, copy_from_model: typing.Optional[state.DataclassInstanceState] = None):
model = self.build_new_model()
if copy_from_model is not None:
model.dcs_update_from(copy_from_model.dcs_instance())
self._model_scenarios.append((name, model))
self._active_scenario = len(self._model_scenarios) - 1
model.dcs_observe(self.notify_model_values_changed)
self.notify_scenarios_changed()
def _find_model_id(self, model_id):
for index, (name, model) in enumerate(list(self._model_scenarios)):
if id(model) == model_id:
return index, name, model
else:
raise ValueError("Model not found")
def rename_scenario(self, model_id, new_name):
index, _, model = self._find_model_id(model_id)
self._model_scenarios[index] = (new_name, model)
self.notify_scenarios_changed()
def remove_scenario(self, model_id):
index, _, model = self._find_model_id(model_id)
self._model_scenarios.pop(index)
if self._active_scenario >= index:
self._active_scenario = max(self._active_scenario - 1, 0)
self.notify_scenarios_changed()
def set_active_scenario(self, model_id):
index, _, model = self._find_model_id(model_id)
self._active_scenario = index
self.notify_scenarios_changed()
self.notify_model_values_changed()
def notify_scenarios_changed(self):
"""
Occurs when the set of scenarios has been modified, but not if the values of the scenario has changed.
"""
self.multi_model_view.scenarios_updated(self._model_scenarios, self._active_scenario)
self.comparison_view.scenarios_updated(self._model_scenarios, self._active_scenario)
def notify_model_values_changed(self):
"""
Occurs when *any* value in *any* of the scenarios has been modified.
"""
self.comparison_view.scenarios_updated(self._model_scenarios, self._active_scenario)
self.current_scenario_figure.update(self._model_scenarios[self._active_scenario][1].dcs_instance())
class MultiModelView(View):
def __init__(self, controller: ExpertApplication):
self._controller = controller
self.widget = widgets.Tab()
self.widget.observe(self._on_tab_change, 'selected_index')
self._tab_model_ids: typing.List[int] = []
self._tab_widgets: typing.List[widgets.Widget] = []
self._tab_model_views: typing.List[ModelWidgets] = []
def scenarios_updated(
self,
model_scenarios: typing.Sequence[ScenarioType],
active_scenario_index: int
):
"""
Called when a scenario is added/removed/renamed etc.
Note: Not called when the model state is modified.
"""
model_scenario_ids = []
for i, (scenario_name, model) in enumerate(model_scenarios):
if id(model) not in self._tab_model_ids:
self.add_tab(scenario_name, model)
model_scenario_ids.append(id(model))
tab_index = self._tab_model_ids.index(id(model))
self.widget.set_title(tab_index, scenario_name)
# Any remaining model_scenario_ids are no longer needed, so remove
# their tabs.
for tab_index, tab_scenario_id in enumerate(self._tab_model_ids[:]):
if tab_scenario_id not in model_scenario_ids:
self.remove_tab(tab_index)
assert self._tab_model_ids == model_scenario_ids
self.widget.selected_index = active_scenario_index
def add_tab(self, name, model):
self._tab_model_views.append(ModelWidgets(model))
self._tab_model_ids.append(id(model))
tab_idx = len(self._tab_model_ids) - 1
tab_widget = widgets.VBox(
children=(
self._build_settings_menu(name, model),
self._tab_model_views[tab_idx].widget,
)
)
self._tab_widgets.append(tab_widget)
self.update_tab_widget()
def remove_tab(self, tab_index):
assert 0 <= tab_index < len(self._tab_model_ids)
assert len(self._tab_model_ids) > 1
self._tab_model_ids.pop(tab_index)
self._tab_widgets.pop(tab_index)
self._tab_model_views.pop(tab_index)
if self._active_tab_index >= tab_index:
self._active_tab_index = max(0, self._active_tab_index - 1)
self.update_tab_widget()
def update_tab_widget(self):
self.tab_widget.children = self.tabs
self.display_titles()
self.widget.children = tuple(self._tab_widgets)
def build_settings_menu(self, tab_index):
def _on_tab_change(self, change):
self._controller.set_active_scenario(
self._tab_model_ids[change['new']]
)
def _build_settings_menu(self, name, model):
delete_button = widgets.Button(description='Delete Scenario', button_style='danger')
rename_text_field = widgets.Text(description='Rename Scenario:', value=self.scenario_names[tab_index],
rename_text_field = widgets.Text(description='Rename Scenario:', value=name,
style={'description_width': 'auto'})
duplicate_button = widgets.Button(description='Duplicate Scenario', button_style='success')
model_id = id(model)
def on_delete_click(b):
self.scenario_names = tuple_without_index(self.scenario_names, tab_index)
self.scenarios = tuple_without_index(self.scenarios, tab_index)
self.views = tuple_without_index(self.views, tab_index)
self.selected_tab = min(0, self.selected_tab - 1)
self.tabs = tuple(widgets.VBox(children=(self.build_settings_menu(i), view.present()))
for i, view in enumerate(self.views))
self.update_tab_widget()
self._controller.remove_scenario(model_id)
def on_rename_text_field(change):
self.scenario_names = tuple(change['new'] if i == tab_index else value
for i, value in enumerate(self.scenario_names))
self.update_tab_widget()
self._controller.rename_scenario(model_id, new_name=change['new'])
def on_duplicate_click(b):
self.scenario_names += (self.scenario_names[tab_index] + " (copy)",)
new_scenario = state.DataclassInstanceState(
models.ExposureModel,
state_builder=CARAStateBuilder(),
)
new_scenario.dcs_update_from(self.scenarios[tab_index].dcs_instance())
self.scenarios += (new_scenario,)
self.views += (WidgetView(new_scenario),)
self.tabs += (widgets.VBox(children=(self.build_settings_menu(len(self.scenario_names) - 1), self.views[-1].present())),)
self.update_tab_widget()
tab_index = self._tab_model_ids.index(model_id)
name = self.widget.get_title(tab_index)
self._controller.add_scenario(f'{name} (copy)', model)
delete_button.on_click(on_delete_click)
duplicate_button.on_click(on_duplicate_click)
rename_text_field.observe(on_rename_text_field, 'value')
buttons = duplicate_button if tab_index == 0 else widgets.HBox(children=(duplicate_button, delete_button))
# TODO: This should be dynamic - we don't want to be able to delete the
# last scenario, so this should be controlled in the remove_tab method.
buttons_w_delete = widgets.HBox(children=(duplicate_button, delete_button))
buttons = duplicate_button if len(self._tab_model_ids) < 2 else buttons_w_delete
return widgets.VBox(children=(buttons, rename_text_field))
@property
def widget(self):
return self.tab_widget
def models_start_end(models: typing.Sequence[models.ConcentrationModel]) -> typing.Tuple[float, float]:
"""
Returns the earliest start and latest end time of a collection of ConcentrationModel objects
def tuple_without_index(t: typing.Tuple, index: int) -> typing.Tuple:
return t[:index] + t[index + 1:]
"""
infected_start = min(model.infected.presence.boundaries()[0][0] for model in models)
infected_finish = min(model.infected.presence.boundaries()[-1][1] for model in models)
return infected_start, infected_finish

View file

@ -1,9 +1,22 @@
import pytest
import cara.apps
def test_app():
@pytest.fixture
def expert_app():
return cara.apps.ExpertApplication()
def test_app(expert_app):
# To start with, let's just test that the application runs. We don't try to
# do anything fancy to verify how it looks etc., we leave that for manual
# testing.
expert_app = cara.apps.ExpertApplication()
assert expert_app.scenario_names[0] == "Scenario 1"
assert expert_app._model_scenarios[0][0] == "Scenario 1"
def test_new_scenario_changes_tab(expert_app):
# Adding a new scenario should change the tab index of the multi-model view.
assert expert_app.multi_model_view.widget.selected_index == 0
expert_app.add_scenario("Another scenario")
assert expert_app.multi_model_view.widget.selected_index == 1