diff --git a/cara/models.py b/cara/models.py index 961ac6b7..041c208b 100644 --- a/cara/models.py +++ b/cara/models.py @@ -38,7 +38,7 @@ import typing import numpy as np from scipy.interpolate import interp1d import scipy.stats as sct -from sklearn.neighbors import KernelDensity +from sklearn.neighbors import KernelDensity # type: ignore if not typing.TYPE_CHECKING: @@ -476,8 +476,9 @@ Virus.types = { ), } -#@cached -def _generate_virus_distribution(samples: int, qID: float=100) -> Virus: +@cached +def _generate_virus_distribution(params: typing.Tuple[int, float]) -> Virus: + samples , qID = params log_symptomatic_vl_frequencies = ((2.46032, 2.67431, 2.85434, 3.06155, 3.25856, 3.47256, 3.66957, 3.85979, 4.09927, 4.27081, 4.47631, 4.66653, 4.87204, 5.10302, 5.27456, 5.46478, 5.6533, 5.88428, 6.07281, 6.30549, 6.48552, 6.64856, 6.85407, 7.10373, 7.30075, 7.47229, 7.66081, 7.85782, 8.05653, 8.27053, @@ -499,9 +500,9 @@ def _generate_virus_distribution(samples: int, qID: float=100) -> Virus: ) Virus.distributions = { - 'SARS_CoV_2': lambda n: _generate_virus_distribution(n, qID=100), - 'SARS_CoV_2_B117': lambda n: _generate_virus_distribution(n, qID=60), - 'SARS_CoV_2_P1': lambda n: _generate_virus_distribution(n, qID=100/2.25), + 'SARS_CoV_2': lambda n: _generate_virus_distribution((n, 100)), + 'SARS_CoV_2_B117': lambda n: _generate_virus_distribution((n, 60)), + 'SARS_CoV_2_P1': lambda n: _generate_virus_distribution((n, 100/2.25)), }