diff --git a/cara/apps/calculator/report_generator.py b/cara/apps/calculator/report_generator.py index 7f225c06..a064a0e1 100644 --- a/cara/apps/calculator/report_generator.py +++ b/cara/apps/calculator/report_generator.py @@ -29,11 +29,80 @@ def model_start_end(model: models.ExposureModel): return t_start, t_end -def calculate_report_data(model: models.ExposureModel): - resolution = 600 +def fill_big_gaps(array, gap_size): + """ + Insert values into the given sorted list if there is a gap of more than ``gap_size``. + All values in the given array are preserved, even if they are within the ``gap_size`` of one another. + + >>> fill_big_gaps([1, 2, 4], gap_size=0.75) + [1, 1.75, 2, 2.75, 3.5, 4] + + """ + result = [] + if len(array) == 0: + raise ValueError("Input array must be len > 0") + + last_value = array[0] + for value in array: + while value - last_value > gap_size: + last_value = last_value + gap_size + result.append(last_value) + result.append(value) + last_value = value + return result + + +def non_temp_transition_times(model: models.ExposureModel): + """ + Return the non-temperature (and PiecewiseConstant) based transition times. + + """ + def walk_model(model, name=""): + # Extend walk_dataclass to handle lists of dataclasses + # (e.g. in MultipleVentilation). + for name, obj in dataclass_utils.walk_dataclass(model, name=name): + if name.endswith('.ventilations') and isinstance(obj, (list, tuple)): + for i, item in enumerate(obj): + fq_name_i = f'{name}[{i}]' + yield fq_name_i, item + if dataclasses.is_dataclass(item): + yield from dataclass_utils.walk_dataclass(item, name=fq_name_i) + else: + yield name, obj t_start, t_end = model_start_end(model) - times = np.linspace(t_start, t_end, resolution) + + change_times = {t_start, t_end} + for name, obj in walk_model(model, name="exposure"): + if isinstance(obj, models.Interval): + change_times |= obj.transition_times() + + # Only choose times that are in the range of the model (removes things + # such as PeriodicIntervals, which extend beyond the model itself). + return sorted(time for time in change_times if (t_start <= time <= t_end)) + + +def interesting_times(model: models.ExposureModel, approx_n_pts=100) -> typing.List[float]: + """ + Pick approximately ``approx_n_pts`` time points which are interesting for the + given model. + + Initially the times are seeded by important state change times (excluding + outside temperature), and the times are then subsequently expanded to ensure + that the step size is at most ``(t_end - t_start) / approx_n_pts``. + + """ + times = non_temp_transition_times(model) + + # Expand the times list to ensure that we have a maximum gap size between + # the key times. + nice_times = fill_big_gaps(times, gap_size=(max(times) - min(times)) / approx_n_pts) + return nice_times + + +def calculate_report_data(model: models.ExposureModel): + times = interesting_times(model) + concentrations = [ np.array(model.concentration_model.concentration(float(time))).mean() for time in times @@ -212,7 +281,7 @@ def manufacture_alternative_scenarios(form: FormData) -> typing.Dict[str, mc.Exp return scenarios -def comparison_plot(scenarios: typing.Dict[str, dict], sample_times: np.ndarray): +def comparison_plot(scenarios: typing.Dict[str, dict], sample_times: typing.List[float]): fig = plt.figure() ax = fig.add_subplot(1, 1, 1) @@ -244,7 +313,7 @@ def comparison_plot(scenarios: typing.Dict[str, dict], sample_times: np.ndarray) return fig -def scenario_statistics(mc_model: mc.ExposureModel, sample_times: np.ndarray): +def scenario_statistics(mc_model: mc.ExposureModel, sample_times: typing.List[float]): model = mc_model.build_model(size=_DEFAULT_MC_SAMPLE_SIZE) return { 'probability_of_infection': np.mean(model.infection_probability()), @@ -258,7 +327,7 @@ def scenario_statistics(mc_model: mc.ExposureModel, sample_times: np.ndarray): def comparison_report( scenarios: typing.Dict[str, mc.ExposureModel], - sample_times: np.ndarray, + sample_times: typing.List[float], executor_factory: typing.Callable[[], concurrent.futures.Executor], ): statistics = {} @@ -309,8 +378,7 @@ class ReportGenerator: 'creation_date': time, } - t_start, t_end = model_start_end(model) - scenario_sample_times = np.linspace(t_start, t_end, 350) + scenario_sample_times = interesting_times(model) context.update(calculate_report_data(model)) alternative_scenarios = manufacture_alternative_scenarios(form) @@ -323,7 +391,7 @@ class ReportGenerator: 'level': 'Yellow - 2', 'incidence_rate': 'lower than 25 new cases per 100 000 inhabitants', 'onsite_access': 'of about 8000', - 'threshold' : '' + 'threshold': '' } return context diff --git a/cara/dataclass_utils.py b/cara/dataclass_utils.py index 3915b9e8..0b261afd 100644 --- a/cara/dataclass_utils.py +++ b/cara/dataclass_utils.py @@ -43,3 +43,24 @@ def replace(obj, **changes): new = dataclasses.replace(obj, **changes) object.__setattr__(obj, '__dataclass_fields__', orig) return new + + +def walk_dataclass(model, name=""): + """ + Recursively walk a dataclass instance, generating (name, obj) pairs for + attributes and decending into nested dataclasses. + + >>> list(walk_dataclass(obj), 'my_obj') + [('my_obj.attr_a', ), ('my_obj.attr_a.sub_attr', )] + + """ + if name: + name = name + '.' + if not dataclasses.is_dataclass(model): + raise TypeError(f'Not a dataclass based model: {type(model)}') + for field in dataclasses.fields(model): + obj = getattr(model, field.name) + fq_name = f'{name}{field.name}' + yield fq_name, obj + if dataclasses.is_dataclass(obj): + yield from walk_dataclass(obj, name=fq_name) diff --git a/cara/tests/apps/calculator/test_report_generator.py b/cara/tests/apps/calculator/test_report_generator.py index 91c6a02c..e3975a59 100644 --- a/cara/tests/apps/calculator/test_report_generator.py +++ b/cara/tests/apps/calculator/test_report_generator.py @@ -2,10 +2,13 @@ import concurrent.futures from functools import partial import time +import numpy.testing +import numpy as np import pytest -from cara.apps.calculator.report_generator import ReportGenerator, readable_minutes from cara.apps.calculator import make_app +from cara.apps.calculator.report_generator import ReportGenerator, readable_minutes +import cara.apps.calculator.report_generator as rep_gen def test_generate_report(baseline_form): @@ -38,3 +41,28 @@ def test_generate_report(baseline_form): ) def test_readable_minutes(test_input, expected): assert readable_minutes(test_input) == expected + + +def test_fill_big_gaps(): + expected = [1, 1.75, 2, 2.75, 3.5, 4] + assert rep_gen.fill_big_gaps([1, 2, 4], gap_size=0.75) == expected + + +def test_non_temp_transition_times(baseline_exposure_model): + expected = [0.0, 4.0, 5.0, 8.0] + result = rep_gen.non_temp_transition_times(baseline_exposure_model) + assert result == expected + + +def test_interesting_times_many(baseline_exposure_model): + result = rep_gen.interesting_times(baseline_exposure_model, approx_n_pts=100) + assert 100 <= len(result) <= 120 + assert np.abs(np.diff(result)).max() < 8.1/100. + + +def test_interesting_times_small(baseline_exposure_model): + expected = [0.0, 0.8, 1.6, 2.4, 3.2, 4.0, 4.8, 5.0, 5.8, 6.6, 7.4, 8.0] + # Ask for more data than there is in the transition times. + result = rep_gen.interesting_times(baseline_exposure_model, approx_n_pts=10) + + np.testing.assert_allclose(result, expected, atol=1e-04) diff --git a/cara/tests/test_dataclass_utils.py b/cara/tests/test_dataclass_utils.py index ec0da382..270066ee 100644 --- a/cara/tests/test_dataclass_utils.py +++ b/cara/tests/test_dataclass_utils.py @@ -1,6 +1,6 @@ import dataclasses -from cara.dataclass_utils import nested_replace +from cara.dataclass_utils import nested_replace, walk_dataclass @dataclasses.dataclass(frozen=True) @@ -25,3 +25,15 @@ def test_nested_replace(): inst = One(1, two=Two(3, Four(4))) new_inst = nested_replace(inst, {'two.four': Four(5)}) assert new_inst == One(1, two=Two(3, Four(5))) + + +def test_walk(): + inst = One(1, two=Two(3, Four(4))) + expected = [ + ('inst.one', inst.one), + ('inst.two', inst.two), + ('inst.two.three', inst.two.three), + ('inst.two.four', inst.two.four), + ('inst.two.four.four', inst.two.four.four), + ] + assert list(walk_dataclass(inst, name='inst')) == expected