From b5419f6f3fd8407f5344a9a5cc1a56c06abc3ca6 Mon Sep 17 00:00:00 2001 From: markus Date: Tue, 23 Feb 2021 17:43:27 +0100 Subject: [PATCH] add composite_plot_pi_vs_viral_load --- cara/montecarlo.py | 86 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/cara/montecarlo.py b/cara/montecarlo.py index 333b9b33..e659cbc2 100644 --- a/cara/montecarlo.py +++ b/cara/montecarlo.py @@ -608,6 +608,92 @@ def plot_pi_vs_viral_load(baselines: typing.Union[MCExposureModel, typing.List[M plt.show() +def composite_plot_pi_vs_viral_load(baselines: typing.List[MCExposureModel], labels: typing.List[str], + colors: typing.List[str], samples_per_vl: int = 2000, vl_points: int = 200, + title: str = 'Probability of infection vs viral load', show_lines: bool = True) -> None: + viral_loads = np.linspace(1, 12, vl_points) + lines, lowers, uppers = [], [], [] + for baseline in baselines: + infected = baseline.concentration_model.infected + pi_means = [] + lower_percentiles = [] + upper_percentiles = [] + for viral_load in tqdm(viral_loads): + model = MCExposureModel(concentration_model=MCConcentrationModel( + room=baseline.concentration_model.room, + ventilation=baseline.concentration_model.ventilation, + infected=MCInfectedPopulation( + number=infected.number, + presence=infected.presence, + masked=infected.masked, + expiratory_activity=infected.expiratory_activity, + breathing_category=infected.breathing_category, + virus=infected.virus, + samples=samples_per_vl, + viral_load=viral_load, + expiratory_activity_weights=infected.expiratory_activity_weights + ) + ), + exposed=baseline.exposed) + + infection_probabilities = model.infection_probability()/100 + pi_means.append(np.mean(infection_probabilities)) + lower_percentiles.append(np.quantile(infection_probabilities, 0.01)) + upper_percentiles.append(np.quantile(infection_probabilities, 0.99)) + + lines.append(pi_means) + uppers.append(upper_percentiles) + lowers.append(lower_percentiles) + + histogram_data = [model.infection_probability() / 100 for model in baselines] + + fig, axs = plt.subplots(2, 2 + len(baselines), gridspec_kw={'width_ratios': [5, 0.5] + [1] * len(baselines), + 'height_ratios': [3, 1], 'wspace': 0}, + sharey='row', sharex='col') + + for y, x in [(0, 1)] + [(1, i + 1) for i in range(len(baselines) + 1)]: + axs[y, x].axis('off') + + for x in range(len(baselines) - 1): + axs[0, x + 3].tick_params(axis='y', which='both', left='off') + + axs[0, 1].set_visible(False) + + for line, upper, lower, label, color in zip(lines, uppers, lowers, labels, colors): + axs[0, 0].plot(viral_loads, line, label=label, color=color) + axs[0, 0].fill_between(viral_loads, lower, upper, alpha=0.2, color=color) + + for i, (data, color) in enumerate(zip(histogram_data, colors)): + axs[0, i + 2].hist(data, bins=30, orientation='horizontal', color=color) + axs[0, i + 2].set_xticks([]) + axs[0, i + 2].set_xticklabels([]) + + axs[1, 0].hist(baselines[0].concentration_model.infected._generate_viral_loads(), bins=30, range=(1, 12)) + axs[1, 0].set_yticks([]) + axs[1, 0].set_yticklabels([]) + axs[1, 0].set_xticks([i for i in range(1, 13, 2)]) + axs[1, 0].set_xticklabels(['$10^{' + str(i) + '}$' for i in range(1, 13, 2)]) + axs[1, 0].set_xlabel('Viral load') + axs[0, 0].set_ylabel('Probability of infection') + plt.suptitle(title) + + if show_lines: + middle_positions = [] + for line in lines: + for i, point in enumerate(line): + if point > 0.5: + middle_positions.append(viral_loads[i]) + break + + axs[0, 0].vlines(middle_positions, colors=colors, linestyles=['dashed']*2, ymin=axs[0, 0].get_ylim()[0], + ymax=axs[0, 0].get_ylim()[1]) + axs[1, 0].vlines(middle_positions, colors=colors, linestyles=['dashed']*2, ymin=0, ymax=axs[1, 0].get_ylim()[1]) + + axs[0, 0].legend() + plt.show() + + + def plot_pi_vs_qid(baselines: typing.Union[MCExposureModel, typing.List[MCExposureModel]], samples_per_qid: int = 20000, title: str = 'Probability of infection vs qID', labels: typing.List[str] = None, qid_min: float = 5, qid_max: float = 2000, qid_samples: int = 200) -> None: