cara/cara/state.py
Phil Elson c8f39dc26b Track down the issue with the expert app and typing.
The visitor pattern in the StateBuilder means that it is easy to miss type changes in the dataclass definition.
2021-01-06 16:21:33 +01:00

391 lines
14 KiB
Python

"""
This module is entirely in support of providing a convenient mutable counterpart
to frozen dataclasses. Significant effort went into to trying to use traitlets
for this purpose, but the need to define class-level attributes proved to be a
limitation that meant we could not mutate the state from one subclass to another
after the state was instantiated.
This module MUST not import other parts of cara as this would point at a
leaky abstraction.
"""
from contextlib import contextmanager
import dataclasses
import typing
Datamodel_T = typing.TypeVar('Datamodel_T')
dataclass_instance = typing.Any
class StateBuilder:
def visit(self, field: dataclasses.Field):
builder = self.resolve_builder(field)
return builder(field.type)
def resolve_builder(self, field: dataclasses.Field):
method_name = [
f'build_name_{field.name}',
f'build_type_{field.type.__name__}',
]
for name in method_name:
method = getattr(self, name, None)
if method is not None:
return method
return self.build_generic
def build_generic(self, type_to_build: typing.Type) -> "DataclassInstanceState":
return DataclassInstanceState(type_to_build, state_builder=self)
class DataclassState(typing.Generic[Datamodel_T]):
def __init__(self, state_builder=StateBuilder()):
with self._object_setattr():
self._state_builder = state_builder
@contextmanager
def _object_setattr(self):
"""
For the lifetime of this contextmanager, don't do anything other than
standard object.__setattr__ when setting attributes.
"""
object.__setattr__(self, '_use_base_setattr', True)
yield
object.__setattr__(self, '_use_base_setattr', False)
def dcs_instance(self) -> Datamodel_T:
"""
Return the instance that this state represents. The instance returned
is immutable, so it is advised to call this method each time that
you want the instance so that it reflects the most up-to-date state.
"""
pass
def dcs_observe(self, callback: typing.Callable):
"""
If any changes are made to the state, call the given callback.
"""
pass
@contextmanager
def dcs_state_transaction(self):
"""
For the lifetime of this context manager, do not fire observer
notifications. If any notifications would have been fired during the
lifetime of this context manager, then an event will be fired once
exiting the context.
"""
yield
def dcs_update_from(self, data: Datamodel_T):
"""
Update the state based on the values of the given dataclass instance.
"""
pass
def _dcs_set_value(self, attr_name, value):
"""
Set the state of the given attribute to the given value.
"""
def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]):
"""
Update the current instance of the state to this type.
Note: This currently wipes all downstream observers.
"""
class DataclassInstanceState(DataclassState[Datamodel_T]):
"""
Represents the state of a frozen dataclass.
No type checking of the attributes is attempted.
Setting the state can be done with:
setattr(state, attr, value)
Accessing the instance that this state represents can be done with:
state.dcs_instance()
Changing the type to a subclass of the base that this state represents can be done with:
state.dcs_set_instance_type(ASubclassOfBase)
"""
def __init__(self, dataclass: typing.Type[Datamodel_T], state_builder=StateBuilder()):
super().__init__(state_builder=state_builder)
# Note that the constructor does *not* insert any data by default. It
# therefore doesn't build nested DataclassState instances when a dataclass contains another.
# For that, use the build classmethod.
if not dataclasses.is_dataclass(dataclass):
raise TypeError("The given class is not a valid dataclass")
if not isinstance(dataclass, type):
raise TypeError("A dataclass type must be provided, not an instance of one")
with self._object_setattr():
#: The base instance which this state must support.
self._base = dataclass
#: The actual instance type that this state represents (i.e. may be a
#: subclass of _base).
self._instance_type: typing.Type[Datamodel_T] = dataclass
#: The instance of dataclass which this state represents. Undefined until
#: sufficient data is provided.
self._instance: typing.Optional[Datamodel_T] = None
#: The underlying state data mapping attribute name to object.
self._data: typing.Dict[str, typing.Any] = {}
self._observers: typing.List[typing.Callable] = []
self._state_builder = state_builder
self._held_events: typing.List[bool] = []
self._hold_fire = False
self.dcs_set_instance_type(dataclass)
def __repr__(self):
return f"<state for {self._instance_type.__name__}(**{self._instance_state()})>"
def _instance_attrs(self):
return [field.name for field in dataclasses.fields(self._instance_type)]
def dcs_observe(self, callback: typing.Callable):
self._observers.append(callback)
@contextmanager
def dcs_state_transaction(self):
self._hold_fire = True
yield
self._hold_fire = False
if self._held_events:
self._held_events.clear()
self._fire_observers()
def dcs_update_from(self, data: Datamodel_T):
with self.dcs_state_transaction():
self.dcs_set_instance_type(data.__class__)
for field in dataclasses.fields(data):
attr = field.name
current_value = self._data.get(attr, None)
new_value = getattr(data, attr)
if dataclasses.is_dataclass(field.type):
assert isinstance(current_value, DataclassState)
current_value.dcs_update_from(new_value)
else:
self._data[attr] = new_value
self._fire_observers()
def _fire_observers(self):
if self._hold_fire:
self._held_events.append(True)
else:
self._instance = None
for observer in self._observers:
observer()
def __getattr__(self, name):
try:
return object.__getattribute__(self, name)
except AttributeError:
pass
if name in self._data:
return self._data[name]
elif name in self._instance_attrs():
raise ValueError(f"State not yet set for {name}")
else:
raise AttributeError(f"Attribute {name} does not exist on {self._instance_type.__name__}")
def __setattr__(self, name, value):
if name in self.__dict__ or self.__dict__.get('_use_base_setattr', True):
return object.__setattr__(self, name, value)
if name in self._instance_attrs():
self._dcs_set_value(name, value)
def _dcs_set_value(self, attr_name, value):
valid_attrs = self._instance_attrs()
if isinstance(value, DataclassState):
# TODO: We need to check that the value is acceptable. (needs
# thinking about)
# assert value._base ==
pass
# TODO: Inject some notifications here to tell any holding DataclassState
# instances that we've changed.
if attr_name in valid_attrs:
self._data[attr_name] = value
self._fire_observers()
else:
raise AttributeError(f"No attribute {attr_name} on a {self._instance_type.__name__}")
def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]):
if not dataclasses.is_dataclass(instance_dataclass):
raise TypeError("The given class is not a valid dataclass")
if not issubclass(instance_dataclass, self._base):
raise TypeError(f"The dataclass type provided ({instance_dataclass}) must be a subclass of the base ({self._base})")
self._instance_type = instance_dataclass
# TODO: It is possible to cut observer connections by clearing like this.
self._data.clear()
for field in dataclasses.fields(instance_dataclass):
if dataclasses.is_dataclass(field.type):
self._data[field.name] = self._state_builder.visit(field)
self._data[field.name].dcs_observe(self._fire_observers)
def _instance_state(self):
# Note: this method should not validate that the args are complete or
# overspecified.
kwargs = {}
for name, data in self._data.items():
if isinstance(data, DataclassState):
data = data._instance_state()
kwargs[name] = data
return kwargs
def _instance_kwargs(self):
# Note: this method should not validate that the args are complete or
# overspecified.
kwargs = {}
for name, data in self._data.items():
if isinstance(data, DataclassState):
data = data.dcs_instance()
kwargs[name] = data
return kwargs
def dcs_instance(self):
if self._instance is None:
# TODO: Check if we are able to create an instance with our data...
self._instance = self._instance_type(**self._instance_kwargs())
return self._instance
class DataclassStatePredefined(DataclassInstanceState[Datamodel_T]):
"""
Only a pre-defined selection of states for the given type are allowed.
Selected by name (the keys in the dictionary).
You can change the chosen state with:
state.dcs_select(name)
"""
def __init__(self,
dataclass: typing.Type[Datamodel_T],
choices: typing.Dict[str, Datamodel_T],
**kwargs,
):
super().__init__(dataclass=dataclass, **kwargs)
with self._object_setattr():
self._choices = choices
self._selected: str = None # type: ignore
# Pick the first choice until we know otherwise.
default_selection = list(choices.keys())[0]
self.dcs_select(default_selection)
def dcs_select(self, name: str):
if name not in self._choices:
raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._choices)}')
self._selected = name
self._instance = self._choices[name]
self._fire_observers()
def dcs_instance(self):
return self._choices[self._selected]
def __repr__(self):
return f"<state for {self._instance_type.__name__}. '{self._selected}' selected>"
def _instance_state(self):
return dataclasses.asdict(self.dcs_instance())
def _instance_kwargs(self):
return dataclasses.asdict(self.dcs_instance())
class DataclassStateNamed(DataclassState[Datamodel_T]):
"""
A collection of instances of the given type, switchable by name, but each
instance is still mutable.
"""
def __init__(self,
states: typing.Dict[str, DataclassState[Datamodel_T]],
**kwargs
):
# TODO: This is effectively a container type. We shouldn't use the standard constructor for this.
enabled = list(states.keys())[0]
t = states[enabled]
super().__init__(**kwargs)
with self._object_setattr():
self._states = states.copy()
self._selected: str = None # type: ignore
# Pick the first choice until we know otherwise.
self.dcs_select(enabled)
def __getattr__(self, name):
try:
return object.__getattribute__(self, name)
except AttributeError:
pass
return getattr(self._selected_state(), name)
# if name in self._data:
# return self._data[name]
# elif name in self._instance_attrs():
# raise ValueError(f"State not yet set for {name}")
# else:
# raise AttributeError(f"Attribute {name} does not exist on {self._instance_type.__name__}")
def __setattr__(self, name, value):
if name in self.__dict__ or self.__dict__.get('_use_base_setattr', True):
return object.__setattr__(self, name, value)
setattr(self._selected_state(), name, value)
def dcs_select(self, name: str):
if name not in self._states:
raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._states)}')
self._selected = name
self._fire_observers()
def _selected_state(self):
return self._states[self._selected]
def dcs_instance(self):
return self._selected_state().dcs_instance()
def __repr__(self):
return f"<state for {self._instance_type.__name__}. Holding {len(self._states)} state(s). '{self._selected}' selected>"
def dcs_observe(self, callback: typing.Callable):
# Note there is no way to observe the selected state change currently.
# You can only watch for the individual selected states being changed.
for state in self._states.values():
state.dcs_observe(callback)
def dcs_update_from(self, data: Datamodel_T):
return self._selected_state().dcs_update_from(data)
def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]):
return self._selected_state().dcs_set_instance_type(instance_dataclass)
@contextmanager
def dcs_state_transaction(self):
orig = [s._hold_fire for s in self._states.values()]
for s in self._states.values():
s._hold_fire = True
yield
for orig_hold, s in zip(orig, self._states.values()):
s._hold_fire = orig_hold
if s._held_events:
s._held_events.clear()
s._fire_observers()