diff --git a/cara/apps/calculator/report_generator.py b/cara/apps/calculator/report_generator.py index 2f73c05d..261dd8d8 100644 --- a/cara/apps/calculator/report_generator.py +++ b/cara/apps/calculator/report_generator.py @@ -29,11 +29,80 @@ def model_start_end(model: models.ExposureModel): return t_start, t_end -def calculate_report_data(model: models.ExposureModel): - resolution = 600 +def fill_big_gaps(array, gap_size): + """ + Insert values into the given sorted list if there is a gap of more than ``gap_size``. + All values in the given array are preserved, even if they are within the ``gap_size`` of one another. + + >>> fill_big_gaps([1, 2, 4], gap_size=0.75) + [1, 1.75, 2, 2.75, 3.5, 4] + + """ + result = [] + if len(array) == 0: + raise ValueError("Input array must be len > 0") + + last_value = array[0] + for value in array: + while value - last_value > gap_size + 1e-15: + last_value = last_value + gap_size + result.append(last_value) + result.append(value) + last_value = value + return result + + +def non_temp_transition_times(model: models.ExposureModel): + """ + Return the non-temperature (and PiecewiseConstant) based transition times. + + """ + def walk_model(model, name=""): + # Extend walk_dataclass to handle lists of dataclasses + # (e.g. in MultipleVentilation). + for name, obj in dataclass_utils.walk_dataclass(model, name=name): + if name.endswith('.ventilations') and isinstance(obj, (list, tuple)): + for i, item in enumerate(obj): + fq_name_i = f'{name}[{i}]' + yield fq_name_i, item + if dataclasses.is_dataclass(item): + yield from dataclass_utils.walk_dataclass(item, name=fq_name_i) + else: + yield name, obj t_start, t_end = model_start_end(model) - times = np.linspace(t_start, t_end, resolution) + + change_times = {t_start, t_end} + for name, obj in walk_model(model, name="exposure"): + if isinstance(obj, models.Interval): + change_times |= obj.transition_times() + + # Only choose times that are in the range of the model (removes things + # such as PeriodicIntervals, which extend beyond the model itself). + return sorted(time for time in change_times if (t_start <= time <= t_end)) + + +def interesting_times(model: models.ExposureModel, approx_n_pts=100) -> typing.List[float]: + """ + Pick approximately ``approx_n_pts`` time points which are interesting for the + given model. + + Initially the times are seeded by important state change times (excluding + outside temperature), and the times are then subsequently expanded to ensure + that the step size is at most ``(t_end - t_start) / approx_n_pts``. + + """ + times = non_temp_transition_times(model) + + # Expand the times list to ensure that we have a maximum gap size between + # the key times. + nice_times = fill_big_gaps(times, gap_size=(max(times) - min(times)) / approx_n_pts) + return nice_times + + +def calculate_report_data(model: models.ExposureModel): + times = interesting_times(model) + concentrations = [ np.array(model.concentration_model.concentration(float(time))).mean() for time in times @@ -212,7 +281,7 @@ def manufacture_alternative_scenarios(form: FormData) -> typing.Dict[str, mc.Exp return scenarios -def comparison_plot(scenarios: typing.Dict[str, dict], sample_times: np.ndarray): +def comparison_plot(scenarios: typing.Dict[str, dict], sample_times: typing.List[float]): fig = plt.figure() ax = fig.add_subplot(1, 1, 1) @@ -244,7 +313,7 @@ def comparison_plot(scenarios: typing.Dict[str, dict], sample_times: np.ndarray) return fig -def scenario_statistics(mc_model: mc.ExposureModel, sample_times: np.ndarray): +def scenario_statistics(mc_model: mc.ExposureModel, sample_times: typing.List[float]): model = mc_model.build_model(size=_DEFAULT_MC_SAMPLE_SIZE) return { 'probability_of_infection': np.mean(model.infection_probability()), @@ -258,7 +327,7 @@ def scenario_statistics(mc_model: mc.ExposureModel, sample_times: np.ndarray): def comparison_report( scenarios: typing.Dict[str, mc.ExposureModel], - sample_times: np.ndarray, + sample_times: typing.List[float], executor_factory: typing.Callable[[], concurrent.futures.Executor], ): statistics = {} @@ -309,8 +378,7 @@ class ReportGenerator: 'creation_date': time, } - t_start, t_end = model_start_end(model) - scenario_sample_times = np.linspace(t_start, t_end, 350) + scenario_sample_times = interesting_times(model) context.update(calculate_report_data(model)) alternative_scenarios = manufacture_alternative_scenarios(form) @@ -323,7 +391,7 @@ class ReportGenerator: 'level': 'Yellow - 2', 'incidence_rate': 'lower than 25 new cases per 100 000 inhabitants', 'onsite_access': 'of about 8000', - 'threshold' : '' + 'threshold': '' } return context diff --git a/cara/dataclass_utils.py b/cara/dataclass_utils.py index 3915b9e8..0b261afd 100644 --- a/cara/dataclass_utils.py +++ b/cara/dataclass_utils.py @@ -43,3 +43,24 @@ def replace(obj, **changes): new = dataclasses.replace(obj, **changes) object.__setattr__(obj, '__dataclass_fields__', orig) return new + + +def walk_dataclass(model, name=""): + """ + Recursively walk a dataclass instance, generating (name, obj) pairs for + attributes and decending into nested dataclasses. + + >>> list(walk_dataclass(obj), 'my_obj') + [('my_obj.attr_a', ), ('my_obj.attr_a.sub_attr', )] + + """ + if name: + name = name + '.' + if not dataclasses.is_dataclass(model): + raise TypeError(f'Not a dataclass based model: {type(model)}') + for field in dataclasses.fields(model): + obj = getattr(model, field.name) + fq_name = f'{name}{field.name}' + yield fq_name, obj + if dataclasses.is_dataclass(obj): + yield from walk_dataclass(obj, name=fq_name) diff --git a/cara/tests/apps/calculator/test_report_generator.py b/cara/tests/apps/calculator/test_report_generator.py index 91c6a02c..c18fdd45 100644 --- a/cara/tests/apps/calculator/test_report_generator.py +++ b/cara/tests/apps/calculator/test_report_generator.py @@ -2,10 +2,13 @@ import concurrent.futures from functools import partial import time +import numpy.testing +import numpy as np import pytest -from cara.apps.calculator.report_generator import ReportGenerator, readable_minutes from cara.apps.calculator import make_app +from cara.apps.calculator.report_generator import ReportGenerator, readable_minutes +import cara.apps.calculator.report_generator as rep_gen def test_generate_report(baseline_form): @@ -38,3 +41,50 @@ def test_generate_report(baseline_form): ) def test_readable_minutes(test_input, expected): assert readable_minutes(test_input) == expected + + +def test_fill_big_gaps(): + expected = [1, 1.75, 2, 2.75, 3.5, 4] + assert rep_gen.fill_big_gaps([1, 2, 4], gap_size=0.75) == expected + + +def test_fill_big_gaps__float_tolerance(): + # Ensure that there is some float tolerance to the gap size check. + assert rep_gen.fill_big_gaps([0, 2 + 1e-15, 4], gap_size=2) == [0, 2 + 1e-15, 4] + assert rep_gen.fill_big_gaps([0, 2 + 1e-14, 4], gap_size=2) == [0, 2, 2 + 1e-14, 4] + + +def test_non_temp_transition_times(baseline_exposure_model): + expected = [0.0, 4.0, 5.0, 8.0] + result = rep_gen.non_temp_transition_times(baseline_exposure_model) + assert result == expected + + +def test_interesting_times_many(baseline_exposure_model): + result = rep_gen.interesting_times(baseline_exposure_model, approx_n_pts=100) + assert 100 <= len(result) <= 120 + assert np.abs(np.diff(result)).max() < 8.1/100. + + +def test_interesting_times_small(baseline_exposure_model): + expected = [0.0, 0.8, 1.6, 2.4, 3.2, 4.0, 4.8, 5.0, 5.8, 6.6, 7.4, 8.0] + # Ask for more data than there is in the transition times. + result = rep_gen.interesting_times(baseline_exposure_model, approx_n_pts=10) + + np.testing.assert_allclose(result, expected, atol=1e-04) + + +def test_interesting_times_w_temp(exposure_model_w_outside_temp_changes): + # Ensure that the state change times are returned (minus the temperature changes) by + # requesting n_points=1. + result = rep_gen.interesting_times(exposure_model_w_outside_temp_changes, approx_n_pts=1) + expected = [0., 1.8, 2.2, 4., 4.4, 5., 6.2, 6.6, 8.] + np.testing.assert_allclose(result, expected) + + # Now request more than the state-change times. + result = rep_gen.interesting_times(exposure_model_w_outside_temp_changes, approx_n_pts=20) + expected = [ + 0., 0.4, 0.8, 1.2, 1.6, 1.8, 2.2, 2.6, 3., 3.4, 3.8, 4., 4.4, 4.8, + 5., 5.4, 5.8, 6.2, 6.6, 7., 7.4, 7.8, 8. + ] + np.testing.assert_allclose(result, expected) diff --git a/cara/tests/conftest.py b/cara/tests/conftest.py index e69b44c5..9499dfcf 100644 --- a/cara/tests/conftest.py +++ b/cara/tests/conftest.py @@ -1,4 +1,6 @@ from cara import models +import cara.data +import cara.dataclass_utils import pytest @@ -33,5 +35,20 @@ def baseline_exposure_model(baseline_model): activity=baseline_model.infected.activity, mask=baseline_model.infected.mask, ), - fraction_deposited = 1., + fraction_deposited=1., ) + + +@pytest.fixture +def exposure_model_w_outside_temp_changes(baseline_exposure_model: models.ExposureModel): + exp_model = cara.dataclass_utils.nested_replace( + baseline_exposure_model, { + 'concentration_model.ventilation': models.SlidingWindow( + active=models.PeriodicInterval(2.2 * 60, 1.8 * 60), + inside_temp=models.PiecewiseConstant((0., 24.), (293,)), + outside_temp=cara.data.GenevaTemperatures['Jan'], + window_height=1.6, + opening_length=0.6, + ) + }) + return exp_model diff --git a/cara/tests/test_dataclass_utils.py b/cara/tests/test_dataclass_utils.py index ec0da382..270066ee 100644 --- a/cara/tests/test_dataclass_utils.py +++ b/cara/tests/test_dataclass_utils.py @@ -1,6 +1,6 @@ import dataclasses -from cara.dataclass_utils import nested_replace +from cara.dataclass_utils import nested_replace, walk_dataclass @dataclasses.dataclass(frozen=True) @@ -25,3 +25,15 @@ def test_nested_replace(): inst = One(1, two=Two(3, Four(4))) new_inst = nested_replace(inst, {'two.four': Four(5)}) assert new_inst == One(1, two=Two(3, Four(5))) + + +def test_walk(): + inst = One(1, two=Two(3, Four(4))) + expected = [ + ('inst.one', inst.one), + ('inst.two', inst.two), + ('inst.two.three', inst.two.three), + ('inst.two.four', inst.two.four), + ('inst.two.four.four', inst.two.four.four), + ] + assert list(walk_dataclass(inst, name='inst')) == expected diff --git a/cara/tests/test_known_quantities.py b/cara/tests/test_known_quantities.py index 1556df75..89c0268d 100644 --- a/cara/tests/test_known_quantities.py +++ b/cara/tests/test_known_quantities.py @@ -312,7 +312,7 @@ def test_concentrations_hourly_dep_temp_vs_constant(month, temperatures, time): ) def test_concentrations_hourly_dep_temp_startup(month, temperatures, time): # The concentrations should be the zero up to the first presence time - # of an infecter person. + # of an infected person. m = build_hourly_dependent_model( month, ((0., 0.5), (1., 1.5), (4., 4.5), (7.5, 8), ),