caching _generate_virus_distribution and ignoring mypy type checking for KernelDensity

This commit is contained in:
Nicolas Mounet 2021-05-14 12:49:17 +02:00
parent 64565264ea
commit 2e659da45d

View file

@ -38,7 +38,7 @@ import typing
import numpy as np
from scipy.interpolate import interp1d
import scipy.stats as sct
from sklearn.neighbors import KernelDensity
from sklearn.neighbors import KernelDensity # type: ignore
if not typing.TYPE_CHECKING:
@ -476,8 +476,9 @@ Virus.types = {
),
}
#@cached
def _generate_virus_distribution(samples: int, qID: float=100) -> Virus:
@cached
def _generate_virus_distribution(params: typing.Tuple[int, float]) -> Virus:
samples , qID = params
log_symptomatic_vl_frequencies = ((2.46032, 2.67431, 2.85434, 3.06155, 3.25856, 3.47256, 3.66957, 3.85979, 4.09927, 4.27081,
4.47631, 4.66653, 4.87204, 5.10302, 5.27456, 5.46478, 5.6533, 5.88428, 6.07281, 6.30549,
6.48552, 6.64856, 6.85407, 7.10373, 7.30075, 7.47229, 7.66081, 7.85782, 8.05653, 8.27053,
@ -499,9 +500,9 @@ def _generate_virus_distribution(samples: int, qID: float=100) -> Virus:
)
Virus.distributions = {
'SARS_CoV_2': lambda n: _generate_virus_distribution(n, qID=100),
'SARS_CoV_2_B117': lambda n: _generate_virus_distribution(n, qID=60),
'SARS_CoV_2_P1': lambda n: _generate_virus_distribution(n, qID=100/2.25),
'SARS_CoV_2': lambda n: _generate_virus_distribution((n, 100)),
'SARS_CoV_2_B117': lambda n: _generate_virus_distribution((n, 60)),
'SARS_CoV_2_P1': lambda n: _generate_virus_distribution((n, 100/2.25)),
}