From 68c1eaee29515f98944eae000cd414ad7f345fd8 Mon Sep 17 00:00:00 2001 From: markus Date: Wed, 3 Mar 2021 13:49:42 +0100 Subject: [PATCH] add plot_pi_vs_exposure_time --- cara/montecarlo.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/cara/montecarlo.py b/cara/montecarlo.py index ebf6c317..e5252769 100644 --- a/cara/montecarlo.py +++ b/cara/montecarlo.py @@ -1188,3 +1188,55 @@ def compare_viruses_qr(violins: bool = True) -> None: plt.tight_layout() plt.show() + + +def plot_pi_vs_exposure_time(exp_models: typing.List[MCExposureModel], labels: typing.List[str], + colors: typing.Optional[typing.List] = None, + linestyles: typing.Optional[typing.List[str]] = None, + points: int = 50, time_in_minutes: bool = False) -> None: + conc_models = [m.concentration_model for m in exp_models] + if colors is None: + colors = [None for _ in exp_models] + if linestyles is None: + linestyles = ['solid' for _ in exp_models] + + presence_intervals = [m.exposed.presence.boundaries() for m in exp_models] + all_equal = True + first_interval = presence_intervals[0] + for interval in presence_intervals[1:]: + if interval != first_interval or len(interval) > 1: + all_equal = False + break + + assert all_equal, \ + "The presence intervals of the exposed populations must match and be single intervals of the form ((start, stop),)" + + pis = [[] for _ in exp_models] + + start, final = first_interval[0] + times = np.linspace(start, final, points) + for finish in tqdm(times): + current_models = [MCExposureModel( + concentration_model=cm, + exposed=models.Population( + number=em.exposed.number, + presence=models.SpecificInterval(((start, finish), )), + activity=em.exposed.activity, + mask=em.exposed.mask + ) + ) for cm, em in zip(conc_models, exp_models)] + + for i, m in enumerate(current_models): + pis[i].append(np.mean(m.infection_probability()) / 100) + + if time_in_minutes: + times = [time * 60 for time in times] + + for i, pi in enumerate(pis): + plt.plot(times, pi, color=colors[i], linestyle=linestyles[i], label=labels[i]) + + plt.title('TITLE HERE') + plt.xlabel(f'XLABEL HERE ({"min" if time_in_minutes else "h"})') + plt.ylabel('YLABEL HERE') + plt.legend() + plt.show() \ No newline at end of file