""" This module is entirely in support of providing a convenient mutable counterpart to frozen dataclasses. Significant effort went into to trying to use traitlets for this purpose, but the need to define class-level attributes proved to be a limitation that meant we could not mutate the state from one subclass to another after the state was instantiated. This module MUST not import other parts of caimira as this would point at a leaky abstraction. """ from contextlib import contextmanager import dataclasses import typing Datamodel_T = typing.TypeVar('Datamodel_T') dataclass_instance = typing.Any class StateBuilder: def visit(self, field: dataclasses.Field): builder = self.resolve_builder(field) return builder(field.type) def resolve_builder(self, field: dataclasses.Field): method_name = [ f'build_name_{field.name}', f'build_type_{field.type.__name__}', ] for name in method_name: method = getattr(self, name, None) if method is not None: return method return self.build_generic def build_generic(self, type_to_build: typing.Type) -> "DataclassInstanceState": return DataclassInstanceState(type_to_build, state_builder=self) class DataclassState(typing.Generic[Datamodel_T]): def __init__(self, state_builder=StateBuilder()): with self._object_setattr(): self._state_builder = state_builder @contextmanager def _object_setattr(self): """ For the lifetime of this contextmanager, don't do anything other than standard object.__setattr__ when setting attributes. """ object.__setattr__(self, '_use_base_setattr', True) yield object.__setattr__(self, '_use_base_setattr', False) def dcs_instance(self) -> typing.Union[None, Datamodel_T]: """ Return the instance that this state represents. The instance returned is immutable, so it is advised to call this method each time that you want the instance so that it reflects the most up-to-date state. """ pass def dcs_observe(self, callback: typing.Callable): """ If any changes are made to the state, call the given callback. """ pass @contextmanager def dcs_state_transaction(self): """ For the lifetime of this context manager, do not fire observer notifications. If any notifications would have been fired during the lifetime of this context manager, then an event will be fired once exiting the context. """ yield def dcs_update_from(self, data: Datamodel_T): """ Update the state based on the values of the given dataclass instance. """ pass def _dcs_set_value(self, attr_name, value): """ Set the state of the given attribute to the given value. """ def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]): """ Update the current instance of the state to this type. Note: This currently wipes all downstream observers. """ class DataclassInstanceState(DataclassState[Datamodel_T]): """ Represents the state of a frozen dataclass. No type checking of the attributes is attempted. Setting the state can be done with: setattr(state, attr, value) Accessing the instance that this state represents can be done with: state.dcs_instance() Changing the type to a subclass of the base that this state represents can be done with: state.dcs_set_instance_type(ASubclassOfBase) """ def __init__(self, dataclass: typing.Type[Datamodel_T], state_builder=StateBuilder()): super().__init__(state_builder=state_builder) # Note that the constructor does *not* insert any data by default. It # therefore doesn't build nested DataclassState instances when a dataclass contains another. # For that, use the build classmethod. if not dataclasses.is_dataclass(dataclass): raise TypeError("The given class is not a valid dataclass") if not isinstance(dataclass, type): raise TypeError("A dataclass type must be provided, not an instance of one") with self._object_setattr(): #: The base instance which this state must support. self._base = dataclass #: The actual instance type that this state represents (i.e. may be a #: subclass of _base). self._instance_type: typing.Type[dataclass_instance] = dataclass #: The instance of dataclass which this state represents. Undefined until #: sufficient data is provided. self._instance: typing.Optional[Datamodel_T] = None #: The underlying state data mapping attribute name to object. self._data: typing.Dict[str, typing.Any] = {} self._observers: typing.List[typing.Callable] = [] self._state_builder = state_builder self._held_events: typing.List[bool] = [] self._hold_fire = False self.dcs_set_instance_type(dataclass) def __repr__(self): return f"" def _instance_attrs(self): return [field.name for field in dataclasses.fields(self._instance_type)] def dcs_observe(self, callback: typing.Callable): self._observers.append(callback) @contextmanager def dcs_state_transaction(self): self._hold_fire = True yield self._hold_fire = False if self._held_events: self._held_events.clear() self._fire_observers() def dcs_update_from(self, data: dataclass_instance): with self.dcs_state_transaction(): self.dcs_set_instance_type(data.__class__) for field in dataclasses.fields(data): attr = field.name current_value = self._data.get(attr, None) new_value = getattr(data, attr) if dataclasses.is_dataclass(field.type): assert isinstance(current_value, DataclassState) current_value.dcs_update_from(new_value) else: self._data[attr] = new_value self._fire_observers() def _fire_observers(self): if self._hold_fire: self._held_events.append(True) else: self._instance = None for observer in self._observers: observer() def __getattr__(self, name): try: return object.__getattribute__(self, name) except AttributeError: pass if name in self._data: return self._data[name] elif name in self._instance_attrs(): raise ValueError(f"State not yet set for {name}") else: raise AttributeError(f"Attribute {name} does not exist on {self._instance_type.__name__}") def __setattr__(self, name, value): if name in self.__dict__ or self.__dict__.get('_use_base_setattr', True): return object.__setattr__(self, name, value) if name in self._instance_attrs(): self._dcs_set_value(name, value) def _dcs_set_value(self, attr_name, value): valid_attrs = self._instance_attrs() if isinstance(value, DataclassState): # TODO: We need to check that the value is acceptable. (needs # thinking about) # assert value._base == pass # TODO: Inject some notifications here to tell any holding DataclassState # instances that we've changed. if attr_name in valid_attrs: self._data[attr_name] = value self._fire_observers() else: raise AttributeError(f"No attribute {attr_name} on a {self._instance_type.__name__}") def dcs_set_instance_type(self, instance_dataclass: typing.Type[dataclass_instance]): if not dataclasses.is_dataclass(instance_dataclass): raise TypeError("The given class is not a valid dataclass") if not issubclass(instance_dataclass, self._base): raise TypeError(f"The dataclass type provided ({instance_dataclass}) must be a subclass of the base ({self._base})") self._instance_type = instance_dataclass # TODO: It is possible to cut observer connections by clearing like this. self._data.clear() for field in dataclasses.fields(instance_dataclass): if dataclasses.is_dataclass(field.type): self._data[field.name] = self._state_builder.visit(field) self._data[field.name].dcs_observe(self._fire_observers) def _instance_state(self): # Note: this method should not validate that the args are complete or # overspecified. kwargs = {} for name, data in self._data.items(): if isinstance(data, DataclassState): data = data._instance_state() kwargs[name] = data return kwargs def _instance_kwargs(self): # Note: this method should not validate that the args are complete or # overspecified. kwargs = {} for name, data in self._data.items(): if isinstance(data, DataclassState): data = data.dcs_instance() kwargs[name] = data return kwargs def dcs_instance(self): if self._instance is None: # TODO: Check if we are able to create an instance with our data... self._instance = self._instance_type(**self._instance_kwargs()) return self._instance class DataclassStatePredefined(DataclassInstanceState[Datamodel_T]): """ Only a pre-defined selection of states for the given type are allowed. Selected by name (the keys in the dictionary). You can change the chosen state with: state.dcs_select(name) """ def __init__(self, dataclass: typing.Type[Datamodel_T], choices: typing.Dict[str, Datamodel_T], **kwargs, ): super().__init__(dataclass=dataclass, **kwargs) with self._object_setattr(): self._choices = choices self._selected: str = None # type: ignore # Pick the first choice until we know otherwise. default_selection = list(choices.keys())[0] self.dcs_select(default_selection) def dcs_select(self, name: str): if name not in self._choices: raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._choices)}') self._selected = name self._instance = self._choices[name] self._fire_observers() def dcs_instance(self): return self._choices[self._selected] def __repr__(self): return f"" def _instance_state(self): return dataclasses.asdict(self.dcs_instance()) def _instance_kwargs(self): return dataclasses.asdict(self.dcs_instance()) class DataclassStateNamed(DataclassState[Datamodel_T]): """ A collection of instances of the given type, switchable by name, but each instance is still mutable. """ def __init__(self, states: typing.Dict[str, DataclassState[Datamodel_T]], base_type: str, **kwargs ): # TODO: This is effectively a container type. We shouldn't use the standard constructor for this. super().__init__(**kwargs) with self._object_setattr(): self._states = states.copy() self._selected: str = None # type: ignore self.dcs_select(base_type) def __getattr__(self, name): try: return object.__getattribute__(self, name) except AttributeError: pass return getattr(self._selected_state(), name) # if name in self._data: # return self._data[name] # elif name in self._instance_attrs(): # raise ValueError(f"State not yet set for {name}") # else: # raise AttributeError(f"Attribute {name} does not exist on {self._instance_type.__name__}") def __setattr__(self, name, value): if name in self.__dict__ or self.__dict__.get('_use_base_setattr', True): return object.__setattr__(self, name, value) setattr(self._selected_state(), name, value) def dcs_select(self, name: str): if name not in self._states: raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._states)}') self._selected = name self._fire_observers() def _selected_state(self): return self._states[self._selected] def dcs_instance(self): return self._selected_state().dcs_instance() def __repr__(self): return f"" def dcs_observe(self, callback: typing.Callable): # Note there is no way to observe the selected state change currently. # You can only watch for the individual selected states being changed. for state in self._states.values(): state.dcs_observe(callback) def dcs_update_from(self, data: Datamodel_T): return self._selected_state().dcs_update_from(data) def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]): return self._selected_state().dcs_set_instance_type(instance_dataclass) @contextmanager def dcs_state_transaction(self): orig = [s._hold_fire for s in self._states.values()] for s in self._states.values(): s._hold_fire = True yield for orig_hold, s in zip(orig, self._states.values()): s._hold_fire = orig_hold if s._held_events: s._held_events.clear() s._fire_observers()