finished injection of data_service param

This commit is contained in:
Luis Aleixo 2023-12-12 11:58:19 +01:00
parent 20e8bf1df7
commit d79ef934cb
13 changed files with 27 additions and 21 deletions

View file

@ -2,7 +2,6 @@ import dataclasses
import logging
import typing
import numpy as np
from caimira.store.data_registry import DataRegistry
import ruptures as rpt
import matplotlib.pyplot as plt
import re
@ -177,6 +176,7 @@ class CO2FormData(FormData):
for _, stop in zip(all_state_changes[:-1], all_state_changes[1:])]
return models.CO2DataModel(
data_registry=self.data_registry,
room_volume=self.room_volume,
number=models.IntPiecewiseConstant(transition_times=tuple(all_state_changes), values=tuple(total_people)),
presence=None,

View file

@ -212,10 +212,10 @@ class VirusFormData(FormData):
for interaction in self.short_range_interactions:
short_range.append(mc.ShortRangeModel(
data_registry=self.data_registry,
expiration=short_range_expiration_distributions[interaction['expiration']],
expiration=short_range_expiration_distributions(self.data_registry)[interaction['expiration']],
activity=infected_population.activity,
presence=self.short_range_interval(interaction),
distance=short_range_distances,
distance=short_range_distances(self.data_registry),
))
return mc.ExposureModel(
@ -473,7 +473,7 @@ class VirusFormData(FormData):
def build_expiration(data_registry, expiration_definition) -> mc._ExpirationBase:
if isinstance(expiration_definition, str):
return expiration_distributions[expiration_definition]
return expiration_distributions(data_registry)[expiration_definition]
elif isinstance(expiration_definition, dict):
total_weight = sum(expiration_definition.values())
BLO_factors = np.sum([

View file

@ -840,6 +840,7 @@ class ModelWidgets(View):
def baseline_model(data_registry: DataRegistry):
return models.ExposureModel(
data_registry=data_registry,
concentration_model=models.ConcentrationModel(
data_registry=data_registry,
room=models.Room(volume=75, inside_temp=models.PiecewiseConstant((0., 24.), (293.15,))),

View file

@ -22,8 +22,6 @@ def baseline_model(data_registry: DataRegistry):
presence=models.SpecificInterval(((8., 12.), (13., 17.))),
activity=models.Activity.types['Seated'],
),
CO2_atmosphere_concentration=440.44,
CO2_fraction_exhaled=0.042,
)

View file

@ -344,7 +344,7 @@ class SlidingWindow(WindowOpening):
Sliding window, or side-hung window (with the hinge perpendicular to
the horizontal plane).
"""
data_registry: DataRegistry = None
data_registry: DataRegistry = DataRegistry()
@property
def discharge_coefficient(self) -> _VectorisedFloat:
@ -1507,6 +1507,7 @@ class CO2DataModel:
It uses optimization techniques to fit the model's parameters and estimate the exhalation rate and ventilation
values that best match the measured CO2 concentrations.
'''
data_registry: DataRegistry
room_volume: float
number: typing.Union[int, IntPiecewiseConstant]
presence: typing.Optional[Interval]
@ -1518,6 +1519,7 @@ class CO2DataModel:
exhalation_rate: float,
ventilation_values: typing.Tuple[float, ...]) -> typing.List[_VectorisedFloat]:
CO2_concentrations = CO2ConcentrationModel(
data_registry=self.data_registry,
room=Room(volume=self.room_volume),
ventilation=CustomVentilation(PiecewiseConstant(
self.ventilation_transition_times, ventilation_values)),

View file

@ -389,7 +389,7 @@ def expiration_distribution(
BLO_factors,
d_min=0.1,
d_max=30.,
) -> mc.Expiration:
):
"""
Returns an Expiration with an aerosol diameter distribution, defined
by the BLO factors (a length-3 tuple).

View file

@ -35,14 +35,14 @@ def test_model_from_dict_invalid(baseline_form_data, data_registry):
["Cloth"],
]
)
def test_blend_expiration(mask_type):
def test_blend_expiration(data_registry, mask_type):
SAMPLE_SIZE = 250000
TOLERANCE = 0.02
blend = {'Breathing': 2, 'Speaking': 1}
r = model_generator.build_expiration(blend).build_model(SAMPLE_SIZE)
r = model_generator.build_expiration(data_registry, blend).build_model(SAMPLE_SIZE)
mask = models.Mask.types[mask_type]
expected = (expiration_distributions['Breathing'].build_model(SAMPLE_SIZE).aerosols(mask).mean()*2/3. +
expiration_distributions['Speaking'].build_model(SAMPLE_SIZE).aerosols(mask).mean()/3.)
expected = (expiration_distributions(data_registry)['Breathing'].build_model(SAMPLE_SIZE).aerosols(mask).mean()*2/3. +
expiration_distributions(data_registry)['Speaking'].build_model(SAMPLE_SIZE).aerosols(mask).mean()/3.)
npt.assert_allclose(r.aerosols(mask).mean(), expected, rtol=TOLERANCE)
@ -555,6 +555,8 @@ def test_default_types():
raise TypeError(f'{field} has type {field_type}, got {type(value)}')
for field in fields.values():
if field.name == "data_registry":
continue # Skip the assertion for the "data_registry" field
assert field.name in model_generator.VirusFormData._DEFAULTS, f"No default set for field name {field.name}"

View file

@ -60,7 +60,7 @@ def baseline_exposure_model(data_registry, baseline_concentration_model, baselin
@pytest.fixture
def exposure_model_w_outside_temp_changes(baseline_exposure_model: models.ExposureModel):
def exposure_model_w_outside_temp_changes(data_registry, baseline_exposure_model: models.ExposureModel):
exp_model = caimira.dataclass_utils.nested_replace(
baseline_exposure_model, {
'concentration_model.ventilation': models.SlidingWindow(

View file

@ -39,6 +39,7 @@ def test_fitting_algorithm(data_registry, activity_type, ventilation_active, air
# Generate CO2DataModel
data_model = models.CO2DataModel(
data_registry=data_registry,
room_volume=75,
number=models.IntPiecewiseConstant(transition_times=tuple(
[8, 12, 13, 17]), values=tuple([2, 1, 2])),

View file

@ -47,7 +47,7 @@ def baseline_exposure_model(data_registry):
@retry(tries=3)
def test_conditional_prob_inf_given_vl_dist(baseline_exposure_model):
def test_conditional_prob_inf_given_vl_dist(data_registry, baseline_exposure_model):
viral_loads = np.array([3., 5., 7., 9.,])
mc_model: models.ExposureModel = baseline_exposure_model.build_model(2_000_000)
@ -72,7 +72,7 @@ def test_conditional_prob_inf_given_vl_dist(baseline_exposure_model):
specific_vl = np.log10(mc_model.concentration_model.infected.virus.viral_load_in_sputum)
step = 8/100
actual_pi_means, actual_lower_percentiles, actual_upper_percentiles = (
report_generator.conditional_prob_inf_given_vl_dist(infection_probability, viral_loads, specific_vl, step)
report_generator.conditional_prob_inf_given_vl_dist(data_registry, infection_probability, viral_loads, specific_vl, step)
)
assert np.allclose(actual_pi_means, expected_pi_means, atol=0.002)

View file

@ -69,8 +69,9 @@ def baseline_mc_sr_model() -> caimira.monte_carlo.ShortRangeModel:
@pytest.fixture
def baseline_mc_exposure_model(baseline_mc_concentration_model, baseline_mc_sr_model) -> caimira.monte_carlo.ExposureModel:
def baseline_mc_exposure_model(data_registry, baseline_mc_concentration_model, baseline_mc_sr_model) -> caimira.monte_carlo.ExposureModel:
return caimira.monte_carlo.ExposureModel(
data_registry,
baseline_mc_concentration_model,
baseline_mc_sr_model,
exposed=caimira.models.Population(

View file

@ -186,8 +186,8 @@ def skagit_chorale_mc(data_registry):
presence=models.SpecificInterval(((0, 2.5), )),
virus=mc.SARSCoV2(
viral_load_in_sputum=10**9,
infectious_dose=infectious_dose_distribution,
viable_to_RNA_ratio=viable_to_RNA_ratio_distribution,
infectious_dose=infectious_dose_distribution(data_registry),
viable_to_RNA_ratio=viable_to_RNA_ratio_distribution(data_registry),
transmissibility_factor=1.,
),
mask=models.Mask.types['No mask'],
@ -230,8 +230,8 @@ def bus_ride_mc(data_registry):
presence=models.SpecificInterval(((0, 1.67), )),
virus=mc.SARSCoV2(
viral_load_in_sputum=5*10**8,
infectious_dose=infectious_dose_distribution,
viable_to_RNA_ratio=viable_to_RNA_ratio_distribution,
infectious_dose=infectious_dose_distribution(data_registry),
viable_to_RNA_ratio=viable_to_RNA_ratio_distribution(data_registry),
transmissibility_factor=1.,
),
mask=models.Mask.types['No mask'],
@ -411,6 +411,7 @@ def test_small_shared_office_Geneva(data_registry, mask_type, month, expected_pi
evaporation_factor=0.3,
)
exposure_mc = mc.ExposureModel(
data_registry=data_registry,
concentration_model=concentration_mc,
short_range=(),
exposed=mc.Population(

View file

@ -39,7 +39,7 @@ def test_activity_distributions(data_registry, distribution, mean, std):
['SARS_CoV_2_GAMMA', 6.22, 1.80],
]
)
def test_viral_load_logdistribution(distribution, mean, std):
def test_viral_load_logdistribution(data_registry, distribution, mean, std):
virus = virus_distributions(data_registry)[distribution].build_model(size=1000000)
npt.assert_allclose(np.log10(virus.viral_load_in_sputum).mean(), mean, atol=0.01)
npt.assert_allclose(np.log10(virus.viral_load_in_sputum).std(), std, atol=0.01)