Handle arrays in the cache of the model by using the memoization library. Speeds up the single model case too.
This commit is contained in:
parent
bb9b8657c6
commit
9b836a6507
5 changed files with 26 additions and 5 deletions
|
|
@ -31,10 +31,16 @@ the same for all parameters of a single model.
|
|||
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
import functools
|
||||
import numpy as np
|
||||
import typing
|
||||
|
||||
if not typing.TYPE_CHECKING:
|
||||
from memoization import cached
|
||||
else:
|
||||
# Workaround issue https://github.com/lonelyenvoy/python-memoization/issues/18
|
||||
# by providing a no-op cache decorator when type-checking.
|
||||
cached = lambda *cached_args, **cached_kwargs: lambda function: function # noqa
|
||||
|
||||
from .dataclass_utils import nested_replace
|
||||
|
||||
|
||||
|
|
@ -593,7 +599,7 @@ class InfectedPopulation(Population):
|
|||
|
||||
return self.emission_rate_when_present()
|
||||
|
||||
@functools.lru_cache()
|
||||
@cached()
|
||||
def emission_rate(self, time) -> float:
|
||||
"""
|
||||
The emission rate of the entire population.
|
||||
|
|
@ -622,7 +628,7 @@ class ConcentrationModel:
|
|||
|
||||
return k + self.virus.decay_constant + self.ventilation.air_exchange(self.room, time)
|
||||
|
||||
@functools.lru_cache()
|
||||
@cached()
|
||||
def state_change_times(self):
|
||||
"""
|
||||
All time dependent entities on this model must provide information about
|
||||
|
|
@ -645,8 +651,11 @@ class ConcentrationModel:
|
|||
return change_time
|
||||
return 0
|
||||
|
||||
@functools.lru_cache()
|
||||
@cached()
|
||||
def concentration(self, time: float) -> float:
|
||||
# Note that time is not vectorised. You can only pass a single float
|
||||
# to this method.
|
||||
|
||||
if time == 0:
|
||||
return 0.0
|
||||
IVRR = self.infectious_virus_removal_rate(time)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from cara.apps.calculator import model_generator
|
||||
|
|
@ -15,10 +17,18 @@ def baseline_form(baseline_form_data):
|
|||
|
||||
|
||||
def test_generate_report(baseline_form):
|
||||
model = baseline_form.build_model()
|
||||
# This is a simple test that confirms that given a model, we can actually
|
||||
# generate a report for it. Because this is what happens in the cara
|
||||
# calculator, we confirm that the generation happens within a reasonable
|
||||
# time threshold.
|
||||
time_limit: float = 1.5 # seconds
|
||||
|
||||
start = time.perf_counter()
|
||||
model = baseline_form.build_model()
|
||||
report = report_generator.build_report("", model, baseline_form)
|
||||
end = time.perf_counter()
|
||||
assert report != ""
|
||||
assert end - start < time_limit
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
0
cara/tests/models/__init__.py
Normal file
0
cara/tests/models/__init__.py
Normal file
|
|
@ -37,6 +37,7 @@ jupyterlab-widgets==1.0.0
|
|||
kiwisolver==1.3.1
|
||||
MarkupSafe==1.1.1
|
||||
matplotlib==3.3.4
|
||||
memoization==0.3.2
|
||||
mistune==0.8.4
|
||||
nbclient==0.5.2
|
||||
nbconvert==6.0.7
|
||||
|
|
|
|||
1
setup.py
1
setup.py
|
|
@ -22,6 +22,7 @@ REQUIREMENTS: dict = {
|
|||
'ipywidgets',
|
||||
'Jinja2',
|
||||
'matplotlib',
|
||||
'memoization',
|
||||
'mistune',
|
||||
'numpy',
|
||||
'qrcode[pil]',
|
||||
|
|
|
|||
Loading…
Reference in a new issue