added fix to mpl subplots after 3.9.1 release

This commit is contained in:
lrdossan 2024-07-12 14:30:28 +02:00
parent f5e1ea728a
commit 22b4abe464

View file

@ -275,49 +275,47 @@ def uncertainties_plot(infection_probability: models._VectorisedFloat,
upper_percentiles: list = conditional_probability_data['upper_percentiles']
log10_vl_in_sputum: list = conditional_probability_data['log10_vl_in_sputum']
fig, axes = plt.subplots(2, 3,
fig, ((axs00, axs01, axs02), (axs10, axs11, axs12)) = plt.subplots(nrows=2, ncols=3, # type: ignore
gridspec_kw={'width_ratios': [5, 0.5] + [1],
'height_ratios': [3, 1], 'wspace': 0},
sharey='row',
sharex='col')
# Type hint for axs
axs: np.ndarray = np.array(axes)
axs01.axis('off')
axs11.axis('off')
axs12.axis('off')
for y, x in [(0, 1)] + [(1, i + 1) for i in range(2)]:
axs[y, x].axis('off')
axs01.set_visible(False)
axs[0, 1].set_visible(False)
axs00.plot(viral_loads, np.array(pi_means), label='Predictive total probability')
axs00.fill_between(viral_loads, np.array(lower_percentiles), np.array(upper_percentiles), alpha=0.1, label='5ᵗʰ and 95ᵗʰ percentile')
axs[0, 0].plot(viral_loads, np.array(pi_means), label='Predictive total probability')
axs[0, 0].fill_between(viral_loads, np.array(lower_percentiles), np.array(upper_percentiles), alpha=0.1, label='5ᵗʰ and 95ᵗʰ percentile')
axs02.hist(infection_probability, bins=30, orientation='horizontal')
axs02.set_xticks([])
axs02.set_xticklabels([])
axs02.set_facecolor("lightgrey")
axs[0, 2].hist(infection_probability, bins=30, orientation='horizontal')
axs[0, 2].set_xticks([])
axs[0, 2].set_xticklabels([])
axs[0, 2].set_facecolor("lightgrey")
highest_bar = axs02.get_xlim()[1]
axs02.set_xlim(0, highest_bar)
highest_bar = axs[0, 2].get_xlim()[1]
axs[0, 2].set_xlim(0, highest_bar)
axs[0, 2].text(highest_bar * 0.5, 50,
axs02.text(highest_bar * 0.5, 50,
"$P(I)=$\n" + rf"$\bf{np.round(np.mean(infection_probability), 1)}$%", ha='center', va='center')
axs[1, 0].hist(log10_vl_in_sputum,
axs10.hist(log10_vl_in_sputum,
bins=150, range=(2, 10), color='grey')
axs[1, 0].set_facecolor("lightgrey")
axs[1, 0].set_yticks([])
axs[1, 0].set_yticklabels([])
axs[1, 0].set_xticks([i for i in range(2, 13, 2)])
axs[1, 0].set_xticklabels(['$10^{' + str(i) + '}$' for i in range(2, 13, 2)])
axs[1, 0].set_xlim(2, 10)
axs[1, 0].set_xlabel('Viral load\n(RNA copies)', fontsize=12)
axs[0, 0].set_ylabel('Conditional Probability\nof Infection', fontsize=12)
axs10.set_facecolor("lightgrey")
axs10.set_yticks([])
axs10.set_yticklabels([])
axs10.set_xticks([i for i in range(2, 13, 2)])
axs10.set_xticklabels(['$10^{' + str(i) + '}$' for i in range(2, 13, 2)])
axs10.set_xlim(2, 10)
axs10.set_xlabel('Viral load\n(RNA copies)', fontsize=12)
axs00.set_ylabel('Conditional Probability\nof Infection', fontsize=12)
axs[0, 0].text(9.5, -0.01, '$(i)$')
axs[1, 0].text(9.5, axs[1, 0].get_ylim()[1] * 0.8, '$(ii)$')
axs[0, 2].set_title('$(iii)$', fontsize=10)
axs00.text(9.5, -0.01, '$(i)$')
axs10.text(9.5, axs10.get_ylim()[1] * 0.8, '$(ii)$')
axs02.set_title('$(iii)$', fontsize=10)
axs[0, 0].legend()
axs00.legend()
return fig