diff --git a/cara/montecarlo.py b/cara/montecarlo.py index c0a792f9..d24cf251 100644 --- a/cara/montecarlo.py +++ b/cara/montecarlo.py @@ -5,6 +5,7 @@ import numpy as np import scipy.stats as sct import typing import matplotlib.pyplot as plt +import matplotlib.patches as patches from sklearn.neighbors import KernelDensity USE_SCOEH = False @@ -327,6 +328,48 @@ def print_qr_info(qr_values: np.ndarray) -> None: print(f"qR_{quantile} = {np.quantile(qr_values, quantile)}") +def present_model(model: MCConcentrationModel, bins: int = 30) -> None: + fig, axs = plt.subplots(2, 2, sharex=False, sharey=False) + fig.set_figheight(8) + fig.set_figwidth(10) + fig.suptitle('Summary of model parameters') + plt.tight_layout() + plt.subplots_adjust(hspace=0.2) + fig.set_figheight(10) + + for x, y in ((0, 0), (0, 1), (1, 0), (1, 1)): + axs[x, y].set_yticklabels([]) + axs[x, y].set_yticks([]) + + for data, (x, y) in zip((model.infected._generate_viral_loads(), + model.infected._generate_breathing_rates(), + np.log10(model.infected.emission_rate_when_present())), + ((0, 0), (1, 0), (1, 1))): + axs[x, y].hist(data, bins=bins) + top = axs[x, y].get_ylim()[1] + mean, median, std = np.mean(data), np.median(data), np.std(data) + axs[x, y].vlines(x=(mean, median, mean - std, mean + std), ymin=0, ymax=top, + colors=('red', 'green', 'pink', 'pink')) + + axs[0, 0].set_title('Viral load in sputum') + axs[0, 0].set_xlabel('Viral load [log10(RNA copies / mL)]') + + categories = ("seated", "standing", "light exercise", "moderate exercise", "heavy exercise") + axs[1, 0].set_title(f'Breathing rate - ' + f'{categories[model.infected.breathing_category - 1]}') + axs[1, 0].set_xlabel('Breathing rate [m^3 / h]') + + axs[1, 1].set_title('qR') + axs[1, 1].set_xlabel('qR [log10(RNA copies / h)]') + + mean_patch = patches.Patch(color='red', label='Mean') + median_patch = patches.Patch(color='green', label='Median') + std_patch = patches.Patch(color='pink', label='Standard deviations') + fig.legend(handles=(mean_patch, median_patch, std_patch)) + + plt.show() + + def buaonanno_exposure_model(): return MCExposureModel( concentration_model=MCConcentrationModel(