From d7039a958a3320b62150d1ecc7ad26e9ed11224f Mon Sep 17 00:00:00 2001 From: Luis Aleixo Date: Wed, 15 Feb 2023 16:56:40 +0100 Subject: [PATCH] added method to generate P(I|vl) uncertainties graphs --- caimira/apps/calculator/report_generator.py | 112 ++++++++++++++++++ .../templates/base/calculator.report.html.j2 | 2 + 2 files changed, 114 insertions(+) diff --git a/caimira/apps/calculator/report_generator.py b/caimira/apps/calculator/report_generator.py index f1471519..5795e1b1 100644 --- a/caimira/apps/calculator/report_generator.py +++ b/caimira/apps/calculator/report_generator.py @@ -10,6 +10,7 @@ import zlib import jinja2 import numpy as np +import matplotlib.pyplot as plt from caimira import models from caimira.apps.calculator import markdown_tools @@ -157,6 +158,7 @@ def calculate_report_data(form: FormData, model: models.ExposureModel) -> typing "emission_rate": er, "exposed_occupants": exposed_occupants, "expected_new_cases": expected_new_cases, + "uncertainties_plot_scr": img2base64(_figure2bytes(uncertainties_plot([model]))) } @@ -179,6 +181,109 @@ def generate_permalink(base_url, get_root_url, get_root_calculator_url, form: F } +def uncertainties_plot(exposure_models): + from tqdm import tqdm + fig = plt.figure(figsize=(7, 10)) + viral_loads = np.linspace(2, 10, 600) + + lines, lowers, uppers = [], [], [] + for exposure_mc in exposure_models: + concentration_model = exposure_mc.concentration_model + pi_means = [] + lower_percentiles = [] + upper_percentiles = [] + + for vl in tqdm(viral_loads): + model_vl = dataclass_utils.replace(exposure_mc, + concentration_model = models.ConcentrationModel( + room=concentration_model.room, + ventilation=concentration_model.ventilation, + infected=models.InfectedPopulation( + number=concentration_model.infected.number, + presence=concentration_model.infected.presence, + virus = models.SARSCoV2( + viral_load_in_sputum=10**vl, + infectious_dose=concentration_model.infected.virus.infectious_dose, + viable_to_RNA_ratio=concentration_model.infected.virus.viable_to_RNA_ratio, + transmissibility_factor=0.2, + ), + mask=concentration_model.infected.mask, + activity=concentration_model.infected.activity, + expiration=concentration_model.infected.expiration, + host_immunity=concentration_model.infected.host_immunity, + ) + ), + ) + + pi = model_vl.infection_probability()/100 + pi_means.append(np.mean(pi)) + lower_percentiles.append(np.quantile(pi, 0.05)) + upper_percentiles.append(np.quantile(pi, 0.95)) + + lines.append(pi_means) + uppers.append(upper_percentiles) + lowers.append(lower_percentiles) + + # print(model.concentration_model.infected.virus) + histogram_data = [model.infection_probability() / 100 for model in exposure_models] + + fig, axs = plt.subplots(2, 2 + len(exposure_models), gridspec_kw={'width_ratios': [5, 0.5] + [1] * len(exposure_models), + 'height_ratios': [3, 1], 'wspace': 0}, + sharey='row', sharex='col') + + for y, x in [(0, 1)] + [(1, i + 1) for i in range(len(exposure_models) + 1)]: + axs[y, x].axis('off') + + for x in range(len(exposure_models) - 1): + axs[0, x + 3].tick_params(axis='y', which='both', left='off') + + axs[0, 1].set_visible(False) + + for line, upper, lower in zip(lines, uppers, lowers): + axs[0, 0].plot(viral_loads, line, label='Predictive total probability') + axs[0, 0].fill_between(viral_loads, lower, upper, alpha=0.1, label='5ᵗʰ and 95ᵗʰ percentile') + + for i, data in enumerate(histogram_data): + axs[0, i + 2].hist(data, bins=30, orientation='horizontal') + axs[0, i + 2].set_xticks([]) + axs[0, i + 2].set_xticklabels([]) + # axs[0, i + 2].set_xlabel(f"{np.round(np.mean(data) * 100, 1)}%") + axs[0, i + 2].set_facecolor("lightgrey") + + highest_bar = max(axs[0, i + 2].get_xlim()[1] for i in range(len(histogram_data))) + for i in range(len(histogram_data)): + axs[0, i + 2].set_xlim(0, highest_bar) + + axs[0, i + 2].text(highest_bar * 0.5, 0.5, + rf"$\bf{np.round(np.mean(histogram_data[i]) * 100, 1)}$%", ha='center', va='center') + + axs[1, 0].hist([np.log10(vl) for vl in exposure_models[0].concentration_model.infected.virus.viral_load_in_sputum], + bins=150, range=(2, 10), color='grey') + axs[1, 0].set_facecolor("lightgrey") + axs[1, 0].set_yticks([]) + axs[1, 0].set_yticklabels([]) + axs[1, 0].set_xticks([i for i in range(2, 13, 2)]) + axs[1, 0].set_xticklabels(['$10^{' + str(i) + '}$' for i in range(2, 13, 2)]) + axs[1, 0].set_xlim(2, 10) + axs[1, 0].set_xlabel('Viral load\n(RNA copies)', fontsize=12) + axs[0, 0].set_ylabel('Probability of infection\nfor a given viral load', fontsize=12) + + axs[0, 0].text(9.5, -0.01, '$(i)$') + axs[1, 0].text(9.5, axs[1, 0].get_ylim()[1] * 0.8, '$(ii)$') + #axs[0, 2].text(axs[0, 2].get_xlim()[1] * 0.1, -0.05, '$(iii)$') + axs[0, 2].set_title('$(iii)$', fontsize=10) + + crits = [] + for line in lines: + for i, point in enumerate(line): + if point >= 0.05: + crits.append(viral_loads[i]) + break + + axs[0, 0].legend() + return fig + + def _img2bytes(figure): # Draw the image img_data = io.BytesIO() @@ -186,6 +291,13 @@ def _img2bytes(figure): return img_data +def _figure2bytes(figure): + # Draw the image + img_data = io.BytesIO() + figure.savefig(img_data, format='png', bbox_inches="tight", transparent=True) + return img_data + + def img2base64(img_data) -> str: img_data.seek(0) pic_hash = base64.b64encode(img_data.read()).decode('ascii') diff --git a/caimira/apps/templates/base/calculator.report.html.j2 b/caimira/apps/templates/base/calculator.report.html.j2 index 1bd08286..1ba6f720 100644 --- a/caimira/apps/templates/base/calculator.report.html.j2 +++ b/caimira/apps/templates/base/calculator.report.html.j2 @@ -193,6 +193,8 @@ + + {% if form.short_range_option == "short_range_no" %}
Alternative scenarios