From 3defcc63f3d7f5328911239b98337139ad3d02cd Mon Sep 17 00:00:00 2001 From: markus Date: Thu, 11 Feb 2021 17:15:07 +0100 Subject: [PATCH] allow overlayed plots --- cara/montecarlo.py | 74 ++++++++++++++++++++++++++-------------------- 1 file changed, 42 insertions(+), 32 deletions(-) diff --git a/cara/montecarlo.py b/cara/montecarlo.py index 6dc3739a..81a26f6b 100644 --- a/cara/montecarlo.py +++ b/cara/montecarlo.py @@ -500,45 +500,55 @@ def present_model(model: MCConcentrationModel, bins: int = 200, plt.show() -def plot_pi_vs_viral_load(baseline: MCExposureModel, samples_per_vl: int = 20000, - title: str = 'Probability of infection vs viral load') -> None: - infected = baseline.concentration_model.infected +def plot_pi_vs_viral_load(baselines: typing.Union[MCExposureModel, typing.List[MCExposureModel]], + samples_per_vl: int = 20000, title: str = 'Probability of infection vs viral load', + labels: typing.List[str] = None) -> None: + if isinstance(baselines, MCExposureModel): + baselines = [baselines] + viral_loads = np.linspace(3, 12, 200) - pi_means = [] - pi_medians = [] - lower_percentiles = [] - upper_percentiles = [] - for viral_load in tqdm(viral_loads): - model = MCExposureModel(concentration_model=MCConcentrationModel( - room=baseline.concentration_model.room, - ventilation=baseline.concentration_model.ventilation, - infected=MCInfectedPopulation( - number=infected.number, - presence=infected.presence, - masked=infected.masked, - expiratory_activity=infected.expiratory_activity, - breathing_category=infected.breathing_category, - virus=infected.virus, - samples=samples_per_vl, - qid=infected.qid, - english_variant=infected.english_variant, - viral_load=viral_load - ) - ), - exposed=baseline.exposed) - infection_probabilities = model.infection_probability() - pi_means.append(np.mean(infection_probabilities)) - pi_medians.append(np.median(infection_probabilities)) - lower_percentiles.append(np.quantile(infection_probabilities, 0.01)) - upper_percentiles.append(np.quantile(infection_probabilities, 0.99)) + for baseline in baselines: + infected = baseline.concentration_model.infected + pi_means = [] + pi_medians = [] + lower_percentiles = [] + upper_percentiles = [] + for viral_load in tqdm(viral_loads): + model = MCExposureModel(concentration_model=MCConcentrationModel( + room=baseline.concentration_model.room, + ventilation=baseline.concentration_model.ventilation, + infected=MCInfectedPopulation( + number=infected.number, + presence=infected.presence, + masked=infected.masked, + expiratory_activity=infected.expiratory_activity, + breathing_category=infected.breathing_category, + virus=infected.virus, + samples=samples_per_vl, + qid=infected.qid, + english_variant=infected.english_variant, + viral_load=viral_load + ) + ), + exposed=baseline.exposed) + + infection_probabilities = model.infection_probability() + pi_means.append(np.mean(infection_probabilities)) + pi_medians.append(np.median(infection_probabilities)) + lower_percentiles.append(np.quantile(infection_probabilities, 0.01)) + upper_percentiles.append(np.quantile(infection_probabilities, 0.99)) + + plt.plot(viral_loads, pi_means) + plt.fill_between(viral_loads, lower_percentiles, upper_percentiles, alpha=0.2) - plt.plot(viral_loads, pi_medians) - plt.fill_between(viral_loads, lower_percentiles, upper_percentiles, alpha=0.2) plt.title(title) plt.ylabel('Percentage probability of infection') plt.xticks(ticks=[i for i in range(3, 13)], labels=['$10^{' + str(i) + '}$' for i in range(3, 13)]) plt.xlabel('Viral load in sputum') + if labels is not None: + plt.legend(labels) + plt.show()