From 386b0892731ba80a5b0d0a752c8754c7e2af41f8 Mon Sep 17 00:00:00 2001 From: markus Date: Mon, 22 Feb 2021 15:56:36 +0100 Subject: [PATCH] make vline positions dynamic --- cara/montecarlo.py | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/cara/montecarlo.py b/cara/montecarlo.py index b4f9fb39..bdb4d2b3 100644 --- a/cara/montecarlo.py +++ b/cara/montecarlo.py @@ -522,7 +522,8 @@ def plot_pi_vs_viral_load(baselines: typing.Union[MCExposureModel, typing.List[M if isinstance(baselines, MCExposureModel): baselines = [baselines] - viral_loads = np.linspace(3, 12, 200) + points = 200 + viral_loads = np.linspace(3, 12, points) for baseline in baselines: infected = baseline.concentration_model.infected @@ -562,16 +563,30 @@ def plot_pi_vs_viral_load(baselines: typing.Union[MCExposureModel, typing.List[M plt.xticks(ticks=[i for i in range(3, 13)], labels=['$10^{' + str(i) + '}$' for i in range(3, 13)]) plt.xlabel('Viral load') # add vertical lines for the critical viral loads for which pi= 5 or 95 - # TODO Insert viral_load(Pi = 5) and viral_load(Pi = 95) instead of hard coded values - # 7.8 and 9.5 - #plt.vlines(x=(7.8, 9.5), ymin=0, ymax=1, - # colors=('grey', 'grey'), linestyles='dotted') - #plt.text(6.7, 0.80, '$vl_{crit1}$', fontsize=12,color='black') - #plt.text(9.6, 0.80, '$vl_{crit2}$', fontsize=12,color='black') - # add 3 shaded areas - #plt.axvspan(3, 7.8, alpha=0.1, color='limegreen') - #plt.axvspan(7.8, 9.5, alpha=0.1, color='orange') - #plt.axvspan(9.5, 12, alpha=0.1, color='tomato') + + if len(baselines) == 1: + left_index, right_index = 0, 0 + for i, pi in enumerate(pi_means): + if pi > 0.05: + left_index = i + break + + for i, pi in enumerate(pi_means[::-1]): + if pi < 0.95: + right_index = points - i + break + + left, right = viral_loads[left_index], viral_loads[right_index] + + plt.vlines(x=(left, right), ymin=0, ymax=1, + colors=('grey', 'grey'), linestyles='dotted') + plt.text(left - 1.1, 0.80, '$vl_{crit1}$', fontsize=12,color='black') + plt.text(right + 0.1, 0.80, '$vl_{crit2}$', fontsize=12,color='black') + # add 3 shaded areas + plt.axvspan(3, left, alpha=0.1, color='limegreen') + plt.axvspan(left, right, alpha=0.1, color='orange') + plt.axvspan(right, 12, alpha=0.1, color='tomato') + if labels is not None: plt.legend(labels) # this is an inset plot inside the main plot