diff --git a/cara/apps/expert.py b/cara/apps/expert.py
index 9955efaa..95f3ae99 100644
--- a/cara/apps/expert.py
+++ b/cara/apps/expert.py
@@ -9,8 +9,9 @@ import matplotlib.figure
import numpy as np
import mplcursors
from matplotlib import pyplot as plt
-from numpy import object_
from cara import data, models, state
+import matplotlib.lines as mlines
+import matplotlib.patches as patches
def collapsible(widgets_to_collapse: typing.List, title: str, start_collapsed=False):
collapsed = widgets.Accordion([widgets.VBox(widgets_to_collapse)])
@@ -107,11 +108,10 @@ def ipympl_canvas(figure):
class ExposureModelResult(View):
def __init__(self):
- self.figure = matplotlib.figure.Figure(figsize=(9, 5))
+ self.figure = matplotlib.figure.Figure(figsize=(9, 6))
ipympl_canvas(self.figure)
self.html_output = widgets.HTML()
self.ax = self.figure.add_subplot(1, 1, 1)
- self.figure.subplots_adjust(left=8, right=9)
self.ax2 = self.ax.twinx()
self.concentration_line = None
self.concentration_area = None
@@ -124,6 +124,15 @@ class ExposureModelResult(View):
self.figure.canvas,
])
+ def initialize_axes(self) -> matplotlib.figure.Axes:
+ ax = self.figure.add_subplot(1, 1, 1)
+ ax.spines['right'].set_visible(False)
+ ax.spines['top'].set_visible(False)
+ ax.set_xlabel('Time (hours)')
+ ax.set_ylabel('Concentration ($virions/m^{3}$)')
+ ax.set_title('Concentration of virions')
+ return ax
+
def update(self, model: models.ExposureModel):
self.update_plot(model)
self.update_textual_result(model)
@@ -150,7 +159,7 @@ class ExposureModelResult(View):
ax.set_xlabel('Time (hours)')
ax.set_ylabel('Mean concentration ($virions/m^{3}$)')
- ax.set_title('Concentration of virions and Cumulative dose')
+ ax.set_title('Concentration of virions \nand Cumulative dose')
#cursor = SnaptoCursor(self.ax, ts, concentration)
@@ -191,15 +200,14 @@ class ExposureModelResult(View):
cumulative_top = max([1e-5, max(cumulative_doses)])
self.ax2.set_ylim(bottom=0., top=cumulative_top)
- self.ax.set_xlim(left = min(model.concentration_model.infected.presence.present_times[0]), right = max(model.concentration_model.infected.presence.present_times[1]))
+ self.ax.set_xlim(left = min(min(model.concentration_model.infected.presence.present_times[0]), min(model.exposed.presence.present_times[0])), right = max(max(model.concentration_model.infected.presence.present_times[1]), max(model.exposed.presence.present_times[1])))
+
+ figure_legends = [mlines.Line2D([], [], color='#3530fe', markersize=15, label='Mean concentration'),
+ mlines.Line2D([], [], color='#0000c8', markersize=15, ls="dotted", label='Cumulative dose'),
+ patches.Patch(edgecolor="#96cbff", facecolor='#96cbff', label='Presence of exposed person(s)')]
+ self.figure.legend(handles=figure_legends)
- legend = self.ax.legend(bbox_to_anchor=(1.15, 1), frameon=False)
- self.ax2.legend(bbox_to_anchor=(1.15, 0.89), frameon=False)
- #sself.marker=plt.connect('motion_notify_event', mouse_move)
-
self.figure.canvas.draw()
- self.figure.tight_layout()
- return legend
def update_textual_result(self, model: models.ExposureModel):
lines = []
@@ -215,7 +223,7 @@ class ExposureModelResult(View):
R0 = np.round(np.array(model.reproduction_number()).mean(), 1)
lines.append(f'Reproduction number (R0): {R0}')
- self.html_output.value = '
\n'.join(lines)
+ self.html_output.value = '
\n'.join(lines)
class ExposureComparissonResult(View):
@@ -223,7 +231,7 @@ class ExposureComparissonResult(View):
self.figure = matplotlib.figure.Figure(figsize=(9, 6))
ipympl_canvas(self.figure)
self.html_output = widgets.HTML()
- self.ax = self.figure.add_subplot(2, 1, 2)
+ self.ax = self.figure.add_subplot(1, 1, 1)
self.ax2 = self.ax.twinx()
self.concentration_line = None
self.cumulative_line = None
@@ -672,16 +680,17 @@ class ModelWidgets(View):
def on_expiration_change(change):
expiration = models.Expiration.types[change['new']]
node.dcs_update_from(expiration)
+
expiration_choice.observe(on_expiration_change, names=['value'])
return widgets.HBox([widgets.Label("Expiration"), expiration_choice], layout=widgets.Layout(justify_content='space-between'))
def _build_viral_load(self, node):
-
- viral_load_in_sputum = widgets.IntText(value=node.viral_load_in_sputum, PlaceHolder='1e9')
+ viral_load_in_sputum = widgets.Text(continuous_update=False, value=("{:.2e}".format(node.viral_load_in_sputum)))
def viral_load_change(change):
- node.viral_load_in_sputum = change['new']
+ viral_load_in_sputum.value = "{:.2e}".format(float(change['new']))
+ node.viral_load_in_sputum = float(viral_load_in_sputum.value)
viral_load_in_sputum.observe(viral_load_change, names=['value'])