From c13d57972a5795b7fac5ece12ddda9c351f9dbcf Mon Sep 17 00:00:00 2001 From: markus Date: Tue, 2 Mar 2021 16:22:09 +0100 Subject: [PATCH] add violin plots to compare_viruses_qr --- cara/mc-output.py | 14 ++++++------ cara/montecarlo.py | 57 +++++++++++++++++++++++++++++++++++----------- 2 files changed, 51 insertions(+), 20 deletions(-) diff --git a/cara/mc-output.py b/cara/mc-output.py index 30b1e8d2..42e5d1c6 100644 --- a/cara/mc-output.py +++ b/cara/mc-output.py @@ -12,7 +12,7 @@ from cara.model_scenarios import * #print(np.quantile(chorale_model.infection_probability(),0.1)) -compare_viruses_qr() +compare_viruses_qr(violins=True) # print_qd_info(large_population_baselines[0]) @@ -29,11 +29,11 @@ compare_viruses_qr() # title='Classroom scenario', # vl_points=200) -composite_plot_pi_vs_viral_load([ski_cabin_model_60[1], ski_cabin_model_30[1], ski_cabin_model_20[1], ski_cabin_model_10[1]], - labels=['60 min', '30 min', 'Baseline: 20 min', '10 min'], - colors=['tomato', 'lightsalmon', '#1f77b4', 'limegreen'], - title='Ski cabin scenario', - vl_points=200) +# composite_plot_pi_vs_viral_load([ski_cabin_model_60[1], ski_cabin_model_30[1], ski_cabin_model_20[1], ski_cabin_model_10[1]], +# labels=['60 min', '30 min', 'Baseline: 20 min', '10 min'], +# colors=['tomato', 'lightsalmon', '#1f77b4', 'limegreen'], +# title='Ski cabin scenario', +# vl_points=200) #compare_concentration_curves([classroom_model_no_vent[1], classroom_model[1], classroom_model_with_hepa[1], classroom_model_full_open_multi[1]], # labels=['Windows closed', 'Baseline:(windows 10min/2h)', 'Baseline:(windows 10min/2h) + HEPA', 'Multiple windows open'], @@ -61,7 +61,7 @@ composite_plot_pi_vs_viral_load([ski_cabin_model_60[1], ski_cabin_model_30[1], s # compare_infection_probabilities_vs_viral_loads(*exposure_models) # # -#present_model(exposure_models[0].concentration_model) +# present_model(exposure_models[0].concentration_model) # plot_pi_vs_qid(fixed_vl_exposure_models, labels=['Viral load = $10^{' + str(i) + '}$' for i in range(6, 11)], # qid_min=5, qid_max=2000, qid_samples=200) # diff --git a/cara/montecarlo.py b/cara/montecarlo.py index 03d6092d..fb646ce2 100644 --- a/cara/montecarlo.py +++ b/cara/montecarlo.py @@ -1114,42 +1114,73 @@ def print_qd_info(model: MCExposureModel) -> None: f"99th:\t{np.percentile(qds, 99)}\n") -def compare_viruses_qr() -> None: +def compare_viruses_qr(violins: bool = True) -> None: # A list of 7 colors corresponding to each of the boxes # Can be represented as hex-strings (e.g. '#FF0000') or tuples of numbers on the interval [0, 1] (e.g. (1, 0, 0)) colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] + [(x, x, x) for x in np.linspace(0.1, 0.9, 4)] + pastels = [x + (0.3, ) for x in colors[:3]] + + # The colors of the borders surrounding the violin plots + border_colors = [(0, 0, 0), (0, 0, 0), (0, 0, 0)] + line_color = (0.8, 0.8, 0.8) whisker_width = 0.8 positions = [1, 2, 3, 5, 7, 9, 11] line_positions = [4, 6, 8, 10] - ranges = [(1.5, 9.8), (114, 740), (574, 3678), (28, 28), (15, 128), (480, 5580), (1, 30000)] - data = [(x, x, y, y) for x, y in ranges] + ranges = [(28, 28), (15, 128), (480, 5580), (1, 30000)] + log_ranges = [(np.log10(x), np.log10(y)) for x, y in ranges] + data = [(x, x, y, y) for x, y in log_ranges] + + infected_populations = [MCInfectedPopulation( + number=1, + presence=models.SpecificInterval(((0, 5), )), + masked=False, + virus=MCVirus(halflife=1.1, qID=100), + expiratory_activity=e, + samples=2000000, + breathing_category=3, + ) for e in range(1, 4)] + qrs = [np.log10(pop.emission_rate_when_present()) for pop in infected_populations] fig, ax = plt.subplots() ax.set_xlim((0, 12)) - bp = ax.boxplot(data[3:], patch_artist=True, medianprops={'linewidth': 0}, whiskerprops={'linewidth': 0}, + bp = ax.boxplot(data, patch_artist=True, medianprops={'linewidth': 0}, whiskerprops={'linewidth': 0}, positions=positions[3:], widths=[0.8]*4) for patch, color in zip(bp['boxes'], colors[3:]): patch.set(facecolor=color) - ax.vlines(x=positions[:3], ymin=[x[0] for x in ranges[:3]], ymax=[x[1] for x in ranges[:3]], colors=colors[:3]) - ax.hlines(y=[x for r in ranges[:3] for x in r], - xmin=[pos - whisker_width / 2 for pos in positions[:3] for _ in range(2)], - xmax=[pos + whisker_width / 2 for pos in positions[:3] for _ in range(2)], - colors=[c for c in colors[:3] for _ in range(2)]) + if violins: + parts = ax.violinplot(qrs, quantiles=[(0.05, 0.95) for _ in qrs], showextrema=False) + means = [np.log10(np.mean(10 ** qr)) for qr in qrs] + ax.hlines(y=means, + xmin=[pos - whisker_width / 2 for pos in positions[:3]], + xmax=[pos + whisker_width / 2 for pos in positions[:3]], + colors=colors[:3]) + for pc, color, bc in zip(parts['bodies'], pastels, border_colors): + pc.set_facecolor(color) + pc.set_edgecolor(bc) + parts['cquantiles'].set_color([c for c in colors[:3] for _ in range(2)]) + else: + tops, bottoms = [np.quantile(x, 0.95) for x in qrs], [np.quantile(x, 0.05) for x in qrs] + ax.vlines(x=positions[:3], ymin=bottoms, ymax=tops, colors=colors[:3]) + ax.hlines(y=list(zip(tops, bottoms)), + xmin=[pos - whisker_width / 2 for pos in positions[:3] for _ in range(2)], + xmax=[pos + whisker_width / 2 for pos in positions[:3] for _ in range(2)], + colors=[c for c in colors[:3] for _ in range(2)]) ax.vlines(x=line_positions, ymin=ax.get_ylim()[0], ymax=ax.get_ylim()[1], colors=[line_color for _ in line_positions]) ax.set_xticks([2, 5, 7, 9, 11]) ax.set_xticklabels(['SARS-CoV-2', 'SARS-CoV', 'Influenza', 'Measles', 'Tuberculosis']) - ax.hlines(y=[970], linestyles=['dashed'], colors=['red'], xmin=0, xmax=4) + ax.hlines(y=[np.log10(970)], linestyles=['dashed'], colors=['red'], xmin=0, xmax=4) - plt.yscale('log') - handles = [patches.Patch(color=c, label=l) for c, l in zip(colors[:3], ('Breathing', 'Speaking', 'Shouting'))] + handles = [patches.Patch(color=c, label=l) for c, l in zip(pastels, ('Breathing', 'Speaking', 'Shouting'))] handles += [mlines.Line2D([], [], linestyle='dashed', color='red', label='Chorale')] - plt.legend(handles=handles, loc='upper left') + plt.legend(handles=handles, loc='lower left', bbox_to_anchor=(0.2, 0.03)) + ax.set_yticks([i for i in range(-6, 7, 2)]) + ax.set_yticklabels(['$10^{' + str(i) + '}$' for i in range(-6, 7, 2)]) plt.suptitle('SUPTITLE HERE') ax.set_xlabel('XLABEL HERE')