diff --git a/cara/monte_carlo/models.py b/cara/monte_carlo/models.py index acdc1dad..7348e3be 100644 --- a/cara/monte_carlo/models.py +++ b/cara/monte_carlo/models.py @@ -22,6 +22,19 @@ class MCModelBase(typing.Generic[_ModelType]): """ _base_cls: typing.Type[_ModelType] + @classmethod + def _to_vectorized_form(cls, item, size): + if isinstance(item, SampleableDistribution): + return item.generate_samples(size) + elif isinstance(item, MCModelBase): + # Recurse into other MCModelBase instances by calling their + # build_model method. + return item.build_model(size) + elif isinstance(item, tuple): + return tuple(cls._to_vectorized_form(sub, size) for sub in item) + else: + return item + def build_model(self, size: int) -> _ModelType: """ Turn this MCModelBase subclass into a cara.models Model instance @@ -31,13 +44,7 @@ class MCModelBase(typing.Generic[_ModelType]): kwargs = {} for field in dataclasses.fields(self._base_cls): attr = getattr(self, field.name) - if isinstance(attr, SampleableDistribution): - attr = attr.generate_samples(size) - elif isinstance(attr, MCModelBase): - # Recurse into other MCModelBase instances by calling their - # build_model method. - attr = attr.build_model(size) - kwargs[field.name] = attr + kwargs[field.name] = self._to_vectorized_form(attr, size) return self._base_cls(**kwargs) # type: ignore @@ -53,12 +60,43 @@ def _build_mc_model(model: _ModelType) -> typing.Type[MCModelBase[_ModelType]]: new_field = copy.copy(field) if field.type is cara.models._VectorisedFloat: # noqa new_field.type = _VectorisedFloatOrSampleable # type: ignore - # TODO: Update the type annotation to support the new model classes that exist. - fields.append((new_field.name, new_field.type, new_field)) + + field_type: typing.Any = new_field.type + + if getattr(field_type, '__origin__', None) in [typing.Union, typing.Tuple]: + # It is challenging to generalise this code, so we provide specific transformations, + # and raise for unforseen cases. + if new_field.type == typing.Tuple[cara.models._VentilationBase, ...]: + VB = getattr(sys.modules[__name__], "_VentilationBase") + field_type = typing.Tuple[typing.Union[cara.models._VentilationBase, VB], ...] + elif new_field.type == typing.Tuple[cara.models._ExpirationBase, ...]: + EB = getattr(sys.modules[__name__], "_ExpirationBase") + field_type = typing.Tuple[typing.Union[cara.models._ExpirationBase, EB], ...] + else: + # Check that we don't need to do anything with this type. + for item in new_field.type.__args__: + if getattr(item, '__module__', None) == 'cara.models': + raise ValueError( + f"unsupported type annotation transformation required for {new_field.type}") + elif field_type.__module__ == 'cara.models': + mc_model = getattr(sys.modules[__name__], new_field.type.__name__) + field_type = typing.Union[new_field.type, mc_model] + + fields.append((new_field.name, field_type, new_field)) + + bases = [] + # Update the inheritance/based to use the new MC classes, rather than the cara.models ones. + for model_base in model.__bases__: # type: ignore + if model_base is object: + bases.append(MCModelBase) + else: + mc_model = getattr(sys.modules[__name__], model_base.__name__) + bases.append(mc_model) + cls = dataclasses.make_dataclass( model.__name__, # type: ignore fields, # type: ignore - bases=(MCModelBase, ), + bases=bases, # type: ignore namespace={'_base_cls': model}, # This thing can be mutable - the calculations live on # the wrapped class, not on the MCModelBase.