diff --git a/cara/apps/calculator/model_generator.py b/cara/apps/calculator/model_generator.py index 2e10b1f2..c6eeeeae 100644 --- a/cara/apps/calculator/model_generator.py +++ b/cara/apps/calculator/model_generator.py @@ -185,6 +185,10 @@ class FormData: return form_dict def validate(self): + # Validate number of infected <= number of total people + if self.infected_people > self.total_people: + raise ValueError('Number of infected people should be less than number of total people.') + # Validate time intervals selected by user time_intervals = [ ['exposed_start', 'exposed_finish'], diff --git a/cara/tests/apps/calculator/test_model_generator.py b/cara/tests/apps/calculator/test_model_generator.py index c1d1f2c2..534d038d 100644 --- a/cara/tests/apps/calculator/test_model_generator.py +++ b/cara/tests/apps/calculator/test_model_generator.py @@ -167,6 +167,13 @@ def test_ventilation_window_hepa(baseline_form: model_generator.FormData): assert ventilation == baseline_vent +def test_infected_less_than_total_people(baseline_form: model_generator.FormData): + baseline_form.total_people = 10 + baseline_form.infected_people = 11 + with pytest.raises(ValueError, match='Number of infected people should be less than number of total people.'): + baseline_form.validate() + + def present_times(interval: models.Interval) -> models.BoundarySequence_t: assert isinstance(interval, models.SpecificInterval) return interval.present_times