Fix up type annotation issues with the cara.state module.

This commit is contained in:
Phil Elson 2021-01-05 16:46:43 +01:00
parent 64fa2b60a4
commit e9f9b7fde9
2 changed files with 30 additions and 28 deletions

View file

@ -481,12 +481,12 @@ baseline_model = models.ExposureModel(
class CARAStateBuilder(state.StateBuilder):
def build_type_Mask(self, _: dataclasses.Field):
return state.DataclassStatePredefined(
models.Mask,
dataclass=models.Mask,
choices=models.Mask.types,
)
def build_type_Ventilation(self, _: dataclasses.Field):
s = state.DataclassStateNamed(
s: state.DataclassStateNamed = state.DataclassStateNamed(
states={
'Natural': self.build_generic(models.WindowOpening),
'Mechanical': self.build_generic(models.HVACMechanical),

View file

@ -14,7 +14,7 @@ import dataclasses
import typing
Datamodel_T = typing.Type
Datamodel_T = typing.TypeVar('Datamodel_T')
dataclass_instance = typing.Any
@ -34,11 +34,11 @@ class StateBuilder:
return method
return self.build_generic
def build_generic(self, type_to_build: typing.Type):
def build_generic(self, type_to_build: typing.Type) -> "DataclassInstanceState":
return DataclassInstanceState(type_to_build, state_builder=self)
class DataclassState:
class DataclassState(typing.Generic[Datamodel_T]):
def __init__(self, state_builder=StateBuilder()):
with self._object_setattr():
self._state_builder = state_builder
@ -54,7 +54,7 @@ class DataclassState:
yield
object.__setattr__(self, '_use_base_setattr', False)
def dcs_instance(self):
def dcs_instance(self) -> Datamodel_T:
"""
Return the instance that this state represents. The instance returned
is immutable, so it is advised to call this method each time that
@ -81,7 +81,7 @@ class DataclassState:
"""
yield
def dcs_update_from(self, data: dataclass_instance):
def dcs_update_from(self, data: Datamodel_T):
"""
Update the state based on the values of the given dataclass instance.
@ -94,7 +94,7 @@ class DataclassState:
"""
def dcs_set_instance_type(self, instance_dataclass: Datamodel_T):
def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]):
"""
Update the current instance of the state to this type.
@ -103,7 +103,7 @@ class DataclassState:
"""
class DataclassInstanceState(DataclassState):
class DataclassInstanceState(DataclassState[Datamodel_T]):
"""
Represents the state of a frozen dataclass.
No type checking of the attributes is attempted.
@ -142,11 +142,12 @@ class DataclassInstanceState(DataclassState):
#: The instance of dataclass which this state represents. Undefined until
#: sufficient data is provided.
self._instance = None
self._data = {}
self._observers: typing.List[callable] = []
self._instance: typing.Optional[Datamodel_T] = None
#: The underlying state data mapping attribute name to object.
self._data: typing.Dict[str, typing.Any] = {}
self._observers: typing.List[typing.Callable] = []
self._state_builder = state_builder
self._held_events = []
self._held_events: typing.List[bool] = []
self._hold_fire = False
self.dcs_set_instance_type(dataclass)
@ -169,7 +170,7 @@ class DataclassInstanceState(DataclassState):
self._held_events.clear()
self._fire_observers()
def dcs_update_from(self, data: dataclass_instance):
def dcs_update_from(self, data: Datamodel_T):
with self.dcs_state_transaction():
self.dcs_set_instance_type(data.__class__)
for field in dataclasses.fields(data):
@ -225,7 +226,7 @@ class DataclassInstanceState(DataclassState):
else:
raise AttributeError(f"No attribute {attr_name} on a {self._instance_type.__name__}")
def dcs_set_instance_type(self, instance_dataclass: Datamodel_T):
def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]):
if not dataclasses.is_dataclass(instance_dataclass):
raise TypeError("The given class is not a valid dataclass")
if not issubclass(instance_dataclass, self._base):
@ -266,7 +267,7 @@ class DataclassInstanceState(DataclassState):
return self._instance
class DataclassStatePredefined(DataclassInstanceState):
class DataclassStatePredefined(DataclassInstanceState[Datamodel_T]):
"""
Only a pre-defined selection of states for the given type are allowed.
Selected by name (the keys in the dictionary).
@ -278,18 +279,20 @@ class DataclassStatePredefined(DataclassInstanceState):
"""
def __init__(self,
dataclass: Datamodel_T,
choices: typing.Dict[str, dataclass_instance],
choices: typing.Dict[str, Datamodel_T],
**kwargs,
):
super().__init__(dataclass=dataclass, **kwargs)
# Pick the first choice until we know otherwise.
default_selection = list(choices.keys())[0]
with self._object_setattr():
self._choices = choices
self._selected = None
# Pick the first choice until we know otherwise.
self.dcs_select(list(choices.keys())[0])
self._selected: str = default_selection
def dcs_select(self, name: typing.Hashable):
self.dcs_select(default_selection)
def dcs_select(self, name: str):
if name not in self._choices:
raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._choices)}')
self._selected = name
@ -309,24 +312,23 @@ class DataclassStatePredefined(DataclassInstanceState):
return dataclasses.asdict(self.dcs_instance())
class DataclassStateNamed(DataclassState):
class DataclassStateNamed(DataclassState[Datamodel_T]):
"""
A collection of instances of the given type, switchable by name, but each
instance is still mutable.
"""
def __init__(self,
states: typing.Dict[typing.Hashable, DataclassState],
states: typing.Dict[str, DataclassState],
**kwargs
):
# TODO: This is effectively a container type. We shouldn't use the standard constructor for this.
enabled = list(states.keys())[0]
t = states[enabled]
super().__init__(**kwargs)
with self._object_setattr():
self._states = states.copy()
self._selected = None
self._selected: str = enabled
# Pick the first choice until we know otherwise.
self.dcs_select(enabled)
@ -348,7 +350,7 @@ class DataclassStateNamed(DataclassState):
return object.__setattr__(self, name, value)
setattr(self._selected_state(), name, value)
def dcs_select(self, name: typing.Hashable):
def dcs_select(self, name: str):
if name not in self._states:
raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._states)}')
self._selected = name
@ -369,10 +371,10 @@ class DataclassStateNamed(DataclassState):
for state in self._states.values():
state.dcs_observe(callback)
def dcs_update_from(self, data: dataclass_instance):
def dcs_update_from(self, data: Datamodel_T):
return self._selected_state().dcs_update_from(data)
def dcs_set_instance_type(self, instance_dataclass: Datamodel_T):
def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]):
return self._selected_state().dcs_set_instance_type(instance_dataclass)
@contextmanager