Merge branch 'optimisation/last_state_change_bisection' into 'master'
Use bisection method to compute the last state change for improved performance See merge request cara/cara!239
This commit is contained in:
commit
c745823768
2 changed files with 42 additions and 36 deletions
|
|
@ -762,26 +762,34 @@ class ConcentrationModel:
|
|||
|
||||
return (self.infected.emission_rate(time)) / (IVRR * V)
|
||||
|
||||
@method_cache
|
||||
def state_change_times(self) -> typing.List[float]:
|
||||
"""
|
||||
All time dependent entities on this model must provide information about
|
||||
the times at which their state changes.
|
||||
|
||||
"""
|
||||
state_change_times = set()
|
||||
state_change_times = {0.}
|
||||
state_change_times.update(self.infected.presence.transition_times())
|
||||
state_change_times.update(self.ventilation.transition_times())
|
||||
return sorted(state_change_times)
|
||||
|
||||
def last_state_change(self, time: float) -> float:
|
||||
"""
|
||||
Find the most recent state change.
|
||||
Find the most recent/previous state change.
|
||||
|
||||
Find the nearest time less than the given one. If there is a state
|
||||
change exactly at ``time`` the previous state change is returned
|
||||
(except at ``time == 0``).
|
||||
|
||||
"""
|
||||
for change_time in self.state_change_times()[::-1]:
|
||||
if change_time < time:
|
||||
return change_time
|
||||
return 0.
|
||||
times = self.state_change_times()
|
||||
t_index: int = np.searchsorted(times, time) # type: ignore
|
||||
# Search sorted gives us the index to insert the given time. Instead we
|
||||
# want to get the index of the most recent time, so reduce the index by
|
||||
# one unless we are already at 0.
|
||||
t_index = max([t_index - 1, 0])
|
||||
return times[t_index]
|
||||
|
||||
def _next_state_change(self, time: float) -> float:
|
||||
"""
|
||||
|
|
@ -796,14 +804,6 @@ class ConcentrationModel:
|
|||
f"state change time ({change_time})"
|
||||
)
|
||||
|
||||
def _is_interval_between_state_changes(self, start: float, stop: float) -> bool:
|
||||
"""
|
||||
Check that the times start and stop are in-between two state
|
||||
changes of the concentration model (to ensure sure that all
|
||||
model parameters stay constant between start and stop).
|
||||
"""
|
||||
return (self.last_state_change(stop) <= start)
|
||||
|
||||
@method_cache
|
||||
def _concentration_cached(self, time: float) -> _VectorisedFloat:
|
||||
# A cached version of the concentration method. Use this method if you
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ def test_concentration_model_vectorisation(override_params):
|
|||
|
||||
@pytest.fixture
|
||||
def simple_conc_model():
|
||||
interesting_times = models.SpecificInterval(([0., 1.], [1.1, 1.999], [2., 3.]), )
|
||||
interesting_times = models.SpecificInterval(([0.5, 1.], [1.1, 2], [2., 3.]), )
|
||||
return models.ConcentrationModel(
|
||||
models.Room(75),
|
||||
models.AirChange(interesting_times, 100),
|
||||
|
|
@ -68,14 +68,38 @@ def simple_conc_model():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"time, expected_last_state_change", [
|
||||
[-15., 0.], # Out of range goes to the first state.
|
||||
[0., 0.],
|
||||
[0.5, 0.0],
|
||||
[0.51, 0.5],
|
||||
[1., 0.5],
|
||||
[1.05, 1.],
|
||||
[1.1, 1.],
|
||||
[1.11, 1.1],
|
||||
[2., 1.1],
|
||||
[2.1, 2],
|
||||
[3., 2],
|
||||
[15., 3.], # Out of range goes to the last state.
|
||||
]
|
||||
)
|
||||
def test_last_state_change_time(
|
||||
simple_conc_model: models.ConcentrationModel,
|
||||
time,
|
||||
expected_last_state_change,
|
||||
):
|
||||
assert simple_conc_model.last_state_change(float(time)) == expected_last_state_change
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"time, expected_next_state_change", [
|
||||
[0, 0],
|
||||
[0.0, 0.0],
|
||||
[0.5, 0.5],
|
||||
[1, 1],
|
||||
[1.05, 1.1],
|
||||
[1.1, 1.1],
|
||||
[1.11, 1.999],
|
||||
[1.9991, 2],
|
||||
[1.11, 2],
|
||||
[2, 2],
|
||||
[2.1, 3],
|
||||
[3, 3],
|
||||
|
|
@ -97,24 +121,6 @@ def test_next_state_change_time_out_of_range(simple_conc_model: models.Concentra
|
|||
simple_conc_model._next_state_change(3.1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"start, stop, is_valid", [
|
||||
[0, 1.05, False],
|
||||
[0.99, 1.1, False],
|
||||
[0.5, 1.01, False],
|
||||
[0, 1, True],
|
||||
[1.01, 1.1, True],
|
||||
[0.01, 1, True],
|
||||
[1.11, 1.99, True],
|
||||
]
|
||||
)
|
||||
def test_valid_interval(
|
||||
start, stop, is_valid,
|
||||
simple_conc_model: models.ConcentrationModel
|
||||
):
|
||||
assert simple_conc_model._is_interval_between_state_changes(start, stop) == is_valid
|
||||
|
||||
|
||||
def test_integrated_concentration(simple_conc_model):
|
||||
c1 = simple_conc_model.integrated_concentration(0, 2)
|
||||
c2 = simple_conc_model.integrated_concentration(0, 1)
|
||||
|
|
|
|||
Loading…
Reference in a new issue