From 2d8d03230e6c97f7bd250e4be9d04acb16487dd2 Mon Sep 17 00:00:00 2001 From: markus Date: Thu, 11 Feb 2021 12:50:10 +0100 Subject: [PATCH] add customizable plot titles --- cara/montecarlo.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/cara/montecarlo.py b/cara/montecarlo.py index a742448c..1502ec1d 100644 --- a/cara/montecarlo.py +++ b/cara/montecarlo.py @@ -418,17 +418,19 @@ def print_qr_info(log_qr: np.ndarray) -> None: print(f"qR_{quantile} = {np.quantile(qr_values, quantile)}") -def present_model(model: MCConcentrationModel, bins: int = 200) -> None: +def present_model(model: MCConcentrationModel, bins: int = 200, + title: str = 'Summary of model parameters') -> None: """ Displays a number of plots and prints a handful of key parameters and results of a given MCConcentrationModel :param model: The MCConcentrationModel representing the scenario to be presented :param bins: The number of bins (bars) to use for the histograms + :param title: A string giving the title at the top of the generated plot :return: Nothing, graphs are displayed and parameters are printed """ fig, axs = plt.subplots(2, 2, sharex=False, sharey=False) fig.set_figheight(8) fig.set_figwidth(10) - fig.suptitle('Summary of model parameters') + fig.suptitle(title) plt.tight_layout() plt.subplots_adjust(hspace=0.4) plt.subplots_adjust(wspace=0.2) @@ -498,7 +500,8 @@ def present_model(model: MCConcentrationModel, bins: int = 200) -> None: plt.show() -def plot_pi_vs_viral_load(baseline: MCExposureModel, samples_per_vl: int = 20000) -> None: +def plot_pi_vs_viral_load(baseline: MCExposureModel, samples_per_vl: int = 20000, + title: str = 'Probability of infection vs viral load') -> None: infected = baseline.concentration_model.infected viral_loads = np.linspace(3, 12, 200) pi_means = [] @@ -532,7 +535,7 @@ def plot_pi_vs_viral_load(baseline: MCExposureModel, samples_per_vl: int = 20000 plt.plot(viral_loads, pi_medians) plt.fill_between(viral_loads, lower_percentiles, upper_percentiles, alpha=0.2) - plt.title('Probability of infection vs viral load') + plt.title(title) plt.ylabel('Percentage probability of infection') plt.xticks(ticks=[i for i in range(3, 13)], labels=['$10^{' + str(i) + '}$' for i in range(3, 13)]) plt.xlabel('Viral load in sputum')