From a3bd951d36e3176333eec8e87d9ac138ee682eb9 Mon Sep 17 00:00:00 2001 From: Nicolas Mounet Date: Thu, 27 May 2021 12:33:12 +0200 Subject: [PATCH 1/8] Adding tests for MultipleExpiration --- cara/tests/test_expiration.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 cara/tests/test_expiration.py diff --git a/cara/tests/test_expiration.py b/cara/tests/test_expiration.py new file mode 100644 index 00000000..7638d613 --- /dev/null +++ b/cara/tests/test_expiration.py @@ -0,0 +1,29 @@ +import re + +import numpy as np +import numpy.testing as npt +import pytest + +from cara import models + + +def test_multiple_wrong_weight_size(): + weights = (1., 2., 3.) + e_base = models.Expiration((0.084, 0.009, 0.003, 0.002)) + with pytest.raises( + ValueError, + match=re.escape("expirations and weigths should contain the" + "same number of elements") + ): + e = models.MultipleExpiration([e_base, e_base], weights) + + +def test_multiple(): + weights = (1., 2.) + e1 = models.Expiration((0.03, 0.02, 0.01, 0.005)) + e2 = models.Expiration((0.05, 0.04, 0.03, 0.01)) + e = models.MultipleExpiration([e1, e2], weights) + assert e.aerosols(models.Mask.types['No mask']) == ( + e1.aerosols(models.Mask.types['No mask'])/3. + + 2*e2.aerosols(models.Mask.types['No mask'])/3. + ) From 720bf1a56af54d7e5f952fde445e1ff063956298 Mon Sep 17 00:00:00 2001 From: Nicolas Mounet Date: Thu, 27 May 2021 12:34:14 +0200 Subject: [PATCH 2/8] Adapting tests for model_generator --- cara/tests/apps/calculator/test_model_generator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cara/tests/apps/calculator/test_model_generator.py b/cara/tests/apps/calculator/test_model_generator.py index 6c5050e4..eb216727 100644 --- a/cara/tests/apps/calculator/test_model_generator.py +++ b/cara/tests/apps/calculator/test_model_generator.py @@ -8,6 +8,7 @@ from cara.apps.calculator.model_generator import minutes_since_midnight from cara import models from cara import data import numpy as np +import numpy.testing as npt def test_model_from_dict(baseline_form_data): @@ -24,10 +25,11 @@ def test_model_from_dict_invalid(baseline_form_data): def test_blend_expiration(): blend = {'Breathing': 2, 'Talking': 1} r = model_generator.build_expiration(blend) + mask = models.Mask.types['Type I'] expected = models.Expiration( (0.13466666666666668, 0.02866666666666667, 0.004333333333333334, 0.005) ) - assert r == expected + npt.assert_almost_equal(r.aerosols(mask), expected.aerosols(mask)) def test_ventilation_slidingwindow(baseline_form: model_generator.FormData): From b8422aaf1bf805697e0a7abcf2e671a73f54264d Mon Sep 17 00:00:00 2001 From: Nicolas Mounet Date: Thu, 27 May 2021 13:40:46 +0200 Subject: [PATCH 3/8] Introducing _ExpirationBase and MultipleExpiration classes; adapting tests and model_generator accordingly (removing now obsolete expiration_blend function) --- cara/apps/calculator/model_generator.py | 32 +++------------ cara/models.py | 53 +++++++++++++++++++++---- 2 files changed, 51 insertions(+), 34 deletions(-) diff --git a/cara/apps/calculator/model_generator.py b/cara/apps/calculator/model_generator.py index fb34392d..ad65795e 100644 --- a/cara/apps/calculator/model_generator.py +++ b/cara/apps/calculator/model_generator.py @@ -537,38 +537,16 @@ class FormData: ) -def build_expiration(expiration_definition) -> models.Expiration: +def build_expiration(expiration_definition) -> models._ExpirationBase: if isinstance(expiration_definition, str): - return models.Expiration.types[expiration_definition] + return models._ExpirationBase.types[expiration_definition] elif isinstance(expiration_definition, dict): - return expiration_blend({ - build_expiration(exp): amount - for exp, amount in expiration_definition.items() - } + return models.MultipleExpiration( + tuple([build_expiration(exp) for exp in expiration_definition.keys()]), + tuple(expiration_definition.values()) ) -def expiration_blend(expiration_weights: typing.Dict[models.Expiration, int]) -> models.Expiration: - """ - Combine together multiple types of Expiration, using a weighted mean to - compute their ejection factor and particle sizes. - - """ - ejection_factor = np.zeros(4) - particle_sizes = np.zeros(4) - - total_weight = 0 - for expiration, weight in expiration_weights.items(): - total_weight += weight - ejection_factor += np.array(expiration.ejection_factor) * weight - particle_sizes += np.array(expiration.particle_sizes) * weight - - r_ejection_factor: typing.Tuple[float, float, float, float] = tuple(ejection_factor/total_weight) # type: ignore - r_particle_sizes: typing.Tuple[float, float, float, float] = tuple(particle_sizes/total_weight) # type: ignore - - return models.Expiration(ejection_factor=r_ejection_factor, particle_sizes=r_particle_sizes) - - def model_from_form(form: FormData) -> models.ExposureModel: # Initializes room with volume either given directly or as product of area and height if form.volume_type == 'room_volume_explicit': diff --git a/cara/models.py b/cara/models.py index 0943bd87..ecbccad3 100644 --- a/cara/models.py +++ b/cara/models.py @@ -511,14 +511,29 @@ Mask.types = { @dataclass(frozen=True) -class Expiration: +class _ExpirationBase: + """ + Represents the expiration of aerosols by a person. + Subclasses of _ExpirationBase represent different models. + """ + #: Pre-populated examples of Masks. + types: typing.ClassVar[typing.Dict[str, "_ExpirationBase"]] + + def aerosols(self, mask: _MaskBase): + # total volume of aerosols expired (cm^3). + raise NotImplementedError("Subclass must implement") + + +@dataclass(frozen=True) +class Expiration(_ExpirationBase): + """ + Simple model based on four different sizes of particles emitted, + with different ejection factors. + """ ejection_factor: typing.Tuple[float, ...] particle_sizes: typing.Tuple[float, ...] = (0.8e-4, 1.8e-4, 3.5e-4, 5.5e-4) # In cm. - #: Pre-populated examples of Expiration. - types: typing.ClassVar[typing.Dict[str, "Expiration"]] - - def aerosols(self, mask: Mask): + def aerosols(self, mask: _MaskBase): def volume(diameter): return (4 * np.pi * (diameter/2)**3) / 3 total = 0 @@ -529,7 +544,31 @@ class Expiration: return total -Expiration.types = { +@dataclass(frozen=True) +class MultipleExpiration(_ExpirationBase): + """ + Represents an expiration of aerosols. + Group together different modes of expiration, that represent + each the main expiration mode for a certain fraction of time (given by + the weights). + + """ + expirations: typing.Tuple[_ExpirationBase, ...] + weights: typing.Tuple[float, ...] + + def __post_init__(self): + if len(self.expirations) != len(self.weights): + raise ValueError("expirations and weigths should contain the" + "same number of elements") + + def aerosols(self, mask: _MaskBase): + return np.array([ + weight * expiration.aerosols(mask) / sum(self.weights) + for weight,expiration in zip(self.weights,self.expirations) + ]).sum(axis=0) + + +_ExpirationBase.types = { 'Breathing': Expiration((0.084, 0.009, 0.003, 0.002)), 'Whispering': Expiration((0.11, 0.014, 0.004, 0.002)), 'Talking': Expiration((0.236, 0.068, 0.007, 0.011)), @@ -585,7 +624,7 @@ class InfectedPopulation(Population): virus: Virus #: The type of expiration that is being emitted whilst doing the activity. - expiration: Expiration + expiration: _ExpirationBase def emission_rate_when_present(self) -> _VectorisedFloat: """ From f050214237a94446cc84053bb1069d2eaf500da4 Mon Sep 17 00:00:00 2001 From: Nicolas Mounet Date: Fri, 28 May 2021 06:56:13 +0200 Subject: [PATCH 4/8] Improving docstrings in expiration classes --- cara/models.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cara/models.py b/cara/models.py index ecbccad3..3f7e629a 100644 --- a/cara/models.py +++ b/cara/models.py @@ -520,7 +520,7 @@ class _ExpirationBase: types: typing.ClassVar[typing.Dict[str, "_ExpirationBase"]] def aerosols(self, mask: _MaskBase): - # total volume of aerosols expired (cm^3). + # total volume of aerosols expired per volume of air (mL/cm^3). raise NotImplementedError("Subclass must implement") @@ -528,7 +528,10 @@ class _ExpirationBase: class Expiration(_ExpirationBase): """ Simple model based on four different sizes of particles emitted, - with different ejection factors. + with different ejection factors. See Fig. 4 in L. Morawska et al, + Size distribution and sites of origin of droplets expelled from the + human respiratory tract during expiratory activities, + Aerosol Science 40 (2009) pp. 256 - 269. """ ejection_factor: typing.Tuple[float, ...] particle_sizes: typing.Tuple[float, ...] = (0.8e-4, 1.8e-4, 3.5e-4, 5.5e-4) # In cm. From 8ac3b4cbd966bc091c2ef70acce16adc436bf4e4 Mon Sep 17 00:00:00 2001 From: Phil Elson Date: Fri, 28 May 2021 10:50:33 +0200 Subject: [PATCH 5/8] Avoid the use of the updates from #184. --- cara/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cara/models.py b/cara/models.py index 3f7e629a..c2500d68 100644 --- a/cara/models.py +++ b/cara/models.py @@ -519,7 +519,7 @@ class _ExpirationBase: #: Pre-populated examples of Masks. types: typing.ClassVar[typing.Dict[str, "_ExpirationBase"]] - def aerosols(self, mask: _MaskBase): + def aerosols(self, mask: Mask): # total volume of aerosols expired per volume of air (mL/cm^3). raise NotImplementedError("Subclass must implement") @@ -536,7 +536,7 @@ class Expiration(_ExpirationBase): ejection_factor: typing.Tuple[float, ...] particle_sizes: typing.Tuple[float, ...] = (0.8e-4, 1.8e-4, 3.5e-4, 5.5e-4) # In cm. - def aerosols(self, mask: _MaskBase): + def aerosols(self, mask: Mask): def volume(diameter): return (4 * np.pi * (diameter/2)**3) / 3 total = 0 @@ -564,7 +564,7 @@ class MultipleExpiration(_ExpirationBase): raise ValueError("expirations and weigths should contain the" "same number of elements") - def aerosols(self, mask: _MaskBase): + def aerosols(self, mask: Mask): return np.array([ weight * expiration.aerosols(mask) / sum(self.weights) for weight,expiration in zip(self.weights,self.expirations) From 9669e5afd0f331c80188a1bb3ec5c2b1d58622b0 Mon Sep 17 00:00:00 2001 From: Phil Elson Date: Fri, 16 Apr 2021 07:47:31 +0200 Subject: [PATCH 6/8] Add a cara.monte_carlo submodule as syntactic sugar on top of the existing cara.models vectorisation. This allows us to define SampleableDistributions for key variables (and in the future, good default values for these), as well as giving us an exact mirror of the non-MC models which we can ultimately generate those models. --- cara/monte_carlo/__init__.py | 1 + cara/monte_carlo/__init__.pyi | 4 ++ cara/monte_carlo/models.py | 85 ++++++++++++++++++++++++++++++++++ cara/monte_carlo/sampleable.py | 29 ++++++++++++ cara/tests/test_monte_carlo.py | 83 +++++++++++++++++++++++++++++++++ 5 files changed, 202 insertions(+) create mode 100644 cara/monte_carlo/__init__.py create mode 100644 cara/monte_carlo/__init__.pyi create mode 100644 cara/monte_carlo/models.py create mode 100644 cara/monte_carlo/sampleable.py create mode 100644 cara/tests/test_monte_carlo.py diff --git a/cara/monte_carlo/__init__.py b/cara/monte_carlo/__init__.py new file mode 100644 index 00000000..aed4fa32 --- /dev/null +++ b/cara/monte_carlo/__init__.py @@ -0,0 +1 @@ +from .models import * diff --git a/cara/monte_carlo/__init__.pyi b/cara/monte_carlo/__init__.pyi new file mode 100644 index 00000000..5a184e17 --- /dev/null +++ b/cara/monte_carlo/__init__.pyi @@ -0,0 +1,4 @@ +from typing import Any + +# For now we disable all type-checking in the monte-carlo submodule. +def __getattr__(name) -> Any: ... diff --git a/cara/monte_carlo/models.py b/cara/monte_carlo/models.py new file mode 100644 index 00000000..acdc1dad --- /dev/null +++ b/cara/monte_carlo/models.py @@ -0,0 +1,85 @@ +import copy +import dataclasses +import sys +import typing + +import cara.models + +from .sampleable import SampleableDistribution, _VectorisedFloatOrSampleable + + +_ModelType = typing.TypeVar('_ModelType') + + +class MCModelBase(typing.Generic[_ModelType]): + """ + A model base class for monte carlo types. + + This base class is essentially a declarative description of a cara.models + model with a :meth:`.build_model` method to generate an appropriate + ``cara.models` model instance on demand. + + """ + _base_cls: typing.Type[_ModelType] + + def build_model(self, size: int) -> _ModelType: + """ + Turn this MCModelBase subclass into a cara.models Model instance + from which you can then run the model. + + """ + kwargs = {} + for field in dataclasses.fields(self._base_cls): + attr = getattr(self, field.name) + if isinstance(attr, SampleableDistribution): + attr = attr.generate_samples(size) + elif isinstance(attr, MCModelBase): + # Recurse into other MCModelBase instances by calling their + # build_model method. + attr = attr.build_model(size) + kwargs[field.name] = attr + return self._base_cls(**kwargs) # type: ignore + + +def _build_mc_model(model: _ModelType) -> typing.Type[MCModelBase[_ModelType]]: + """ + Generate a new MCModelBase subclass for the given cara.models model. + + """ + fields = [] + for field in dataclasses.fields(model): + # Note: deepcopy not needed here as we aren't mutating entities beyond + # the top level. + new_field = copy.copy(field) + if field.type is cara.models._VectorisedFloat: # noqa + new_field.type = _VectorisedFloatOrSampleable # type: ignore + # TODO: Update the type annotation to support the new model classes that exist. + fields.append((new_field.name, new_field.type, new_field)) + cls = dataclasses.make_dataclass( + model.__name__, # type: ignore + fields, # type: ignore + bases=(MCModelBase, ), + namespace={'_base_cls': model}, + # This thing can be mutable - the calculations live on + # the wrapped class, not on the MCModelBase. + frozen=False, + ) + # Update the module of the generated class to be this one. Without this the + # module will be "types". + cls.__module__ = __name__ + return cls + + +_MODEL_CLASSES = [ + cls for cls in vars(cara.models).values() + if dataclasses.is_dataclass(cls) +] + + +# Inject the runtime generated MC types into this module. +for _model in _MODEL_CLASSES: + setattr(sys.modules[__name__], _model.__name__, _build_mc_model(_model)) + + +# Make sure that each of the models is imported if you do a ``import *``. +__all__ = [_model.__name__ for _model in _MODEL_CLASSES] + ["MCModelBase"] diff --git a/cara/monte_carlo/sampleable.py b/cara/monte_carlo/sampleable.py new file mode 100644 index 00000000..6f53c43c --- /dev/null +++ b/cara/monte_carlo/sampleable.py @@ -0,0 +1,29 @@ +import typing + +import numpy as np + +import cara.models + + +# Declare a float array type of a given size. +# There is no better way to declare this currently, unfortunately. +float_array_size_n = np.ndarray + + +class SampleableDistribution: + def generate_samples(self, size: int) -> float_array_size_n: + raise NotImplementedError() + + +class Normal(SampleableDistribution): + def __init__(self, mean: float, scale: float): + self.mean = mean + self.scale = scale + + def generate_samples(self, size: int) -> float_array_size_n: + return np.random.normal(self.mean, self.scale, size=size) + + +_VectorisedFloatOrSampleable = typing.Union[ + SampleableDistribution, cara.models._VectorisedFloat, +] diff --git a/cara/tests/test_monte_carlo.py b/cara/tests/test_monte_carlo.py new file mode 100644 index 00000000..28f7dc28 --- /dev/null +++ b/cara/tests/test_monte_carlo.py @@ -0,0 +1,83 @@ +import dataclasses + +import pytest + +import cara.models +import cara.monte_carlo.models as mc_models +import cara.monte_carlo.sampleable + + +MODEL_CLASSES = [ + cls for cls in vars(cara.models).values() + if dataclasses.is_dataclass(cls) +] + + +def test_type_annotations(): + # Check that there are appropriate type annotations for all of the model + # classes in cara.models. Note that these must be statically defined in + # cara.monte_carlo, rather than being dynamically generated, in order to + # allow the type system to be able to see their definition without needing + # runtime execution. + missing = [] + for cls in MODEL_CLASSES: + if not hasattr(cara.monte_carlo, cls.__name__): + missing.append(cls.__name__) + continue + mc_cls = getattr(cara.monte_carlo, cls.__name__) + assert issubclass(mc_cls, cara.monte_carlo.MCModelBase) + + if missing: + msg = ( + 'There are missing model implementations in cara.monte_carlo. ' + 'The following definitions are needed:\n ' + + '\n '.join([f'{model} = build_mc_model(cara.models.{model})' for model in missing]) + ) + pytest.fail(msg) + + +@pytest.fixture +def baseline_mc_model() -> cara.monte_carlo.ConcentrationModel: + mc_model = cara.monte_carlo.ConcentrationModel( + room=cara.monte_carlo.Room(volume=cara.monte_carlo.sampleable.Normal(75, 20)), + ventilation=cara.monte_carlo.SlidingWindow( + active=cara.models.PeriodicInterval(period=120, duration=120), + inside_temp=cara.models.PiecewiseConstant((0, 24), (293,)), + outside_temp=cara.models.PiecewiseConstant((0, 24), (283,)), + window_height=1.6, opening_length=0.6, + ), + infected=cara.models.InfectedPopulation( + number=1, + virus=cara.models.Virus.types['SARS_CoV_2'], + presence=cara.models.SpecificInterval(((0, 4), (5, 8))), + mask=cara.models.Mask.types['No mask'], + activity=cara.models.Activity.types['Light activity'], + expiration=cara.models.Expiration.types['Unmodulated Vocalization'], + ), + ) + return mc_model + + +@pytest.fixture +def baseline_mc_exposure_model(baseline_mc_model) -> cara.monte_carlo.ExposureModel: + return cara.monte_carlo.ExposureModel( + baseline_mc_model, + exposed=cara.models.Population( + number=10, + presence=baseline_mc_model.infected.presence, + activity=baseline_mc_model.infected.activity, + mask=baseline_mc_model.infected.mask, + ) + ) + + +def test_build_concentration_model(baseline_mc_model: cara.monte_carlo.ConcentrationModel): + model = baseline_mc_model.build_model(7) + assert isinstance(model, cara.models.ConcentrationModel) + assert isinstance(model.concentration(time=0), float) + assert model.concentration(time=1).shape == (7, ) + + +def test_build_exposure_model(baseline_mc_exposure_model: cara.monte_carlo.ExposureModel): + model = baseline_mc_exposure_model.build_model(7) + assert isinstance(model, cara.models.ExposureModel) From 38fe6e734e56168692479167960fc58d0c87c2d2 Mon Sep 17 00:00:00 2001 From: Phil Elson Date: Fri, 28 May 2021 17:23:43 +0200 Subject: [PATCH 7/8] Review actions for monte carlo models. --- cara/monte_carlo/sampleable.py | 6 +++--- cara/tests/test_monte_carlo.py | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/cara/monte_carlo/sampleable.py b/cara/monte_carlo/sampleable.py index 6f53c43c..4ed49d82 100644 --- a/cara/monte_carlo/sampleable.py +++ b/cara/monte_carlo/sampleable.py @@ -16,12 +16,12 @@ class SampleableDistribution: class Normal(SampleableDistribution): - def __init__(self, mean: float, scale: float): + def __init__(self, mean: float, standard_deviation: float): self.mean = mean - self.scale = scale + self.standard_deviation = standard_deviation def generate_samples(self, size: int) -> float_array_size_n: - return np.random.normal(self.mean, self.scale, size=size) + return np.random.normal(self.mean, self.standard_deviation, size=size) _VectorisedFloatOrSampleable = typing.Union[ diff --git a/cara/tests/test_monte_carlo.py b/cara/tests/test_monte_carlo.py index 28f7dc28..f3d9fb08 100644 --- a/cara/tests/test_monte_carlo.py +++ b/cara/tests/test_monte_carlo.py @@ -1,5 +1,6 @@ import dataclasses +import numpy as np import pytest import cara.models @@ -81,3 +82,6 @@ def test_build_concentration_model(baseline_mc_model: cara.monte_carlo.Concentra def test_build_exposure_model(baseline_mc_exposure_model: cara.monte_carlo.ExposureModel): model = baseline_mc_exposure_model.build_model(7) assert isinstance(model, cara.models.ExposureModel) + prob = model.quanta_exposure() + assert isinstance(prob, np.ndarray) + assert prob.shape == (7, ) From 604422fbb5fcc424a4f490f79d28d0f2733e0160 Mon Sep 17 00:00:00 2001 From: Phil Elson Date: Fri, 28 May 2021 17:34:47 +0200 Subject: [PATCH 8/8] Improve the type handling of the MC model generation. This is tested more thoroughly later when generating type stubs. --- cara/monte_carlo/models.py | 58 +++++++++++++++++++++++++++++++------- 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/cara/monte_carlo/models.py b/cara/monte_carlo/models.py index acdc1dad..7348e3be 100644 --- a/cara/monte_carlo/models.py +++ b/cara/monte_carlo/models.py @@ -22,6 +22,19 @@ class MCModelBase(typing.Generic[_ModelType]): """ _base_cls: typing.Type[_ModelType] + @classmethod + def _to_vectorized_form(cls, item, size): + if isinstance(item, SampleableDistribution): + return item.generate_samples(size) + elif isinstance(item, MCModelBase): + # Recurse into other MCModelBase instances by calling their + # build_model method. + return item.build_model(size) + elif isinstance(item, tuple): + return tuple(cls._to_vectorized_form(sub, size) for sub in item) + else: + return item + def build_model(self, size: int) -> _ModelType: """ Turn this MCModelBase subclass into a cara.models Model instance @@ -31,13 +44,7 @@ class MCModelBase(typing.Generic[_ModelType]): kwargs = {} for field in dataclasses.fields(self._base_cls): attr = getattr(self, field.name) - if isinstance(attr, SampleableDistribution): - attr = attr.generate_samples(size) - elif isinstance(attr, MCModelBase): - # Recurse into other MCModelBase instances by calling their - # build_model method. - attr = attr.build_model(size) - kwargs[field.name] = attr + kwargs[field.name] = self._to_vectorized_form(attr, size) return self._base_cls(**kwargs) # type: ignore @@ -53,12 +60,43 @@ def _build_mc_model(model: _ModelType) -> typing.Type[MCModelBase[_ModelType]]: new_field = copy.copy(field) if field.type is cara.models._VectorisedFloat: # noqa new_field.type = _VectorisedFloatOrSampleable # type: ignore - # TODO: Update the type annotation to support the new model classes that exist. - fields.append((new_field.name, new_field.type, new_field)) + + field_type: typing.Any = new_field.type + + if getattr(field_type, '__origin__', None) in [typing.Union, typing.Tuple]: + # It is challenging to generalise this code, so we provide specific transformations, + # and raise for unforseen cases. + if new_field.type == typing.Tuple[cara.models._VentilationBase, ...]: + VB = getattr(sys.modules[__name__], "_VentilationBase") + field_type = typing.Tuple[typing.Union[cara.models._VentilationBase, VB], ...] + elif new_field.type == typing.Tuple[cara.models._ExpirationBase, ...]: + EB = getattr(sys.modules[__name__], "_ExpirationBase") + field_type = typing.Tuple[typing.Union[cara.models._ExpirationBase, EB], ...] + else: + # Check that we don't need to do anything with this type. + for item in new_field.type.__args__: + if getattr(item, '__module__', None) == 'cara.models': + raise ValueError( + f"unsupported type annotation transformation required for {new_field.type}") + elif field_type.__module__ == 'cara.models': + mc_model = getattr(sys.modules[__name__], new_field.type.__name__) + field_type = typing.Union[new_field.type, mc_model] + + fields.append((new_field.name, field_type, new_field)) + + bases = [] + # Update the inheritance/based to use the new MC classes, rather than the cara.models ones. + for model_base in model.__bases__: # type: ignore + if model_base is object: + bases.append(MCModelBase) + else: + mc_model = getattr(sys.modules[__name__], model_base.__name__) + bases.append(mc_model) + cls = dataclasses.make_dataclass( model.__name__, # type: ignore fields, # type: ignore - bases=(MCModelBase, ), + bases=bases, # type: ignore namespace={'_base_cls': model}, # This thing can be mutable - the calculations live on # the wrapped class, not on the MCModelBase.