From 64fa2b60a45305ae9323c31a1684fbb95e5860ac Mon Sep 17 00:00:00 2001 From: Phil Elson Date: Tue, 5 Jan 2021 16:30:16 +0100 Subject: [PATCH 1/3] Address and enable static type analysis checking of cara. --- cara/apps/calculator/__init__.py | 4 +- cara/apps/calculator/model_generator.py | 25 +++++---- cara/apps/calculator/report_generator.py | 15 +++--- cara/apps/expert.py | 54 +++++++++---------- cara/dataclass_utils.py | 10 ++-- cara/models.py | 43 +++++++++------ cara/state.py | 2 +- .../apps/calculator/test_report_generator.py | 14 +++++ cara/tests/test_state.py | 2 +- setup.cfg | 14 +++++ 10 files changed, 110 insertions(+), 73 deletions(-) create mode 100644 setup.cfg diff --git a/cara/apps/calculator/__init__.py b/cara/apps/calculator/__init__.py index f0463f15..674a6818 100644 --- a/cara/apps/calculator/__init__.py +++ b/cara/apps/calculator/__init__.py @@ -1,7 +1,7 @@ import html import json from pathlib import Path -from typing import Optional, Awaitable +from typing import Coroutine, Any, Optional, Awaitable import jinja2 import mistune @@ -14,7 +14,7 @@ from .user import AuthenticatedUser, AnonymousUser class BaseRequestHandler(RequestHandler): - async def prepare(self) -> Optional[Awaitable[None]]: + async def prepare(self): """Called at the beginning of a request before `get`/`post`/etc.""" username = self.request.headers.get("X-ADFS-LOGIN", None) if username: diff --git a/cara/apps/calculator/model_generator.py b/cara/apps/calculator/model_generator.py index b5c95ca7..c6ff4508 100644 --- a/cara/apps/calculator/model_generator.py +++ b/cara/apps/calculator/model_generator.py @@ -171,7 +171,7 @@ class FormData: def build_model(self) -> models.ExposureModel: return model_from_form(self) - def ventilation(self) -> models.Ventilation: + def ventilation(self) -> typing.Union[models.Ventilation, models.MultipleVentilation]: always_on = models.PeriodicInterval(period=120, duration=120) # Initializes a ventilation instance as a window if 'natural' is selected, or as a HEPA-filter otherwise if self.ventilation_type == 'natural': @@ -189,10 +189,12 @@ class FormData: inside_temp = models.PiecewiseConstant((0, 24), (293,)) outside_temp = data.GenevaTemperatures[month] + ventilation: models.Ventilation if self.window_type == 'sliding': ventilation = models.SlidingWindow( active=window_interval, - inside_temp=inside_temp, outside_temp=outside_temp, + inside_temp=inside_temp, + outside_temp=outside_temp, window_height=self.window_height, opening_length=self.opening_distance, number_of_windows=self.windows_number, @@ -200,7 +202,8 @@ class FormData: elif self.window_type == 'hinged': ventilation = models.HingedWindow( active=window_interval, - inside_temp=inside_temp, outside_temp=outside_temp, + inside_temp=inside_temp, + outside_temp=outside_temp, window_height=self.window_height, window_width=self.window_width, opening_length=self.opening_distance, @@ -218,7 +221,7 @@ class FormData: if self.hepa_option: hepa = models.HEPAFilter(active=always_on, q_air_mech=self.hepa_amount) - return models.MultipleVentilation((ventilation,hepa)) + return models.MultipleVentilation((ventilation, hepa)) else: return ventilation @@ -301,7 +304,7 @@ class FormData: ) return exposed - def _compute_breaks_in_interval(self, start, finish, n_breaks) -> typing.Tuple[typing.Tuple[int, int]]: + def _compute_breaks_in_interval(self, start, finish, n_breaks) -> models.BoundarySequence_t: break_delay = ((finish - start) - (n_breaks * self.coffee_duration)) // (n_breaks+1) break_times = [] end = start @@ -311,13 +314,13 @@ class FormData: break_times.append((begin, end)) return tuple(break_times) - def lunch_break_times(self) -> typing.Tuple[typing.Tuple[int, int]]: + def lunch_break_times(self) -> models.BoundarySequence_t: result = [] if self.lunch_option: result.append((self.lunch_start, self.lunch_finish)) return tuple(result) - def coffee_break_times(self) -> typing.Tuple[typing.Tuple[int, int]]: + def coffee_break_times(self) -> models.BoundarySequence_t: if not self.coffee_breaks: return () if self.lunch_option: @@ -475,10 +478,10 @@ def expiration_blend(expiration_weights: typing.Dict[models.Expiration, int]) -> ejection_factor += np.array(expiration.ejection_factor) * weight particle_sizes += np.array(expiration.particle_sizes) * weight - return models.Expiration( - ejection_factor=tuple(ejection_factor/total_weight), - particle_sizes=tuple(particle_sizes/total_weight), - ) + r_ejection_factor: typing.Tuple[float, float, float, float] = tuple(ejection_factor/total_weight) # type: ignore + r_particle_sizes: typing.Tuple[float, float, float, float] = tuple(particle_sizes/total_weight) # type: ignore + + return models.Expiration(ejection_factor=r_ejection_factor, particle_sizes=r_particle_sizes) def model_from_form(form: FormData) -> models.ExposureModel: diff --git a/cara/apps/calculator/report_generator.py b/cara/apps/calculator/report_generator.py index 27abbf80..11eab671 100644 --- a/cara/apps/calculator/report_generator.py +++ b/cara/apps/calculator/report_generator.py @@ -113,9 +113,9 @@ def minutes_to_time(minutes: int) -> str: return f"{hour_string}:{minute_string}" -def readable_minutes(minutes: int) -> str: - time = minutes +def readable_minutes(minutes: int) -> str: + time = float(minutes) unit = " minute" if time % 60 == 0: time = minutes/60 @@ -124,19 +124,20 @@ def readable_minutes(minutes: int) -> str: unit += "s" if time.is_integer(): - time = "{:0.0f}".format(time) + time_str = "{:0.0f}".format(time) else: - time = "{0:.2f}".format(time) + time_str = "{0:.2f}".format(time) - return time + unit + return time_str + unit -def non_zero_percentage (percentage: int) -> str: +def non_zero_percentage(percentage: int) -> str: if percentage < 0.01: return "<0.01%" else: return "{:0.0f}%".format(percentage) + def manufacture_alternative_scenarios(form: FormData) -> typing.Dict[str, models.ExposureModel]: scenarios = {} @@ -243,7 +244,7 @@ def build_report(model: models.ExposureModel, form: FormData): cara_templates = Path(__file__).parent.parent / "templates" calculator_templates = Path(__file__).parent / "templates" env = jinja2.Environment( - loader=jinja2.FileSystemLoader([cara_templates, calculator_templates]), + loader=jinja2.FileSystemLoader([str(cara_templates), str(calculator_templates)]), undefined=jinja2.StrictUndefined, ) env.filters['non_zero_percentage'] = non_zero_percentage diff --git a/cara/apps/expert.py b/cara/apps/expert.py index 3dbd634d..a9b0223f 100644 --- a/cara/apps/expert.py +++ b/cara/apps/expert.py @@ -32,9 +32,9 @@ WidgetPairType = typing.Tuple[widgets.Widget, widgets.Widget] class WidgetGroup: - def __init__(self, label_widget_pairs: typing.Sequence[WidgetPairType]): - self.labels = [] - self.widgets = [] + def __init__(self, label_widget_pairs: typing.Iterable[WidgetPairType]): + self.labels: typing.List[widgets.Widget] = [] + self.widgets: typing.List[widgets.Widget] = [] self.add_pairs(label_widget_pairs) def set_visible(self, visible: bool): @@ -46,10 +46,10 @@ class WidgetGroup: widget.layout.visible = False widget.layout.display = 'none' - def pairs(self) -> typing.Sequence[WidgetPairType]: + def pairs(self) -> typing.Iterable[WidgetPairType]: return zip(*[self.labels, self.widgets]) - def add_pairs(self, label_widget_pairs: typing.Sequence[WidgetPairType]): + def add_pairs(self, label_widget_pairs: typing.Iterable[WidgetPairType]): labels, widgets_ = zip(*label_widget_pairs) self.labels.extend(labels) self.widgets.extend(widgets_) @@ -190,13 +190,13 @@ class ExposureComparissonResult(View): return ax def scenarios_updated(self, scenarios: typing.Sequence[ScenarioType], _): - labels, models = zip(*scenarios) - conc_models: typing.Tuple[models.ConcentrationModel] = tuple( - model.concentration_model.dcs_instance() for model in models + updated_labels, updated_models = zip(*scenarios) + conc_models: typing.Tuple[models.ConcentrationModel, ...] = tuple( + model.concentration_model.dcs_instance() for model in updated_models ) - self.update_plot(conc_models, labels) + self.update_plot(conc_models, updated_labels) - def update_plot(self, conc_models: typing.Tuple[models.ConcentrationModel], labels: typing.Tuple[str]): + def update_plot(self, conc_models: typing.Tuple[models.ConcentrationModel, ...], labels: typing.Tuple[str, ...]): self.ax.lines.clear() start, finish = models_start_end(conc_models) ts = np.linspace(start, finish, num=250) @@ -268,12 +268,12 @@ class ModelWidgets(View): outside_temp.observe(outsidetemp_change, names=['value']) auto_width = widgets.Layout(width='auto') - return WidgetGroup([ - [ + return WidgetGroup(( + ( widgets.Label('Outside temperature (℃)', layout=auto_width,), outside_temp, - ], - ]) + ), + )) def _build_window(self, node) -> WidgetGroup: period = widgets.IntSlider(value=node.active.period, min=0, max=240) @@ -314,24 +314,24 @@ class ModelWidgets(View): toggle_outsidetemp(outsidetemp_w.value) auto_width = widgets.Layout(width='auto') - result = WidgetGroup([ - [ + result = WidgetGroup(( + ( widgets.Label('Interval between openings (minutes)', layout=auto_width), period, - ], - [ + ), + ( widgets.Label('Duration of opening (minutes)', layout=auto_width), interval, - ], - [ + ), + ( widgets.Label('Inside temperature (℃)', layout=auto_width), inside_temp, - ], - [ + ), + ( widgets.Label('Outside temperature scheme', layout=auto_width), outsidetemp_w, - ] - ]) + ) + )) for sub_group in outsidetemp_widgets.values(): result.add_pairs(sub_group.pairs()) return result @@ -362,9 +362,9 @@ class ModelWidgets(View): month_choice.observe(on_month_change, names=['value']) return WidgetGroup( - [ - [widgets.Label("Month"), month_choice], - ] + ( + (widgets.Label("Month"), month_choice), + ) ) def _build_activity(self, node): diff --git a/cara/dataclass_utils.py b/cara/dataclass_utils.py index 50654d43..cdfd867b 100644 --- a/cara/dataclass_utils.py +++ b/cara/dataclass_utils.py @@ -2,13 +2,9 @@ import dataclasses import typing -DCInst = typing.TypeVar('T') - - -def nested_replace(obj: DCInst, new_values: typing.Dict[str, typing.Any]) -> DCInst: - """ - Replace an attribute on a dataclass, much like dataclasses.replace, except it - supports nested replacement definitions. For example: +def nested_replace(obj, new_values: typing.Dict[str, typing.Any]): + """Replace an attribute on a dataclass, much like dataclasses.replace, + except it supports nested replacement definitions. For example: >>> new_obj = nested_replace(obj, {'attr1.sub_attr2.sub_sub_attr3': 4}) >>> new_obj.attr1.sub_attr2.sub_sub_attr3 diff --git a/cara/models.py b/cara/models.py index 4478c1ce..dbdaa329 100644 --- a/cara/models.py +++ b/cara/models.py @@ -12,6 +12,11 @@ class Room: volume: float +Time_t = typing.TypeVar('Time_t', float, int) +BoundaryPair_t = typing.Tuple[Time_t, Time_t] +BoundarySequence_t = typing.Union[typing.Tuple[BoundaryPair_t, ...], typing.Tuple] + + @dataclass(frozen=True) class Interval: """ @@ -26,7 +31,7 @@ class Interval: start < t <= end """ - def boundaries(self) -> typing.Tuple[typing.Tuple[float, float], ...]: + def boundaries(self) -> BoundarySequence_t: return () def transition_times(self) -> typing.Set[float]: @@ -48,23 +53,23 @@ class SpecificInterval(Interval): #: A sequence of times (start, stop), in hours, that the infected person #: is present. The flattened list of times must be strictly monotonically #: increasing. - present_times: typing.Tuple[typing.Tuple[float, float], ...] + present_times: BoundarySequence_t - def boundaries(self): + def boundaries(self) -> BoundarySequence_t: return self.present_times @dataclass(frozen=True) class PeriodicInterval(Interval): #: How often does the interval occur (minutes). - period: int + period: float #: How long does the interval occur for (minutes). #: A value greater than :data:`period` signifies the event is permanently #: occurring, a value of 0 signifies that the event never happens. - duration: int + duration: float - def boundaries(self) -> typing.Tuple[typing.Tuple[float, float], ...]: + def boundaries(self) -> BoundarySequence_t: if self.period == 0 or self.duration == 0: return tuple() result = [] @@ -91,16 +96,17 @@ class PiecewiseConstant: if tuple(sorted(set(self.transition_times))) != self.transition_times: raise ValueError("transition_times should not contain duplicated elements and should be sorted") - def value(self,time) -> float: + def value(self, time) -> float: if time <= self.transition_times[0]: return self.values[0] - if time > self.transition_times[-1]: + elif time > self.transition_times[-1]: return self.values[-1] - for t1,t2,value in zip(self.transition_times[:-1], - self.transition_times[1:],self.values): - if time > t1 and time <= t2: - return value + for t1, t2, value in zip(self.transition_times[:-1], + self.transition_times[1:], self.values): + if t1 < time <= t2: + break + return value def interval(self) -> Interval: # build an Interval object @@ -109,7 +115,7 @@ class PiecewiseConstant: self.transition_times[1:],self.values): if value: present_times.append((t1,t2)) - return SpecificInterval(present_times=present_times) + return SpecificInterval(present_times=tuple(present_times)) def refine(self,refine_factor=10): # build a new PiecewiseConstant object with a refined mesh, @@ -258,10 +264,10 @@ class HingedWindow(WindowOpening): horizontal plane). """ #: Window width (m). - window_width: float = None + window_width: float = 0.0 def __post_init__(self): - if self.window_width is None: + if not self.window_width > 0: raise ValueError('window_width must be set') @property @@ -387,7 +393,10 @@ class Mask: #: Filtration efficiency of masks when inhaling. η_inhale: float - particle_sizes: typing.Tuple[float] = (0.8e-4, 1.8e-4, 3.5e-4, 5.5e-4) # In cm. + #: Particle sizes in cm. + particle_sizes: typing.Tuple[float, float, float, float] = ( + 0.8e-4, 1.8e-4, 3.5e-4, 5.5e-4 + ) #: Pre-populated examples of Masks. types: typing.ClassVar[typing.Dict[str, "Mask"]] @@ -542,7 +551,7 @@ class InfectedPopulation(Population): @dataclass(frozen=True) class ConcentrationModel: room: Room - ventilation: Ventilation + ventilation: typing.Union[Ventilation, MultipleVentilation] infected: InfectedPopulation @property diff --git a/cara/state.py b/cara/state.py index 145a1c2b..00cfb3bc 100644 --- a/cara/state.py +++ b/cara/state.py @@ -278,7 +278,7 @@ class DataclassStatePredefined(DataclassInstanceState): """ def __init__(self, dataclass: Datamodel_T, - choices: typing.Dict[typing.Hashable, dataclass_instance], + choices: typing.Dict[str, dataclass_instance], **kwargs, ): super().__init__(dataclass=dataclass, **kwargs) diff --git a/cara/tests/apps/calculator/test_report_generator.py b/cara/tests/apps/calculator/test_report_generator.py index 5bb7d719..3ba1a007 100644 --- a/cara/tests/apps/calculator/test_report_generator.py +++ b/cara/tests/apps/calculator/test_report_generator.py @@ -19,3 +19,17 @@ def test_generate_report(baseline_form): report = report_generator.build_report(model, baseline_form) assert report != "" + + +@pytest.mark.parametrize( + ["test_input", "expected"], + [ + [1, '1 minute'], + [2, '2 minutes'], + [60, '1 hour'], + [120, '2 hours'], + [150, '150 minutes'], + ], +) +def test_readable_minutes(test_input, expected): + assert report_generator.readable_minutes(test_input) == expected \ No newline at end of file diff --git a/cara/tests/test_state.py b/cara/tests/test_state.py index 2102b96e..814d2ae4 100644 --- a/cara/tests/test_state.py +++ b/cara/tests/test_state.py @@ -25,7 +25,7 @@ class DCSimpleSubclass(DCSimple): @dataclass class DCOverrideSubclass(DCSimple): - attr1: float + attr1: float # type: ignore @dataclass diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..75cae8ff --- /dev/null +++ b/setup.cfg @@ -0,0 +1,14 @@ +[mypy] +no_warn_no_return = True + +[mypy-matplotlib.*] +ignore_missing_imports = True + +[mypy-ipympl.*] +ignore_missing_imports = True + +[mypy-ipywidgets.*] +ignore_missing_imports = True + +[mypy-mistune.*] +ignore_missing_imports = True From e9f9b7fde93546062ff2e990eab42283b4ff4b47 Mon Sep 17 00:00:00 2001 From: Phil Elson Date: Tue, 5 Jan 2021 16:46:43 +0100 Subject: [PATCH 2/3] Fix up type annotation issues with the cara.state module. --- cara/apps/expert.py | 4 ++-- cara/state.py | 54 +++++++++++++++++++++++---------------------- 2 files changed, 30 insertions(+), 28 deletions(-) diff --git a/cara/apps/expert.py b/cara/apps/expert.py index a9b0223f..c284f578 100644 --- a/cara/apps/expert.py +++ b/cara/apps/expert.py @@ -481,12 +481,12 @@ baseline_model = models.ExposureModel( class CARAStateBuilder(state.StateBuilder): def build_type_Mask(self, _: dataclasses.Field): return state.DataclassStatePredefined( - models.Mask, + dataclass=models.Mask, choices=models.Mask.types, ) def build_type_Ventilation(self, _: dataclasses.Field): - s = state.DataclassStateNamed( + s: state.DataclassStateNamed = state.DataclassStateNamed( states={ 'Natural': self.build_generic(models.WindowOpening), 'Mechanical': self.build_generic(models.HVACMechanical), diff --git a/cara/state.py b/cara/state.py index 00cfb3bc..10af075d 100644 --- a/cara/state.py +++ b/cara/state.py @@ -14,7 +14,7 @@ import dataclasses import typing -Datamodel_T = typing.Type +Datamodel_T = typing.TypeVar('Datamodel_T') dataclass_instance = typing.Any @@ -34,11 +34,11 @@ class StateBuilder: return method return self.build_generic - def build_generic(self, type_to_build: typing.Type): + def build_generic(self, type_to_build: typing.Type) -> "DataclassInstanceState": return DataclassInstanceState(type_to_build, state_builder=self) -class DataclassState: +class DataclassState(typing.Generic[Datamodel_T]): def __init__(self, state_builder=StateBuilder()): with self._object_setattr(): self._state_builder = state_builder @@ -54,7 +54,7 @@ class DataclassState: yield object.__setattr__(self, '_use_base_setattr', False) - def dcs_instance(self): + def dcs_instance(self) -> Datamodel_T: """ Return the instance that this state represents. The instance returned is immutable, so it is advised to call this method each time that @@ -81,7 +81,7 @@ class DataclassState: """ yield - def dcs_update_from(self, data: dataclass_instance): + def dcs_update_from(self, data: Datamodel_T): """ Update the state based on the values of the given dataclass instance. @@ -94,7 +94,7 @@ class DataclassState: """ - def dcs_set_instance_type(self, instance_dataclass: Datamodel_T): + def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]): """ Update the current instance of the state to this type. @@ -103,7 +103,7 @@ class DataclassState: """ -class DataclassInstanceState(DataclassState): +class DataclassInstanceState(DataclassState[Datamodel_T]): """ Represents the state of a frozen dataclass. No type checking of the attributes is attempted. @@ -142,11 +142,12 @@ class DataclassInstanceState(DataclassState): #: The instance of dataclass which this state represents. Undefined until #: sufficient data is provided. - self._instance = None - self._data = {} - self._observers: typing.List[callable] = [] + self._instance: typing.Optional[Datamodel_T] = None + #: The underlying state data mapping attribute name to object. + self._data: typing.Dict[str, typing.Any] = {} + self._observers: typing.List[typing.Callable] = [] self._state_builder = state_builder - self._held_events = [] + self._held_events: typing.List[bool] = [] self._hold_fire = False self.dcs_set_instance_type(dataclass) @@ -169,7 +170,7 @@ class DataclassInstanceState(DataclassState): self._held_events.clear() self._fire_observers() - def dcs_update_from(self, data: dataclass_instance): + def dcs_update_from(self, data: Datamodel_T): with self.dcs_state_transaction(): self.dcs_set_instance_type(data.__class__) for field in dataclasses.fields(data): @@ -225,7 +226,7 @@ class DataclassInstanceState(DataclassState): else: raise AttributeError(f"No attribute {attr_name} on a {self._instance_type.__name__}") - def dcs_set_instance_type(self, instance_dataclass: Datamodel_T): + def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]): if not dataclasses.is_dataclass(instance_dataclass): raise TypeError("The given class is not a valid dataclass") if not issubclass(instance_dataclass, self._base): @@ -266,7 +267,7 @@ class DataclassInstanceState(DataclassState): return self._instance -class DataclassStatePredefined(DataclassInstanceState): +class DataclassStatePredefined(DataclassInstanceState[Datamodel_T]): """ Only a pre-defined selection of states for the given type are allowed. Selected by name (the keys in the dictionary). @@ -278,18 +279,20 @@ class DataclassStatePredefined(DataclassInstanceState): """ def __init__(self, dataclass: Datamodel_T, - choices: typing.Dict[str, dataclass_instance], + choices: typing.Dict[str, Datamodel_T], **kwargs, ): super().__init__(dataclass=dataclass, **kwargs) + # Pick the first choice until we know otherwise. + default_selection = list(choices.keys())[0] with self._object_setattr(): self._choices = choices - self._selected = None - # Pick the first choice until we know otherwise. - self.dcs_select(list(choices.keys())[0]) + self._selected: str = default_selection - def dcs_select(self, name: typing.Hashable): + self.dcs_select(default_selection) + + def dcs_select(self, name: str): if name not in self._choices: raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._choices)}') self._selected = name @@ -309,24 +312,23 @@ class DataclassStatePredefined(DataclassInstanceState): return dataclasses.asdict(self.dcs_instance()) -class DataclassStateNamed(DataclassState): +class DataclassStateNamed(DataclassState[Datamodel_T]): """ A collection of instances of the given type, switchable by name, but each instance is still mutable. """ def __init__(self, - states: typing.Dict[typing.Hashable, DataclassState], + states: typing.Dict[str, DataclassState], **kwargs ): # TODO: This is effectively a container type. We shouldn't use the standard constructor for this. enabled = list(states.keys())[0] - t = states[enabled] super().__init__(**kwargs) with self._object_setattr(): self._states = states.copy() - self._selected = None + self._selected: str = enabled # Pick the first choice until we know otherwise. self.dcs_select(enabled) @@ -348,7 +350,7 @@ class DataclassStateNamed(DataclassState): return object.__setattr__(self, name, value) setattr(self._selected_state(), name, value) - def dcs_select(self, name: typing.Hashable): + def dcs_select(self, name: str): if name not in self._states: raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._states)}') self._selected = name @@ -369,10 +371,10 @@ class DataclassStateNamed(DataclassState): for state in self._states.values(): state.dcs_observe(callback) - def dcs_update_from(self, data: dataclass_instance): + def dcs_update_from(self, data: Datamodel_T): return self._selected_state().dcs_update_from(data) - def dcs_set_instance_type(self, instance_dataclass: Datamodel_T): + def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]): return self._selected_state().dcs_set_instance_type(instance_dataclass) @contextmanager From c8f39dc26b3c03772b960f618f509b23e3b70790 Mon Sep 17 00:00:00 2001 From: Phil Elson Date: Tue, 5 Jan 2021 18:59:43 +0100 Subject: [PATCH 3/3] Track down the issue with the expert app and typing. The visitor pattern in the StateBuilder means that it is easy to miss type changes in the dataclass definition. --- cara/apps/calculator/__init__.py | 1 - cara/apps/calculator/model_generator.py | 14 +++--- cara/apps/expert.py | 66 +++++++++++++++---------- cara/models.py | 45 ++++++++++------- cara/state.py | 17 ++++--- setup.cfg | 4 ++ setup.py | 4 +- 7 files changed, 91 insertions(+), 60 deletions(-) diff --git a/cara/apps/calculator/__init__.py b/cara/apps/calculator/__init__.py index 674a6818..437fffb1 100644 --- a/cara/apps/calculator/__init__.py +++ b/cara/apps/calculator/__init__.py @@ -1,7 +1,6 @@ import html import json from pathlib import Path -from typing import Coroutine, Any, Optional, Awaitable import jinja2 import mistune diff --git a/cara/apps/calculator/model_generator.py b/cara/apps/calculator/model_generator.py index c6ff4508..92def303 100644 --- a/cara/apps/calculator/model_generator.py +++ b/cara/apps/calculator/model_generator.py @@ -171,7 +171,7 @@ class FormData: def build_model(self) -> models.ExposureModel: return model_from_form(self) - def ventilation(self) -> typing.Union[models.Ventilation, models.MultipleVentilation]: + def ventilation(self) -> models._VentilationBase: always_on = models.PeriodicInterval(period=120, duration=120) # Initializes a ventilation instance as a window if 'natural' is selected, or as a HEPA-filter otherwise if self.ventilation_type == 'natural': @@ -344,7 +344,7 @@ class FormData: self, start: int, finish: int, - breaks: typing.Tuple[typing.Tuple[int, int], ...] = None, + breaks: typing.Optional[models.BoundarySequence_t] = None, ) -> models.Interval: """ Calculate the presence interval given the start and end times (in minutes), and @@ -357,14 +357,14 @@ class FormData: # Order the breaks by their start-time, and ensure that they are monotonic # and that the start of one break happens after the end of another. - breaks = sorted(breaks, key=lambda break_pair: break_pair[0]) + break_boundaries: models.BoundarySequence_t = tuple(sorted(breaks, key=lambda break_pair: break_pair[0])) - for break_start, break_end in breaks: + for break_start, break_end in break_boundaries: if break_start >= break_end: raise ValueError("Break ends before it begins.") - prev_break_end = breaks[0][1] - for break_start, break_end in breaks[1:]: + prev_break_end = break_boundaries[0][1] + for break_start, break_end in break_boundaries[1:]: if prev_break_end >= break_start: raise ValueError(f"A break starts before another ends ({break_start}, {break_end}, {prev_break_end}).") prev_break_end = break_end @@ -389,7 +389,7 @@ class FormData: # 5. The interval straddles the end of the break. Bs <= S < Be <= E # 6. The interval is entirely after the break. Bs < Be <= S < E - for current_break in breaks: + for current_break in break_boundaries: if current_time >= finish: break diff --git a/cara/apps/expert.py b/cara/apps/expert.py index c284f578..a1eefd76 100644 --- a/cara/apps/expert.py +++ b/cara/apps/expert.py @@ -191,7 +191,7 @@ class ExposureComparissonResult(View): def scenarios_updated(self, scenarios: typing.Sequence[ScenarioType], _): updated_labels, updated_models = zip(*scenarios) - conc_models: typing.Tuple[models.ConcentrationModel, ...] = tuple( + conc_models = tuple( model.concentration_model.dcs_instance() for model in updated_models ) self.update_plot(conc_models, updated_labels) @@ -268,12 +268,14 @@ class ModelWidgets(View): outside_temp.observe(outsidetemp_change, names=['value']) auto_width = widgets.Layout(width='auto') - return WidgetGroup(( + return WidgetGroup( ( - widgets.Label('Outside temperature (℃)', layout=auto_width,), - outside_temp, + ( + widgets.Label('Outside temperature (℃)', layout=auto_width,), + outside_temp, + ), ), - )) + ) def _build_window(self, node) -> WidgetGroup: period = widgets.IntSlider(value=node.active.period, min=0, max=240) @@ -314,24 +316,26 @@ class ModelWidgets(View): toggle_outsidetemp(outsidetemp_w.value) auto_width = widgets.Layout(width='auto') - result = WidgetGroup(( + result = WidgetGroup( ( - widgets.Label('Interval between openings (minutes)', layout=auto_width), - period, + ( + widgets.Label('Interval between openings (minutes)', layout=auto_width), + period, + ), + ( + widgets.Label('Duration of opening (minutes)', layout=auto_width), + interval, + ), + ( + widgets.Label('Inside temperature (℃)', layout=auto_width), + inside_temp, + ), + ( + widgets.Label('Outside temperature scheme', layout=auto_width), + outsidetemp_w, + ), ), - ( - widgets.Label('Duration of opening (minutes)', layout=auto_width), - interval, - ), - ( - widgets.Label('Inside temperature (℃)', layout=auto_width), - inside_temp, - ), - ( - widgets.Label('Outside temperature scheme', layout=auto_width), - outsidetemp_w, - ) - )) + ) for sub_group in outsidetemp_widgets.values(): result.add_pairs(sub_group.pairs()) return result @@ -364,7 +368,7 @@ class ModelWidgets(View): return WidgetGroup( ( (widgets.Label("Month"), month_choice), - ) + ), ) def _build_activity(self, node): @@ -414,7 +418,13 @@ class ModelWidgets(View): [[widgets.Label("Expiration"), expiration_choice]] ) - def _build_ventilation(self, node): + def _build_ventilation( + self, + node: typing.Union[ + state.DataclassStateNamed[models.Ventilation], + state.DataclassStateNamed[models.MultipleVentilation], + ], + ) -> widgets.Widget: ventilation_widgets = { 'Natural': self._build_window(node._states['Natural']).build(), 'Mechanical': self._build_mechanical(node._states['Mechanical']), @@ -479,13 +489,17 @@ baseline_model = models.ExposureModel( class CARAStateBuilder(state.StateBuilder): + # Note: The methods in this class must correspond to the *type* of the data classes. + # For example, build_type__VentilationBase is called when dealing with ConcentrationModel + # types as it has a ventilation: _VentilationBase field. + def build_type_Mask(self, _: dataclasses.Field): return state.DataclassStatePredefined( - dataclass=models.Mask, + models.Mask, choices=models.Mask.types, ) - def build_type_Ventilation(self, _: dataclasses.Field): + def build_type__VentilationBase(self, _: dataclasses.Field): s: state.DataclassStateNamed = state.DataclassStateNamed( states={ 'Natural': self.build_generic(models.WindowOpening), @@ -526,7 +540,7 @@ class ExpertApplication(Controller): ) self.add_scenario('Scenario 1') - def build_new_model(self): + def build_new_model(self) -> state.DataclassInstanceState[models.ExposureModel]: default_model = state.DataclassInstanceState( models.ExposureModel, state_builder=CARAStateBuilder(), diff --git a/cara/models.py b/cara/models.py index dbdaa329..c8f97d6c 100644 --- a/cara/models.py +++ b/cara/models.py @@ -111,24 +111,29 @@ class PiecewiseConstant: def interval(self) -> Interval: # build an Interval object present_times = [] - for t1,t2,value in zip(self.transition_times[:-1], - self.transition_times[1:],self.values): + for t1, t2, value in zip(self.transition_times[:-1], + self.transition_times[1:], self.values): if value: present_times.append((t1,t2)) return SpecificInterval(present_times=tuple(present_times)) - def refine(self,refine_factor=10): + def refine(self, refine_factor=10): # build a new PiecewiseConstant object with a refined mesh, # using a linear interpolation in-between the initial mesh points - refined_times = np.linspace(self.transition_times[0],self.transition_times[-1], - (len(self.transition_times)-1)*refine_factor+1) - return PiecewiseConstant(tuple(refined_times), - tuple(np.interp(refined_times[:-1],self.transition_times, - self.values+(self.values[-1],) ) ) ) + refined_times = np.linspace(self.transition_times[0], self.transition_times[-1], + (len(self.transition_times)-1) * refine_factor+1) + return PiecewiseConstant( + tuple(refined_times), + tuple(np.interp( + refined_times[:-1], + self.transition_times, + self.values + (self.values[-1], ), + )), + ) @dataclass(frozen=True) -class Ventilation: +class _VentilationBase: """ Represents a mechanism by which air can be exchanged (replaced/filtered) in a time dependent manner. @@ -139,11 +144,8 @@ class Ventilation: mechanical air exchange through a filter. """ - #: The times at which the air exchange is taking place. - active: Interval - def transition_times(self) -> typing.Set[float]: - return self.active.transition_times() + raise NotImplementedError("Subclass must implement") def air_exchange(self, room: Room, time: float) -> float: """ @@ -159,7 +161,16 @@ class Ventilation: @dataclass(frozen=True) -class MultipleVentilation: +class Ventilation(_VentilationBase): + #: The interval in which the ventilation is active. + active: Interval + + def transition_times(self) -> typing.Set[float]: + return self.active.transition_times() + + +@dataclass(frozen=True) +class MultipleVentilation(_VentilationBase): """ Represents a mechanism by which air can be exchanged (replaced/filtered) in a time dependent manner. @@ -167,7 +178,7 @@ class MultipleVentilation: Group together different sources of ventilations. """ - ventilations: typing.Tuple[Ventilation, ...] + ventilations: typing.Tuple[_VentilationBase, ...] def transition_times(self) -> typing.Set[float]: transitions = set() @@ -180,7 +191,7 @@ class MultipleVentilation: Returns the rate at which air is being exchanged in the given room at a given time (in hours). """ - return sum([ventilation.air_exchange(room,time) + return sum([ventilation.air_exchange(room, time) for ventilation in self.ventilations]) @@ -551,7 +562,7 @@ class InfectedPopulation(Population): @dataclass(frozen=True) class ConcentrationModel: room: Room - ventilation: typing.Union[Ventilation, MultipleVentilation] + ventilation: _VentilationBase infected: InfectedPopulation @property diff --git a/cara/state.py b/cara/state.py index 10af075d..4c2bf8f0 100644 --- a/cara/state.py +++ b/cara/state.py @@ -122,7 +122,7 @@ class DataclassInstanceState(DataclassState[Datamodel_T]): """ - def __init__(self, dataclass: Datamodel_T, state_builder=StateBuilder()): + def __init__(self, dataclass: typing.Type[Datamodel_T], state_builder=StateBuilder()): super().__init__(state_builder=state_builder) # Note that the constructor does *not* insert any data by default. It @@ -138,7 +138,7 @@ class DataclassInstanceState(DataclassState[Datamodel_T]): self._base = dataclass #: The actual instance type that this state represents (i.e. may be a #: subclass of _base). - self._instance_type = dataclass + self._instance_type: typing.Type[Datamodel_T] = dataclass #: The instance of dataclass which this state represents. Undefined until #: sufficient data is provided. @@ -278,18 +278,18 @@ class DataclassStatePredefined(DataclassInstanceState[Datamodel_T]): """ def __init__(self, - dataclass: Datamodel_T, + dataclass: typing.Type[Datamodel_T], choices: typing.Dict[str, Datamodel_T], **kwargs, ): super().__init__(dataclass=dataclass, **kwargs) - # Pick the first choice until we know otherwise. - default_selection = list(choices.keys())[0] with self._object_setattr(): self._choices = choices - self._selected: str = default_selection + self._selected: str = None # type: ignore + # Pick the first choice until we know otherwise. + default_selection = list(choices.keys())[0] self.dcs_select(default_selection) def dcs_select(self, name: str): @@ -319,16 +319,17 @@ class DataclassStateNamed(DataclassState[Datamodel_T]): """ def __init__(self, - states: typing.Dict[str, DataclassState], + states: typing.Dict[str, DataclassState[Datamodel_T]], **kwargs ): # TODO: This is effectively a container type. We shouldn't use the standard constructor for this. enabled = list(states.keys())[0] + t = states[enabled] super().__init__(**kwargs) with self._object_setattr(): self._states = states.copy() - self._selected: str = enabled + self._selected: str = None # type: ignore # Pick the first choice until we know otherwise. self.dcs_select(enabled) diff --git a/setup.cfg b/setup.cfg index 75cae8ff..83208a5f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,7 @@ +[tool:pytest] +addopts = --mypy + + [mypy] no_warn_no_return = True diff --git a/setup.py b/setup.py index 31c73acf..2f86ddc6 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,9 @@ REQUIREMENTS: dict = { 'app': [], 'test': [ 'pytest', - 'pytest-tornasync', # Unused, but needed because of a downstream dependency. + 'pytest-mypy', + 'pytest-tornasync', + 'numpy-stubs @ git+https://github.com/numpy/numpy-stubs.git', ], 'dev': [ 'jupyterlab',