Improve the type handling of the MC model generation. This is tested more thoroughly later when generating type stubs.
This commit is contained in:
parent
38fe6e734e
commit
604422fbb5
1 changed files with 48 additions and 10 deletions
|
|
@ -22,6 +22,19 @@ class MCModelBase(typing.Generic[_ModelType]):
|
|||
"""
|
||||
_base_cls: typing.Type[_ModelType]
|
||||
|
||||
@classmethod
|
||||
def _to_vectorized_form(cls, item, size):
|
||||
if isinstance(item, SampleableDistribution):
|
||||
return item.generate_samples(size)
|
||||
elif isinstance(item, MCModelBase):
|
||||
# Recurse into other MCModelBase instances by calling their
|
||||
# build_model method.
|
||||
return item.build_model(size)
|
||||
elif isinstance(item, tuple):
|
||||
return tuple(cls._to_vectorized_form(sub, size) for sub in item)
|
||||
else:
|
||||
return item
|
||||
|
||||
def build_model(self, size: int) -> _ModelType:
|
||||
"""
|
||||
Turn this MCModelBase subclass into a cara.models Model instance
|
||||
|
|
@ -31,13 +44,7 @@ class MCModelBase(typing.Generic[_ModelType]):
|
|||
kwargs = {}
|
||||
for field in dataclasses.fields(self._base_cls):
|
||||
attr = getattr(self, field.name)
|
||||
if isinstance(attr, SampleableDistribution):
|
||||
attr = attr.generate_samples(size)
|
||||
elif isinstance(attr, MCModelBase):
|
||||
# Recurse into other MCModelBase instances by calling their
|
||||
# build_model method.
|
||||
attr = attr.build_model(size)
|
||||
kwargs[field.name] = attr
|
||||
kwargs[field.name] = self._to_vectorized_form(attr, size)
|
||||
return self._base_cls(**kwargs) # type: ignore
|
||||
|
||||
|
||||
|
|
@ -53,12 +60,43 @@ def _build_mc_model(model: _ModelType) -> typing.Type[MCModelBase[_ModelType]]:
|
|||
new_field = copy.copy(field)
|
||||
if field.type is cara.models._VectorisedFloat: # noqa
|
||||
new_field.type = _VectorisedFloatOrSampleable # type: ignore
|
||||
# TODO: Update the type annotation to support the new model classes that exist.
|
||||
fields.append((new_field.name, new_field.type, new_field))
|
||||
|
||||
field_type: typing.Any = new_field.type
|
||||
|
||||
if getattr(field_type, '__origin__', None) in [typing.Union, typing.Tuple]:
|
||||
# It is challenging to generalise this code, so we provide specific transformations,
|
||||
# and raise for unforseen cases.
|
||||
if new_field.type == typing.Tuple[cara.models._VentilationBase, ...]:
|
||||
VB = getattr(sys.modules[__name__], "_VentilationBase")
|
||||
field_type = typing.Tuple[typing.Union[cara.models._VentilationBase, VB], ...]
|
||||
elif new_field.type == typing.Tuple[cara.models._ExpirationBase, ...]:
|
||||
EB = getattr(sys.modules[__name__], "_ExpirationBase")
|
||||
field_type = typing.Tuple[typing.Union[cara.models._ExpirationBase, EB], ...]
|
||||
else:
|
||||
# Check that we don't need to do anything with this type.
|
||||
for item in new_field.type.__args__:
|
||||
if getattr(item, '__module__', None) == 'cara.models':
|
||||
raise ValueError(
|
||||
f"unsupported type annotation transformation required for {new_field.type}")
|
||||
elif field_type.__module__ == 'cara.models':
|
||||
mc_model = getattr(sys.modules[__name__], new_field.type.__name__)
|
||||
field_type = typing.Union[new_field.type, mc_model]
|
||||
|
||||
fields.append((new_field.name, field_type, new_field))
|
||||
|
||||
bases = []
|
||||
# Update the inheritance/based to use the new MC classes, rather than the cara.models ones.
|
||||
for model_base in model.__bases__: # type: ignore
|
||||
if model_base is object:
|
||||
bases.append(MCModelBase)
|
||||
else:
|
||||
mc_model = getattr(sys.modules[__name__], model_base.__name__)
|
||||
bases.append(mc_model)
|
||||
|
||||
cls = dataclasses.make_dataclass(
|
||||
model.__name__, # type: ignore
|
||||
fields, # type: ignore
|
||||
bases=(MCModelBase, ),
|
||||
bases=bases, # type: ignore
|
||||
namespace={'_base_cls': model},
|
||||
# This thing can be mutable - the calculations live on
|
||||
# the wrapped class, not on the MCModelBase.
|
||||
|
|
|
|||
Loading…
Reference in a new issue