Handle arrays in the cache of the model by using the memoization library. Speeds up the single model case too.

This commit is contained in:
Phil Elson 2021-03-28 06:46:34 +02:00
parent bb9b8657c6
commit 9b836a6507
5 changed files with 26 additions and 5 deletions

View file

@ -31,10 +31,16 @@ the same for all parameters of a single model.
"""
from dataclasses import dataclass
import functools
import numpy as np
import typing
if not typing.TYPE_CHECKING:
from memoization import cached
else:
# Workaround issue https://github.com/lonelyenvoy/python-memoization/issues/18
# by providing a no-op cache decorator when type-checking.
cached = lambda *cached_args, **cached_kwargs: lambda function: function # noqa
from .dataclass_utils import nested_replace
@ -593,7 +599,7 @@ class InfectedPopulation(Population):
return self.emission_rate_when_present()
@functools.lru_cache()
@cached()
def emission_rate(self, time) -> float:
"""
The emission rate of the entire population.
@ -622,7 +628,7 @@ class ConcentrationModel:
return k + self.virus.decay_constant + self.ventilation.air_exchange(self.room, time)
@functools.lru_cache()
@cached()
def state_change_times(self):
"""
All time dependent entities on this model must provide information about
@ -645,8 +651,11 @@ class ConcentrationModel:
return change_time
return 0
@functools.lru_cache()
@cached()
def concentration(self, time: float) -> float:
# Note that time is not vectorised. You can only pass a single float
# to this method.
if time == 0:
return 0.0
IVRR = self.infectious_virus_removal_rate(time)

View file

@ -1,3 +1,5 @@
import time
import pytest
from cara.apps.calculator import model_generator
@ -15,10 +17,18 @@ def baseline_form(baseline_form_data):
def test_generate_report(baseline_form):
model = baseline_form.build_model()
# This is a simple test that confirms that given a model, we can actually
# generate a report for it. Because this is what happens in the cara
# calculator, we confirm that the generation happens within a reasonable
# time threshold.
time_limit: float = 1.5 # seconds
start = time.perf_counter()
model = baseline_form.build_model()
report = report_generator.build_report("", model, baseline_form)
end = time.perf_counter()
assert report != ""
assert end - start < time_limit
@pytest.mark.parametrize(

View file

View file

@ -37,6 +37,7 @@ jupyterlab-widgets==1.0.0
kiwisolver==1.3.1
MarkupSafe==1.1.1
matplotlib==3.3.4
memoization==0.3.2
mistune==0.8.4
nbclient==0.5.2
nbconvert==6.0.7

View file

@ -22,6 +22,7 @@ REQUIREMENTS: dict = {
'ipywidgets',
'Jinja2',
'matplotlib',
'memoization',
'mistune',
'numpy',
'qrcode[pil]',