added method to generate P(I|vl) uncertainties graphs

This commit is contained in:
Luis Aleixo 2023-02-15 16:56:40 +01:00
parent 188ed7a531
commit d7039a958a
2 changed files with 114 additions and 0 deletions

View file

@ -10,6 +10,7 @@ import zlib
import jinja2
import numpy as np
import matplotlib.pyplot as plt
from caimira import models
from caimira.apps.calculator import markdown_tools
@ -157,6 +158,7 @@ def calculate_report_data(form: FormData, model: models.ExposureModel) -> typing
"emission_rate": er,
"exposed_occupants": exposed_occupants,
"expected_new_cases": expected_new_cases,
"uncertainties_plot_scr": img2base64(_figure2bytes(uncertainties_plot([model])))
}
@ -179,6 +181,109 @@ def generate_permalink(base_url, get_root_url, get_root_calculator_url, form: F
}
def uncertainties_plot(exposure_models):
from tqdm import tqdm
fig = plt.figure(figsize=(7, 10))
viral_loads = np.linspace(2, 10, 600)
lines, lowers, uppers = [], [], []
for exposure_mc in exposure_models:
concentration_model = exposure_mc.concentration_model
pi_means = []
lower_percentiles = []
upper_percentiles = []
for vl in tqdm(viral_loads):
model_vl = dataclass_utils.replace(exposure_mc,
concentration_model = models.ConcentrationModel(
room=concentration_model.room,
ventilation=concentration_model.ventilation,
infected=models.InfectedPopulation(
number=concentration_model.infected.number,
presence=concentration_model.infected.presence,
virus = models.SARSCoV2(
viral_load_in_sputum=10**vl,
infectious_dose=concentration_model.infected.virus.infectious_dose,
viable_to_RNA_ratio=concentration_model.infected.virus.viable_to_RNA_ratio,
transmissibility_factor=0.2,
),
mask=concentration_model.infected.mask,
activity=concentration_model.infected.activity,
expiration=concentration_model.infected.expiration,
host_immunity=concentration_model.infected.host_immunity,
)
),
)
pi = model_vl.infection_probability()/100
pi_means.append(np.mean(pi))
lower_percentiles.append(np.quantile(pi, 0.05))
upper_percentiles.append(np.quantile(pi, 0.95))
lines.append(pi_means)
uppers.append(upper_percentiles)
lowers.append(lower_percentiles)
# print(model.concentration_model.infected.virus)
histogram_data = [model.infection_probability() / 100 for model in exposure_models]
fig, axs = plt.subplots(2, 2 + len(exposure_models), gridspec_kw={'width_ratios': [5, 0.5] + [1] * len(exposure_models),
'height_ratios': [3, 1], 'wspace': 0},
sharey='row', sharex='col')
for y, x in [(0, 1)] + [(1, i + 1) for i in range(len(exposure_models) + 1)]:
axs[y, x].axis('off')
for x in range(len(exposure_models) - 1):
axs[0, x + 3].tick_params(axis='y', which='both', left='off')
axs[0, 1].set_visible(False)
for line, upper, lower in zip(lines, uppers, lowers):
axs[0, 0].plot(viral_loads, line, label='Predictive total probability')
axs[0, 0].fill_between(viral_loads, lower, upper, alpha=0.1, label='5ᵗʰ and 95ᵗʰ percentile')
for i, data in enumerate(histogram_data):
axs[0, i + 2].hist(data, bins=30, orientation='horizontal')
axs[0, i + 2].set_xticks([])
axs[0, i + 2].set_xticklabels([])
# axs[0, i + 2].set_xlabel(f"{np.round(np.mean(data) * 100, 1)}%")
axs[0, i + 2].set_facecolor("lightgrey")
highest_bar = max(axs[0, i + 2].get_xlim()[1] for i in range(len(histogram_data)))
for i in range(len(histogram_data)):
axs[0, i + 2].set_xlim(0, highest_bar)
axs[0, i + 2].text(highest_bar * 0.5, 0.5,
rf"$\bf{np.round(np.mean(histogram_data[i]) * 100, 1)}$%", ha='center', va='center')
axs[1, 0].hist([np.log10(vl) for vl in exposure_models[0].concentration_model.infected.virus.viral_load_in_sputum],
bins=150, range=(2, 10), color='grey')
axs[1, 0].set_facecolor("lightgrey")
axs[1, 0].set_yticks([])
axs[1, 0].set_yticklabels([])
axs[1, 0].set_xticks([i for i in range(2, 13, 2)])
axs[1, 0].set_xticklabels(['$10^{' + str(i) + '}$' for i in range(2, 13, 2)])
axs[1, 0].set_xlim(2, 10)
axs[1, 0].set_xlabel('Viral load\n(RNA copies)', fontsize=12)
axs[0, 0].set_ylabel('Probability of infection\nfor a given viral load', fontsize=12)
axs[0, 0].text(9.5, -0.01, '$(i)$')
axs[1, 0].text(9.5, axs[1, 0].get_ylim()[1] * 0.8, '$(ii)$')
#axs[0, 2].text(axs[0, 2].get_xlim()[1] * 0.1, -0.05, '$(iii)$')
axs[0, 2].set_title('$(iii)$', fontsize=10)
crits = []
for line in lines:
for i, point in enumerate(line):
if point >= 0.05:
crits.append(viral_loads[i])
break
axs[0, 0].legend()
return fig
def _img2bytes(figure):
# Draw the image
img_data = io.BytesIO()
@ -186,6 +291,13 @@ def _img2bytes(figure):
return img_data
def _figure2bytes(figure):
# Draw the image
img_data = io.BytesIO()
figure.savefig(img_data, format='png', bbox_inches="tight", transparent=True)
return img_data
def img2base64(img_data) -> str:
img_data.seek(0)
pic_hash = base64.b64encode(img_data.read()).decode('ascii')

View file

@ -193,6 +193,8 @@
</div>
</div>
<img src= "{{ uncertainties_plot_scr }}" />
{% if form.short_range_option == "short_range_no" %}
<div class="card bg-light mb-3">
<div class="card-header"><strong>Alternative scenarios</strong>