Fix up type annotation issues with the cara.state module.
This commit is contained in:
parent
64fa2b60a4
commit
e9f9b7fde9
2 changed files with 30 additions and 28 deletions
|
|
@ -481,12 +481,12 @@ baseline_model = models.ExposureModel(
|
|||
class CARAStateBuilder(state.StateBuilder):
|
||||
def build_type_Mask(self, _: dataclasses.Field):
|
||||
return state.DataclassStatePredefined(
|
||||
models.Mask,
|
||||
dataclass=models.Mask,
|
||||
choices=models.Mask.types,
|
||||
)
|
||||
|
||||
def build_type_Ventilation(self, _: dataclasses.Field):
|
||||
s = state.DataclassStateNamed(
|
||||
s: state.DataclassStateNamed = state.DataclassStateNamed(
|
||||
states={
|
||||
'Natural': self.build_generic(models.WindowOpening),
|
||||
'Mechanical': self.build_generic(models.HVACMechanical),
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import dataclasses
|
|||
import typing
|
||||
|
||||
|
||||
Datamodel_T = typing.Type
|
||||
Datamodel_T = typing.TypeVar('Datamodel_T')
|
||||
dataclass_instance = typing.Any
|
||||
|
||||
|
||||
|
|
@ -34,11 +34,11 @@ class StateBuilder:
|
|||
return method
|
||||
return self.build_generic
|
||||
|
||||
def build_generic(self, type_to_build: typing.Type):
|
||||
def build_generic(self, type_to_build: typing.Type) -> "DataclassInstanceState":
|
||||
return DataclassInstanceState(type_to_build, state_builder=self)
|
||||
|
||||
|
||||
class DataclassState:
|
||||
class DataclassState(typing.Generic[Datamodel_T]):
|
||||
def __init__(self, state_builder=StateBuilder()):
|
||||
with self._object_setattr():
|
||||
self._state_builder = state_builder
|
||||
|
|
@ -54,7 +54,7 @@ class DataclassState:
|
|||
yield
|
||||
object.__setattr__(self, '_use_base_setattr', False)
|
||||
|
||||
def dcs_instance(self):
|
||||
def dcs_instance(self) -> Datamodel_T:
|
||||
"""
|
||||
Return the instance that this state represents. The instance returned
|
||||
is immutable, so it is advised to call this method each time that
|
||||
|
|
@ -81,7 +81,7 @@ class DataclassState:
|
|||
"""
|
||||
yield
|
||||
|
||||
def dcs_update_from(self, data: dataclass_instance):
|
||||
def dcs_update_from(self, data: Datamodel_T):
|
||||
"""
|
||||
Update the state based on the values of the given dataclass instance.
|
||||
|
||||
|
|
@ -94,7 +94,7 @@ class DataclassState:
|
|||
|
||||
"""
|
||||
|
||||
def dcs_set_instance_type(self, instance_dataclass: Datamodel_T):
|
||||
def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]):
|
||||
"""
|
||||
Update the current instance of the state to this type.
|
||||
|
||||
|
|
@ -103,7 +103,7 @@ class DataclassState:
|
|||
"""
|
||||
|
||||
|
||||
class DataclassInstanceState(DataclassState):
|
||||
class DataclassInstanceState(DataclassState[Datamodel_T]):
|
||||
"""
|
||||
Represents the state of a frozen dataclass.
|
||||
No type checking of the attributes is attempted.
|
||||
|
|
@ -142,11 +142,12 @@ class DataclassInstanceState(DataclassState):
|
|||
|
||||
#: The instance of dataclass which this state represents. Undefined until
|
||||
#: sufficient data is provided.
|
||||
self._instance = None
|
||||
self._data = {}
|
||||
self._observers: typing.List[callable] = []
|
||||
self._instance: typing.Optional[Datamodel_T] = None
|
||||
#: The underlying state data mapping attribute name to object.
|
||||
self._data: typing.Dict[str, typing.Any] = {}
|
||||
self._observers: typing.List[typing.Callable] = []
|
||||
self._state_builder = state_builder
|
||||
self._held_events = []
|
||||
self._held_events: typing.List[bool] = []
|
||||
self._hold_fire = False
|
||||
|
||||
self.dcs_set_instance_type(dataclass)
|
||||
|
|
@ -169,7 +170,7 @@ class DataclassInstanceState(DataclassState):
|
|||
self._held_events.clear()
|
||||
self._fire_observers()
|
||||
|
||||
def dcs_update_from(self, data: dataclass_instance):
|
||||
def dcs_update_from(self, data: Datamodel_T):
|
||||
with self.dcs_state_transaction():
|
||||
self.dcs_set_instance_type(data.__class__)
|
||||
for field in dataclasses.fields(data):
|
||||
|
|
@ -225,7 +226,7 @@ class DataclassInstanceState(DataclassState):
|
|||
else:
|
||||
raise AttributeError(f"No attribute {attr_name} on a {self._instance_type.__name__}")
|
||||
|
||||
def dcs_set_instance_type(self, instance_dataclass: Datamodel_T):
|
||||
def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]):
|
||||
if not dataclasses.is_dataclass(instance_dataclass):
|
||||
raise TypeError("The given class is not a valid dataclass")
|
||||
if not issubclass(instance_dataclass, self._base):
|
||||
|
|
@ -266,7 +267,7 @@ class DataclassInstanceState(DataclassState):
|
|||
return self._instance
|
||||
|
||||
|
||||
class DataclassStatePredefined(DataclassInstanceState):
|
||||
class DataclassStatePredefined(DataclassInstanceState[Datamodel_T]):
|
||||
"""
|
||||
Only a pre-defined selection of states for the given type are allowed.
|
||||
Selected by name (the keys in the dictionary).
|
||||
|
|
@ -278,18 +279,20 @@ class DataclassStatePredefined(DataclassInstanceState):
|
|||
"""
|
||||
def __init__(self,
|
||||
dataclass: Datamodel_T,
|
||||
choices: typing.Dict[str, dataclass_instance],
|
||||
choices: typing.Dict[str, Datamodel_T],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(dataclass=dataclass, **kwargs)
|
||||
|
||||
# Pick the first choice until we know otherwise.
|
||||
default_selection = list(choices.keys())[0]
|
||||
with self._object_setattr():
|
||||
self._choices = choices
|
||||
self._selected = None
|
||||
# Pick the first choice until we know otherwise.
|
||||
self.dcs_select(list(choices.keys())[0])
|
||||
self._selected: str = default_selection
|
||||
|
||||
def dcs_select(self, name: typing.Hashable):
|
||||
self.dcs_select(default_selection)
|
||||
|
||||
def dcs_select(self, name: str):
|
||||
if name not in self._choices:
|
||||
raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._choices)}')
|
||||
self._selected = name
|
||||
|
|
@ -309,24 +312,23 @@ class DataclassStatePredefined(DataclassInstanceState):
|
|||
return dataclasses.asdict(self.dcs_instance())
|
||||
|
||||
|
||||
class DataclassStateNamed(DataclassState):
|
||||
class DataclassStateNamed(DataclassState[Datamodel_T]):
|
||||
"""
|
||||
A collection of instances of the given type, switchable by name, but each
|
||||
instance is still mutable.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
states: typing.Dict[typing.Hashable, DataclassState],
|
||||
states: typing.Dict[str, DataclassState],
|
||||
**kwargs
|
||||
):
|
||||
# TODO: This is effectively a container type. We shouldn't use the standard constructor for this.
|
||||
enabled = list(states.keys())[0]
|
||||
t = states[enabled]
|
||||
super().__init__(**kwargs)
|
||||
|
||||
with self._object_setattr():
|
||||
self._states = states.copy()
|
||||
self._selected = None
|
||||
self._selected: str = enabled
|
||||
# Pick the first choice until we know otherwise.
|
||||
self.dcs_select(enabled)
|
||||
|
||||
|
|
@ -348,7 +350,7 @@ class DataclassStateNamed(DataclassState):
|
|||
return object.__setattr__(self, name, value)
|
||||
setattr(self._selected_state(), name, value)
|
||||
|
||||
def dcs_select(self, name: typing.Hashable):
|
||||
def dcs_select(self, name: str):
|
||||
if name not in self._states:
|
||||
raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._states)}')
|
||||
self._selected = name
|
||||
|
|
@ -369,10 +371,10 @@ class DataclassStateNamed(DataclassState):
|
|||
for state in self._states.values():
|
||||
state.dcs_observe(callback)
|
||||
|
||||
def dcs_update_from(self, data: dataclass_instance):
|
||||
def dcs_update_from(self, data: Datamodel_T):
|
||||
return self._selected_state().dcs_update_from(data)
|
||||
|
||||
def dcs_set_instance_type(self, instance_dataclass: Datamodel_T):
|
||||
def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]):
|
||||
return self._selected_state().dcs_set_instance_type(instance_dataclass)
|
||||
|
||||
@contextmanager
|
||||
|
|
|
|||
Loading…
Reference in a new issue