Ensure that piecewise constant values all have the same shape.

This commit is contained in:
Phil Elson 2021-04-07 10:25:00 +02:00
parent eccc57702d
commit 910ed4f3db
3 changed files with 13 additions and 1 deletions

View file

@ -142,6 +142,9 @@ class PiecewiseConstant:
raise ValueError("transition_times should contain one more element than values")
if tuple(sorted(set(self.transition_times))) != self.transition_times:
raise ValueError("transition_times should not contain duplicated elements and should be sorted")
shapes = [np.array(v).shape for v in self.values]
if not all(shapes[0] == shape for shape in shapes):
raise ValueError("All values must have the same shape")
def value(self, time) -> _VectorisedFloat:
if time <= self.transition_times[0]:

View file

@ -14,6 +14,13 @@ def test_piecewiseconstantfunction_wrongarguments():
# unsorted transition times are not allowed
pytest.raises(ValueError, models.PiecewiseConstant, (2, 0), (0, 0))
# If vectors, must all be same length.
with pytest.raises(ValueError, match="All values must have the same shape"):
models.PiecewiseConstant(
(0, 8, 16), (np.array([5, 7]), np.array([8, 9, 10])),
)
@pytest.mark.parametrize(
"time, expected_value",

View file

@ -60,7 +60,9 @@ def test_hinged_window(baseline_hingedwindow, window_width,
{'window_height': np.array([0.15, 0.20])},
{'window_width': np.array([0.15, 0.20])},
{'opening_length': np.array([0.15, 0.20])},
{'outside_temp': models.PiecewiseConstant((1, 2, 3), (np.array([20, 30]), 25))},
{'outside_temp': models.PiecewiseConstant(
(1, 2, 3), (np.array([20, 30]), np.array([25, 30]))),
},
{'outside_temp': np.array([20, 30])},
{'inside_temp': np.array([20, 30])},
]