From e9f9b7fde93546062ff2e990eab42283b4ff4b47 Mon Sep 17 00:00:00 2001 From: Phil Elson Date: Tue, 5 Jan 2021 16:46:43 +0100 Subject: [PATCH] Fix up type annotation issues with the cara.state module. --- cara/apps/expert.py | 4 ++-- cara/state.py | 54 +++++++++++++++++++++++---------------------- 2 files changed, 30 insertions(+), 28 deletions(-) diff --git a/cara/apps/expert.py b/cara/apps/expert.py index a9b0223f..c284f578 100644 --- a/cara/apps/expert.py +++ b/cara/apps/expert.py @@ -481,12 +481,12 @@ baseline_model = models.ExposureModel( class CARAStateBuilder(state.StateBuilder): def build_type_Mask(self, _: dataclasses.Field): return state.DataclassStatePredefined( - models.Mask, + dataclass=models.Mask, choices=models.Mask.types, ) def build_type_Ventilation(self, _: dataclasses.Field): - s = state.DataclassStateNamed( + s: state.DataclassStateNamed = state.DataclassStateNamed( states={ 'Natural': self.build_generic(models.WindowOpening), 'Mechanical': self.build_generic(models.HVACMechanical), diff --git a/cara/state.py b/cara/state.py index 00cfb3bc..10af075d 100644 --- a/cara/state.py +++ b/cara/state.py @@ -14,7 +14,7 @@ import dataclasses import typing -Datamodel_T = typing.Type +Datamodel_T = typing.TypeVar('Datamodel_T') dataclass_instance = typing.Any @@ -34,11 +34,11 @@ class StateBuilder: return method return self.build_generic - def build_generic(self, type_to_build: typing.Type): + def build_generic(self, type_to_build: typing.Type) -> "DataclassInstanceState": return DataclassInstanceState(type_to_build, state_builder=self) -class DataclassState: +class DataclassState(typing.Generic[Datamodel_T]): def __init__(self, state_builder=StateBuilder()): with self._object_setattr(): self._state_builder = state_builder @@ -54,7 +54,7 @@ class DataclassState: yield object.__setattr__(self, '_use_base_setattr', False) - def dcs_instance(self): + def dcs_instance(self) -> Datamodel_T: """ Return the instance that this state represents. The instance returned is immutable, so it is advised to call this method each time that @@ -81,7 +81,7 @@ class DataclassState: """ yield - def dcs_update_from(self, data: dataclass_instance): + def dcs_update_from(self, data: Datamodel_T): """ Update the state based on the values of the given dataclass instance. @@ -94,7 +94,7 @@ class DataclassState: """ - def dcs_set_instance_type(self, instance_dataclass: Datamodel_T): + def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]): """ Update the current instance of the state to this type. @@ -103,7 +103,7 @@ class DataclassState: """ -class DataclassInstanceState(DataclassState): +class DataclassInstanceState(DataclassState[Datamodel_T]): """ Represents the state of a frozen dataclass. No type checking of the attributes is attempted. @@ -142,11 +142,12 @@ class DataclassInstanceState(DataclassState): #: The instance of dataclass which this state represents. Undefined until #: sufficient data is provided. - self._instance = None - self._data = {} - self._observers: typing.List[callable] = [] + self._instance: typing.Optional[Datamodel_T] = None + #: The underlying state data mapping attribute name to object. + self._data: typing.Dict[str, typing.Any] = {} + self._observers: typing.List[typing.Callable] = [] self._state_builder = state_builder - self._held_events = [] + self._held_events: typing.List[bool] = [] self._hold_fire = False self.dcs_set_instance_type(dataclass) @@ -169,7 +170,7 @@ class DataclassInstanceState(DataclassState): self._held_events.clear() self._fire_observers() - def dcs_update_from(self, data: dataclass_instance): + def dcs_update_from(self, data: Datamodel_T): with self.dcs_state_transaction(): self.dcs_set_instance_type(data.__class__) for field in dataclasses.fields(data): @@ -225,7 +226,7 @@ class DataclassInstanceState(DataclassState): else: raise AttributeError(f"No attribute {attr_name} on a {self._instance_type.__name__}") - def dcs_set_instance_type(self, instance_dataclass: Datamodel_T): + def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]): if not dataclasses.is_dataclass(instance_dataclass): raise TypeError("The given class is not a valid dataclass") if not issubclass(instance_dataclass, self._base): @@ -266,7 +267,7 @@ class DataclassInstanceState(DataclassState): return self._instance -class DataclassStatePredefined(DataclassInstanceState): +class DataclassStatePredefined(DataclassInstanceState[Datamodel_T]): """ Only a pre-defined selection of states for the given type are allowed. Selected by name (the keys in the dictionary). @@ -278,18 +279,20 @@ class DataclassStatePredefined(DataclassInstanceState): """ def __init__(self, dataclass: Datamodel_T, - choices: typing.Dict[str, dataclass_instance], + choices: typing.Dict[str, Datamodel_T], **kwargs, ): super().__init__(dataclass=dataclass, **kwargs) + # Pick the first choice until we know otherwise. + default_selection = list(choices.keys())[0] with self._object_setattr(): self._choices = choices - self._selected = None - # Pick the first choice until we know otherwise. - self.dcs_select(list(choices.keys())[0]) + self._selected: str = default_selection - def dcs_select(self, name: typing.Hashable): + self.dcs_select(default_selection) + + def dcs_select(self, name: str): if name not in self._choices: raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._choices)}') self._selected = name @@ -309,24 +312,23 @@ class DataclassStatePredefined(DataclassInstanceState): return dataclasses.asdict(self.dcs_instance()) -class DataclassStateNamed(DataclassState): +class DataclassStateNamed(DataclassState[Datamodel_T]): """ A collection of instances of the given type, switchable by name, but each instance is still mutable. """ def __init__(self, - states: typing.Dict[typing.Hashable, DataclassState], + states: typing.Dict[str, DataclassState], **kwargs ): # TODO: This is effectively a container type. We shouldn't use the standard constructor for this. enabled = list(states.keys())[0] - t = states[enabled] super().__init__(**kwargs) with self._object_setattr(): self._states = states.copy() - self._selected = None + self._selected: str = enabled # Pick the first choice until we know otherwise. self.dcs_select(enabled) @@ -348,7 +350,7 @@ class DataclassStateNamed(DataclassState): return object.__setattr__(self, name, value) setattr(self._selected_state(), name, value) - def dcs_select(self, name: typing.Hashable): + def dcs_select(self, name: str): if name not in self._states: raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._states)}') self._selected = name @@ -369,10 +371,10 @@ class DataclassStateNamed(DataclassState): for state in self._states.values(): state.dcs_observe(callback) - def dcs_update_from(self, data: dataclass_instance): + def dcs_update_from(self, data: Datamodel_T): return self._selected_state().dcs_update_from(data) - def dcs_set_instance_type(self, instance_dataclass: Datamodel_T): + def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]): return self._selected_state().dcs_set_instance_type(instance_dataclass) @contextmanager