diff --git a/cara/apps/expert.py b/cara/apps/expert.py index 9955efaa..95f3ae99 100644 --- a/cara/apps/expert.py +++ b/cara/apps/expert.py @@ -9,8 +9,9 @@ import matplotlib.figure import numpy as np import mplcursors from matplotlib import pyplot as plt -from numpy import object_ from cara import data, models, state +import matplotlib.lines as mlines +import matplotlib.patches as patches def collapsible(widgets_to_collapse: typing.List, title: str, start_collapsed=False): collapsed = widgets.Accordion([widgets.VBox(widgets_to_collapse)]) @@ -107,11 +108,10 @@ def ipympl_canvas(figure): class ExposureModelResult(View): def __init__(self): - self.figure = matplotlib.figure.Figure(figsize=(9, 5)) + self.figure = matplotlib.figure.Figure(figsize=(9, 6)) ipympl_canvas(self.figure) self.html_output = widgets.HTML() self.ax = self.figure.add_subplot(1, 1, 1) - self.figure.subplots_adjust(left=8, right=9) self.ax2 = self.ax.twinx() self.concentration_line = None self.concentration_area = None @@ -124,6 +124,15 @@ class ExposureModelResult(View): self.figure.canvas, ]) + def initialize_axes(self) -> matplotlib.figure.Axes: + ax = self.figure.add_subplot(1, 1, 1) + ax.spines['right'].set_visible(False) + ax.spines['top'].set_visible(False) + ax.set_xlabel('Time (hours)') + ax.set_ylabel('Concentration ($virions/m^{3}$)') + ax.set_title('Concentration of virions') + return ax + def update(self, model: models.ExposureModel): self.update_plot(model) self.update_textual_result(model) @@ -150,7 +159,7 @@ class ExposureModelResult(View): ax.set_xlabel('Time (hours)') ax.set_ylabel('Mean concentration ($virions/m^{3}$)') - ax.set_title('Concentration of virions and Cumulative dose') + ax.set_title('Concentration of virions \nand Cumulative dose') #cursor = SnaptoCursor(self.ax, ts, concentration) @@ -191,15 +200,14 @@ class ExposureModelResult(View): cumulative_top = max([1e-5, max(cumulative_doses)]) self.ax2.set_ylim(bottom=0., top=cumulative_top) - self.ax.set_xlim(left = min(model.concentration_model.infected.presence.present_times[0]), right = max(model.concentration_model.infected.presence.present_times[1])) + self.ax.set_xlim(left = min(min(model.concentration_model.infected.presence.present_times[0]), min(model.exposed.presence.present_times[0])), right = max(max(model.concentration_model.infected.presence.present_times[1]), max(model.exposed.presence.present_times[1]))) + + figure_legends = [mlines.Line2D([], [], color='#3530fe', markersize=15, label='Mean concentration'), + mlines.Line2D([], [], color='#0000c8', markersize=15, ls="dotted", label='Cumulative dose'), + patches.Patch(edgecolor="#96cbff", facecolor='#96cbff', label='Presence of exposed person(s)')] + self.figure.legend(handles=figure_legends) - legend = self.ax.legend(bbox_to_anchor=(1.15, 1), frameon=False) - self.ax2.legend(bbox_to_anchor=(1.15, 0.89), frameon=False) - #sself.marker=plt.connect('motion_notify_event', mouse_move) - self.figure.canvas.draw() - self.figure.tight_layout() - return legend def update_textual_result(self, model: models.ExposureModel): lines = [] @@ -215,7 +223,7 @@ class ExposureModelResult(View): R0 = np.round(np.array(model.reproduction_number()).mean(), 1) lines.append(f'Reproduction number (R0): {R0}') - self.html_output.value = '
\n'.join(lines) + self.html_output.value = '
\n'.join(lines) class ExposureComparissonResult(View): @@ -223,7 +231,7 @@ class ExposureComparissonResult(View): self.figure = matplotlib.figure.Figure(figsize=(9, 6)) ipympl_canvas(self.figure) self.html_output = widgets.HTML() - self.ax = self.figure.add_subplot(2, 1, 2) + self.ax = self.figure.add_subplot(1, 1, 1) self.ax2 = self.ax.twinx() self.concentration_line = None self.cumulative_line = None @@ -672,16 +680,17 @@ class ModelWidgets(View): def on_expiration_change(change): expiration = models.Expiration.types[change['new']] node.dcs_update_from(expiration) + expiration_choice.observe(on_expiration_change, names=['value']) return widgets.HBox([widgets.Label("Expiration"), expiration_choice], layout=widgets.Layout(justify_content='space-between')) def _build_viral_load(self, node): - - viral_load_in_sputum = widgets.IntText(value=node.viral_load_in_sputum, PlaceHolder='1e9') + viral_load_in_sputum = widgets.Text(continuous_update=False, value=("{:.2e}".format(node.viral_load_in_sputum))) def viral_load_change(change): - node.viral_load_in_sputum = change['new'] + viral_load_in_sputum.value = "{:.2e}".format(float(change['new'])) + node.viral_load_in_sputum = float(viral_load_in_sputum.value) viral_load_in_sputum.observe(viral_load_change, names=['value'])