From ccc07cbb26a65ad58f4b20a0e3bf05ba922c7388 Mon Sep 17 00:00:00 2001 From: markus Date: Tue, 2 Feb 2021 14:51:12 +0100 Subject: [PATCH] use dedicated functions for sampling --- cara/montecarlo.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/cara/montecarlo.py b/cara/montecarlo.py index 256efd61..08b09c4b 100644 --- a/cara/montecarlo.py +++ b/cara/montecarlo.py @@ -94,12 +94,21 @@ class MCInfectedPopulation(models.Population): viral_load: typing.Optional[float] = None + @functools.lru_cache() def _generate_viral_loads(self) -> np.ndarray: - kde_model = KernelDensity(kernel='gaussian', bandwidth=0.1) - kde_model.fit(np.asarray(log_viral_load_frequencies)[0, :][:, np.newaxis], - sample_weight=np.asarray(log_viral_load_frequencies)[1, :]) + if self.viral_load is None: + kde_model = KernelDensity(kernel='gaussian', bandwidth=0.1) + kde_model.fit(np.asarray(log_viral_load_frequencies)[0, :][:, np.newaxis], + sample_weight=np.asarray(log_viral_load_frequencies)[1, :]) - return kde_model.sample(n_samples=self.samples)[:, 0] + return kde_model.sample(n_samples=self.samples)[:, 0] + else: + return np.full(self.samples, self.viral_load) + + @functools.lru_cache() + def _generate_breathing_rates(self) -> np.ndarray: + br_params = lognormal_parameters[self.breathing_category - 1] + (self.samples,) + return lognormal(*br_params) def emission_rate_when_present(self) -> np.ndarray: """ @@ -110,18 +119,14 @@ class MCInfectedPopulation(models.Population): # Extracting only the needed information from the pre-existing Mask class masked = self.mask.exhale_efficiency != 0 - if self.viral_load is None: - viral_loads = self._generate_viral_loads() - else: - viral_loads = np.full(self.samples, self.viral_load) + viral_loads = self._generate_viral_loads() emission_concentration = emission_concentrations[self.expiratory_activity - 1] mask_efficiency = [0.75, 0.81, 0.81][self.expiratory_activity - 1] if masked else 0 qr_func = np.vectorize(self._calculate_qr) - br_params = lognormal_parameters[self.breathing_category - 1] + (self.samples,) - breathing_rates = lognormal(*br_params) + breathing_rates = self._generate_breathing_rates() return qr_func(viral_loads, emission_concentration, mask_efficiency, self.qid, breathing_rates)