From 910ed4f3db4112804ba15a3c5ca97376a929172b Mon Sep 17 00:00:00 2001 From: Phil Elson Date: Wed, 7 Apr 2021 10:25:00 +0200 Subject: [PATCH] Ensure that piecewise constant values all have the same shape. --- cara/models.py | 3 +++ cara/tests/models/test_piecewiseconstant.py | 7 +++++++ cara/tests/test_ventilation.py | 4 +++- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/cara/models.py b/cara/models.py index b5a32d47..3d9a730e 100644 --- a/cara/models.py +++ b/cara/models.py @@ -142,6 +142,9 @@ class PiecewiseConstant: raise ValueError("transition_times should contain one more element than values") if tuple(sorted(set(self.transition_times))) != self.transition_times: raise ValueError("transition_times should not contain duplicated elements and should be sorted") + shapes = [np.array(v).shape for v in self.values] + if not all(shapes[0] == shape for shape in shapes): + raise ValueError("All values must have the same shape") def value(self, time) -> _VectorisedFloat: if time <= self.transition_times[0]: diff --git a/cara/tests/models/test_piecewiseconstant.py b/cara/tests/models/test_piecewiseconstant.py index fabd6296..a0ba14f0 100644 --- a/cara/tests/models/test_piecewiseconstant.py +++ b/cara/tests/models/test_piecewiseconstant.py @@ -14,6 +14,13 @@ def test_piecewiseconstantfunction_wrongarguments(): # unsorted transition times are not allowed pytest.raises(ValueError, models.PiecewiseConstant, (2, 0), (0, 0)) + # If vectors, must all be same length. + with pytest.raises(ValueError, match="All values must have the same shape"): + models.PiecewiseConstant( + (0, 8, 16), (np.array([5, 7]), np.array([8, 9, 10])), + ) + + @pytest.mark.parametrize( "time, expected_value", diff --git a/cara/tests/test_ventilation.py b/cara/tests/test_ventilation.py index 73bb86f9..624203ef 100644 --- a/cara/tests/test_ventilation.py +++ b/cara/tests/test_ventilation.py @@ -60,7 +60,9 @@ def test_hinged_window(baseline_hingedwindow, window_width, {'window_height': np.array([0.15, 0.20])}, {'window_width': np.array([0.15, 0.20])}, {'opening_length': np.array([0.15, 0.20])}, - {'outside_temp': models.PiecewiseConstant((1, 2, 3), (np.array([20, 30]), 25))}, + {'outside_temp': models.PiecewiseConstant( + (1, 2, 3), (np.array([20, 30]), np.array([25, 30]))), + }, {'outside_temp': np.array([20, 30])}, {'inside_temp': np.array([20, 30])}, ]