diff --git a/cara/montecarlo.py b/cara/montecarlo.py index c331086f..1afe56eb 100644 --- a/cara/montecarlo.py +++ b/cara/montecarlo.py @@ -1138,8 +1138,8 @@ def compare_viruses_qr(violins: bool = True) -> None: line_color = (0.8, 0.8, 0.8) whisker_width = 0.8 - positions = [1, 2, 3, 5, 7, 9, 11] - line_positions = [4, 6, 8, 10] + positions = [1, 2, 3, 6, 8, 10, 12] + line_positions = [5, 7, 9, 11] ranges = [(28, 28), (15, 128), (480, 5580), (1, 30000)] log_ranges = [(np.log10(x), np.log10(y)) for x, y in ranges] data = [(x, x, y, y) for x, y in log_ranges] @@ -1156,9 +1156,9 @@ def compare_viruses_qr(violins: bool = True) -> None: qrs = [np.log10(pop.emission_rate_when_present()) for pop in infected_populations] fig, ax = plt.subplots() - ax.set_xlim((0, 12)) + ax.set_xlim((0, 13)) bp = ax.boxplot(data, patch_artist=True, medianprops={'linewidth': 0}, whiskerprops={'linewidth': 0}, - positions=positions[3:], widths=[0.8]*4) + positions=positions[3:], widths=[0.6]*4) for patch, color in zip(bp['boxes'], colors[3:]): patch.set(facecolor=color) @@ -1166,10 +1166,16 @@ def compare_viruses_qr(violins: bool = True) -> None: if violins: parts = ax.violinplot(qrs, quantiles=[(0.05, 0.95) for _ in qrs], showextrema=False) means = [np.log10(np.mean(10 ** qr)) for qr in qrs] + perc80 = [np.log10(np.quantile(10 ** qr, 0.80)) for qr in qrs] ax.hlines(y=means, xmin=[pos - whisker_width / 2 for pos in positions[:3]], xmax=[pos + whisker_width / 2 for pos in positions[:3]], colors=colors[:3]) + ax.hlines(y=perc80, + xmin=[pos - whisker_width / 3 for pos in positions[:3]], + xmax=[pos + whisker_width / 3 for pos in positions[:3]], + colors=colors[:3], + linestyles=(0, (3, 1, 1, 1))) for pc, color, bc in zip(parts['bodies'], pastels, border_colors): pc.set_facecolor(color) pc.set_edgecolor(bc) @@ -1184,13 +1190,14 @@ def compare_viruses_qr(violins: bool = True) -> None: ax.vlines(x=line_positions, ymin=ax.get_ylim()[0], ymax=ax.get_ylim()[1], colors=[line_color for _ in line_positions]) - ax.set_xticks([2, 5, 7, 9, 11]) + ax.set_xticks([2.5, 6, 8, 10, 12]) ax.set_xticklabels(['SARS-CoV-2', 'SARS-CoV', 'Influenza', 'Measles', 'Tuberculosis']) - ax.hlines(y=[np.log10(970)], linestyles=['dashed'], colors=['red'], xmin=0, xmax=4) + ax.hlines(y=[np.log10(970), np.log10(45)], linestyles=[(0,(1, 1))], colors=['dodgerblue', 'crimson'], xmin=0, xmax=5) handles = [patches.Patch(color=c, label=l) for c, l in zip([p + (0.3,) for p in pastels], ('Breathing', 'Speaking', 'Shouting'))] - handles += [mlines.Line2D([], [], linestyle='dashed', color='red', label='S V Chorale\n(qR=970)')] - plt.legend(handles=handles, loc='lower left', bbox_to_anchor=(0.12, 0.01)) + handles += [mlines.Line2D([], [], linestyle=(0, (1, 1)), color='dodgerblue', label='S V Chorale (qR=970)')] + handles += [mlines.Line2D([], [], linestyle=(0, (1, 1)), color='crimson', label='Bus ride (qR=45)')] + plt.legend(handles=handles, loc='lower left', bbox_to_anchor=(0.10, 0.01)) ax.set_yticks([i for i in range(-6, 7, 2)]) ax.set_yticklabels(['$10^{' + str(i) + '}$' for i in range(-6, 7, 2)])