From 39fc9d8e96f326e56d48ecc3e033af871084c1c5 Mon Sep 17 00:00:00 2001 From: Luis Aleixo Date: Fri, 8 Mar 2024 10:43:48 +0100 Subject: [PATCH] handled way to visualize custom value types (namely constant values); handled display of conditional probability plot --- caimira/apps/calculator/report_generator.py | 45 +++++++---- .../templates/base/calculator.report.html.j2 | 11 +-- caimira/monte_carlo/data.py | 79 +++++++++---------- caimira/store/data_registry.py | 7 -- 4 files changed, 76 insertions(+), 66 deletions(-) diff --git a/caimira/apps/calculator/report_generator.py b/caimira/apps/calculator/report_generator.py index 53f67249..d0288b03 100644 --- a/caimira/apps/calculator/report_generator.py +++ b/caimira/apps/calculator/report_generator.py @@ -19,6 +19,7 @@ from caimira.store.data_registry import DataRegistry from ... import monte_carlo as mc from .model_generator import VirusFormData from ... import dataclass_utils +from caimira.enums import ViralLoads def model_start_end(model: models.ExposureModel): @@ -168,12 +169,27 @@ def calculate_report_data(form: VirusFormData, model: models.ExposureModel, exec prob_dist_count, prob_dist_bins = np.histogram(prob/100, bins=100, density=True) prob_probabilistic_exposure = np.array(model.total_probability_rule()).mean() expected_new_cases = np.array(model.expected_new_cases()).mean() - uncertainties_plot_src = img2base64(_figure2bytes(uncertainties_plot(model, prob))) if form.conditional_probability_plot else None exposed_presence_intervals = [list(interval) for interval in model.exposed.presence_interval().boundaries()] - conditional_probability_data = {key: value for key, value in - zip(('viral_loads', 'pi_means', 'lower_percentiles', 'upper_percentiles'), - manufacture_conditional_probability_data(model, prob))} + if (model.data_registry.virological_data['virus_distributions'][form.virus_type]['viral_load_in_sputum'] == ViralLoads.COVID_OVERALL.value # type: ignore + and form.conditional_probability_plot): # Only generate this data if covid_overall_vl_data is selected. + + viral_load_in_sputum: models._VectorisedFloat = model.concentration_model.infected.virus.viral_load_in_sputum + viral_loads, pi_means, lower_percentiles, upper_percentiles = manufacture_conditional_probability_data(model, prob) + + uncertainties_plot_src = img2base64(_figure2bytes(uncertainties_plot(prob, viral_load_in_sputum, viral_loads, + pi_means, lower_percentiles, upper_percentiles))) + conditional_probability_data = {key: value for key, value in + zip(('viral_loads', 'pi_means', 'lower_percentiles', 'upper_percentiles'), + (viral_loads, pi_means, lower_percentiles, upper_percentiles))} + vl_dist = list(np.log10(viral_load_in_sputum)) + + else: + uncertainties_plot_src = None + conditional_probability_data = None + vl = model.concentration_model.virus.viral_load_in_sputum + if isinstance(vl, np.ndarray): vl_dist = list(np.log10(model.concentration_model.virus.viral_load_in_sputum)) + else: vl_dist = np.log10(model.concentration_model.virus.viral_load_in_sputum) return { "model_repr": repr(model), @@ -194,7 +210,7 @@ def calculate_report_data(form: VirusFormData, model: models.ExposureModel, exec "expected_new_cases": expected_new_cases, "uncertainties_plot_src": uncertainties_plot_src, "CO2_concentrations": CO2_concentrations, - "vl_dist": list(np.log10(model.concentration_model.virus.viral_load_in_sputum)), + "vl_dist": vl_dist, "conditional_probability_data": conditional_probability_data, } @@ -256,11 +272,12 @@ def manufacture_conditional_probability_data( return list(viral_loads), list(pi_means), list(lower_percentiles), list(upper_percentiles) -def uncertainties_plot(exposure_model: models.ExposureModel, prob: models._VectorisedFloat): - fig = plt.figure(figsize=(4, 7), dpi=110) - - infection_probability = prob / 100 - viral_loads, pi_means, lower_percentiles, upper_percentiles = manufacture_conditional_probability_data(exposure_model, infection_probability) +def uncertainties_plot(infection_probability: models._VectorisedFloat, + viral_load_in_sputum: models._VectorisedFloat, + viral_loads: models._VectorisedFloat, + pi_means: models._VectorisedFloat, + lower_percentiles: models._VectorisedFloat, + upper_percentiles: models._VectorisedFloat): fig, axs = plt.subplots(2, 3, gridspec_kw={'width_ratios': [5, 0.5] + [1], @@ -273,8 +290,8 @@ def uncertainties_plot(exposure_model: models.ExposureModel, prob: models._Vecto axs[0, 1].set_visible(False) - axs[0, 0].plot(viral_loads, pi_means, label='Predictive total probability') - axs[0, 0].fill_between(viral_loads, lower_percentiles, upper_percentiles, alpha=0.1, label='5ᵗʰ and 95ᵗʰ percentile') + axs[0, 0].plot(viral_loads, np.array(pi_means)/100, label='Predictive total probability') + axs[0, 0].fill_between(viral_loads, np.array(lower_percentiles)/100, np.array(upper_percentiles)/100, alpha=0.1, label='5ᵗʰ and 95ᵗʰ percentile') axs[0, 2].hist(infection_probability, bins=30, orientation='horizontal') axs[0, 2].set_xticks([]) @@ -285,8 +302,8 @@ def uncertainties_plot(exposure_model: models.ExposureModel, prob: models._Vecto axs[0, 2].set_xlim(0, highest_bar) axs[0, 2].text(highest_bar * 0.5, 0.5, - rf"$\bf{np.round(np.mean(infection_probability) * 100, 1)}$%", ha='center', va='center') - axs[1, 0].hist(np.log10(exposure_model.concentration_model.infected.virus.viral_load_in_sputum), + rf"$\bf{np.round(np.mean(infection_probability), 1)}$%", ha='center', va='center') + axs[1, 0].hist(np.log10(viral_load_in_sputum), bins=150, range=(2, 10), color='grey') axs[1, 0].set_facecolor("lightgrey") axs[1, 0].set_yticks([]) diff --git a/caimira/apps/templates/base/calculator.report.html.j2 b/caimira/apps/templates/base/calculator.report.html.j2 index 76c3d3f9..041cf3ef 100644 --- a/caimira/apps/templates/base/calculator.report.html.j2 +++ b/caimira/apps/templates/base/calculator.report.html.j2 @@ -214,11 +214,12 @@ draw_histogram("prob_inf_hist", {{ prob_inf }}, {{ prob_inf_sd }});
- -
- - -
+ {% if model.data_registry.virological_data['virus_distributions'][form.virus_type]['viral_load_in_sputum'] == 'Ref: Viral load - covid_overal_vl_data' %} +
+ + +
+ {% endif %} {% if form.conditional_probability_plot %}
diff --git a/caimira/monte_carlo/data.py b/caimira/monte_carlo/data.py index 8958fe5e..d6e523b0 100644 --- a/caimira/monte_carlo/data.py +++ b/caimira/monte_carlo/data.py @@ -14,25 +14,31 @@ from caimira.monte_carlo.sampleable import LogCustom, LogNormal, Normal, LogCust from caimira.store.data_registry import DataRegistry -def evaluate_vl(value, data_registry: DataRegistry): - if value == ViralLoads.COVID_OVERALL.value: +def evaluate_vl(root: typing.Dict, value: str, data_registry: DataRegistry): + if root[value] == ViralLoads.COVID_OVERALL.value: return covid_overal_vl_data(data_registry) - elif value == ViralLoads.SYMPTOMATIC_FREQUENCIES.value: + elif root[value] == ViralLoads.SYMPTOMATIC_FREQUENCIES.value: return symptomatic_vl_frequencies + elif root[value] == 'Custom': + return param_evaluation(root, 'Viral load custom') else: raise ValueError(f"Invalid ViralLoads value {value}") -def evaluate_infectd(value, data_registry: DataRegistry): - if value == InfectiousDoses.DISTRIBUTION.value: +def evaluate_infectd(root: typing.Dict, value: str, data_registry: DataRegistry): + if root[value] == InfectiousDoses.DISTRIBUTION.value: return infectious_dose_distribution(data_registry) + elif root[value] == "Custom": + return param_evaluation(root, 'Infectious dose custom') else: raise ValueError(f"Invalid InfectiousDoses value {value}") -def evaluate_vtrr(value, data_registry: DataRegistry): - if value == ViableToRNARatios.DISTRIBUTION.value: +def evaluate_vtrr(root: typing.Dict, value: str, data_registry: DataRegistry): + if root[value] == ViableToRNARatios.DISTRIBUTION.value: return viable_to_RNA_ratio_distribution(data_registry) + elif root[value] == "Custom": + return param_evaluation(root, 'Viable to RNA ratio custom') else: raise ValueError(f"Invalid ViableToRNARatios value {value}") @@ -60,7 +66,7 @@ def custom_value_type_lookup(dict: dict, key_part: str) -> typing.Any: return f"Key '{key_part}' not found." -def evaluate_custom_value_type(dist: str, params: typing.Dict) -> typing.Any: +def evaluate_custom_value_type(value_type: str, params: typing.Dict) -> typing.Any: """ Evaluate a custom value type. @@ -75,13 +81,13 @@ def evaluate_custom_value_type(dist: str, params: typing.Dict) -> typing.Any: ValueError: If the value type is not recognized. """ - if dist == 'Constant': + if value_type == 'Constant value': return params - elif dist == 'Normal distribution': + elif value_type == 'Normal distribution': return Normal(params['normal_mean_gaussian'], params['normal_standard_deviation_gaussian']) - elif dist == 'Log-normal distribution': + elif value_type == 'Log-normal distribution': return LogNormal(params['lognormal_mean_gaussian'], params['lognormal_standard_deviation_gaussian']) - elif dist == 'Uniform distribution': + elif value_type == 'Uniform distribution': return Uniform(params['low'], params['high']) else: raise ValueError('Bad request - value type not found.') @@ -104,17 +110,10 @@ def param_evaluation(root: typing.Dict, param: typing.Union[str, typing.Any]) -> """ value = root.get(param) - if isinstance(value, str): - if value == 'Custom': - custom_value_type: typing.Dict = custom_value_type_lookup( - root, 'custom distribution') - for d, p in custom_value_type.items(): - return evaluate_custom_value_type(d, p) - - elif isinstance(value, dict): - dist: str = root[param]['associated_value'] + if isinstance(value, dict): + value_type: str = root[param]['associated_value'] params: typing.Dict = root[param]['parameters'] - return evaluate_custom_value_type(dist, params) + return evaluate_custom_value_type(value_type, params) elif isinstance(value, float) or isinstance(value, int): return value @@ -290,39 +289,39 @@ def virus_distributions(data_registry): vd = data_registry.virological_data['virus_distributions'] return { 'SARS_CoV_2': mc.SARSCoV2( - viral_load_in_sputum=evaluate_vl(vd['SARS_CoV_2']['viral_load_in_sputum'], data_registry), - infectious_dose=evaluate_infectd(vd['SARS_CoV_2']['infectious_dose'], data_registry), - viable_to_RNA_ratio=evaluate_vtrr(vd['SARS_CoV_2']['viable_to_RNA_ratio'], data_registry), + viral_load_in_sputum=evaluate_vl(vd['SARS_CoV_2'], 'viral_load_in_sputum', data_registry), + infectious_dose=evaluate_infectd(vd['SARS_CoV_2'], 'infectious_dose', data_registry), + viable_to_RNA_ratio=evaluate_vtrr(vd['SARS_CoV_2'], 'viable_to_RNA_ratio', data_registry), transmissibility_factor=vd['SARS_CoV_2']['transmissibility_factor'], ), 'SARS_CoV_2_ALPHA': mc.SARSCoV2( - viral_load_in_sputum=evaluate_vl(vd['SARS_CoV_2_ALPHA']['viral_load_in_sputum'], data_registry), - infectious_dose=evaluate_infectd(vd['SARS_CoV_2_ALPHA']['infectious_dose'], data_registry), - viable_to_RNA_ratio=evaluate_vtrr(vd['SARS_CoV_2_ALPHA']['viable_to_RNA_ratio'], data_registry), + viral_load_in_sputum=evaluate_vl(vd['SARS_CoV_2_ALPHA'], 'viral_load_in_sputum', data_registry), + infectious_dose=evaluate_infectd(vd['SARS_CoV_2_ALPHA'], 'infectious_dose', data_registry), + viable_to_RNA_ratio=evaluate_vtrr(vd['SARS_CoV_2_ALPHA'], 'viable_to_RNA_ratio', data_registry), transmissibility_factor=vd['SARS_CoV_2_ALPHA']['transmissibility_factor'], ), 'SARS_CoV_2_BETA': mc.SARSCoV2( - viral_load_in_sputum=evaluate_vl(vd['SARS_CoV_2_BETA']['viral_load_in_sputum'], data_registry), - infectious_dose=evaluate_infectd(vd['SARS_CoV_2_BETA']['infectious_dose'], data_registry), - viable_to_RNA_ratio=evaluate_vtrr(vd['SARS_CoV_2_BETA']['viable_to_RNA_ratio'], data_registry), + viral_load_in_sputum=evaluate_vl(vd['SARS_CoV_2_BETA'], 'viral_load_in_sputum', data_registry), + infectious_dose=evaluate_infectd(vd['SARS_CoV_2_BETA'], 'infectious_dose', data_registry), + viable_to_RNA_ratio=evaluate_vtrr(vd['SARS_CoV_2_BETA'], 'viable_to_RNA_ratio', data_registry), transmissibility_factor=vd['SARS_CoV_2_BETA']['transmissibility_factor'], ), 'SARS_CoV_2_GAMMA': mc.SARSCoV2( - viral_load_in_sputum=evaluate_vl(vd['SARS_CoV_2_GAMMA']['viral_load_in_sputum'], data_registry), - infectious_dose=evaluate_infectd(vd['SARS_CoV_2_GAMMA']['infectious_dose'], data_registry), - viable_to_RNA_ratio=evaluate_vtrr(vd['SARS_CoV_2_GAMMA']['viable_to_RNA_ratio'], data_registry), + viral_load_in_sputum=evaluate_vl(vd['SARS_CoV_2_GAMMA'], 'viral_load_in_sputum', data_registry), + infectious_dose=evaluate_infectd(vd['SARS_CoV_2_GAMMA'], 'infectious_dose', data_registry), + viable_to_RNA_ratio=evaluate_vtrr(vd['SARS_CoV_2_GAMMA'], 'viable_to_RNA_ratio', data_registry), transmissibility_factor=vd['SARS_CoV_2_GAMMA']['transmissibility_factor'], ), 'SARS_CoV_2_DELTA': mc.SARSCoV2( - viral_load_in_sputum=evaluate_vl(vd['SARS_CoV_2_DELTA']['viral_load_in_sputum'], data_registry), - infectious_dose=evaluate_infectd(vd['SARS_CoV_2_DELTA']['infectious_dose'], data_registry), - viable_to_RNA_ratio=evaluate_vtrr(vd['SARS_CoV_2_DELTA']['viable_to_RNA_ratio'], data_registry), + viral_load_in_sputum=evaluate_vl(vd['SARS_CoV_2_DELTA'], 'viral_load_in_sputum', data_registry), + infectious_dose=evaluate_infectd(vd['SARS_CoV_2_DELTA'], 'infectious_dose', data_registry), + viable_to_RNA_ratio=evaluate_vtrr(vd['SARS_CoV_2_DELTA'], 'viable_to_RNA_ratio', data_registry), transmissibility_factor=vd['SARS_CoV_2_DELTA']['transmissibility_factor'], ), 'SARS_CoV_2_OMICRON': mc.SARSCoV2( - viral_load_in_sputum=evaluate_vl(vd['SARS_CoV_2_OMICRON']['viral_load_in_sputum'], data_registry), - infectious_dose=evaluate_infectd(vd['SARS_CoV_2_OMICRON']['infectious_dose'], data_registry), - viable_to_RNA_ratio=evaluate_vtrr(vd['SARS_CoV_2_OMICRON']['viable_to_RNA_ratio'], data_registry), + viral_load_in_sputum=evaluate_vl(vd['SARS_CoV_2_OMICRON'], 'viral_load_in_sputum', data_registry), + infectious_dose=evaluate_infectd(vd['SARS_CoV_2_OMICRON'], 'infectious_dose', data_registry), + viable_to_RNA_ratio=evaluate_vtrr(vd['SARS_CoV_2_OMICRON'], 'viable_to_RNA_ratio', data_registry), transmissibility_factor=vd['SARS_CoV_2_OMICRON']['transmissibility_factor'], ), } diff --git a/caimira/store/data_registry.py b/caimira/store/data_registry.py index da3da011..59e4026b 100644 --- a/caimira/store/data_registry.py +++ b/caimira/store/data_registry.py @@ -261,13 +261,6 @@ class DataRegistry: "transmissibility_factor": 0.2, "infectiousness_days": 14, }, - "SARS_CoV_2_Other": { - "viral_load_in_sputum": ViralLoads.COVID_OVERALL.value, - "infectious_dose": InfectiousDoses.DISTRIBUTION.value, - "viable_to_RNA_ratio": ViableToRNARatios.DISTRIBUTION.value, - "transmissibility_factor": 0.1, - "infectiousness_days": 14, - }, }, }