Merge branch 'feature/type-checking' into 'master'

Fix the type annotations for the cara codebase

See merge request cara/cara!125
This commit is contained in:
Nicolas Mounet 2021-01-19 09:47:53 +00:00
commit e83371040a
11 changed files with 207 additions and 137 deletions

View file

@ -1,7 +1,6 @@
import html
import json
from pathlib import Path
from typing import Optional, Awaitable
import jinja2
import mistune
@ -14,7 +13,7 @@ from .user import AuthenticatedUser, AnonymousUser
class BaseRequestHandler(RequestHandler):
async def prepare(self) -> Optional[Awaitable[None]]:
async def prepare(self):
"""Called at the beginning of a request before `get`/`post`/etc."""
username = self.request.headers.get("X-ADFS-LOGIN", None)
if username:

View file

@ -171,7 +171,7 @@ class FormData:
def build_model(self) -> models.ExposureModel:
return model_from_form(self)
def ventilation(self) -> models.Ventilation:
def ventilation(self) -> models._VentilationBase:
always_on = models.PeriodicInterval(period=120, duration=120)
# Initializes a ventilation instance as a window if 'natural' is selected, or as a HEPA-filter otherwise
if self.ventilation_type == 'natural':
@ -189,10 +189,12 @@ class FormData:
inside_temp = models.PiecewiseConstant((0, 24), (293,))
outside_temp = data.GenevaTemperatures[month]
ventilation: models.Ventilation
if self.window_type == 'sliding':
ventilation = models.SlidingWindow(
active=window_interval,
inside_temp=inside_temp, outside_temp=outside_temp,
inside_temp=inside_temp,
outside_temp=outside_temp,
window_height=self.window_height,
opening_length=self.opening_distance,
number_of_windows=self.windows_number,
@ -200,7 +202,8 @@ class FormData:
elif self.window_type == 'hinged':
ventilation = models.HingedWindow(
active=window_interval,
inside_temp=inside_temp, outside_temp=outside_temp,
inside_temp=inside_temp,
outside_temp=outside_temp,
window_height=self.window_height,
window_width=self.window_width,
opening_length=self.opening_distance,
@ -218,7 +221,7 @@ class FormData:
if self.hepa_option:
hepa = models.HEPAFilter(active=always_on, q_air_mech=self.hepa_amount)
return models.MultipleVentilation((ventilation,hepa))
return models.MultipleVentilation((ventilation, hepa))
else:
return ventilation
@ -301,7 +304,7 @@ class FormData:
)
return exposed
def _compute_breaks_in_interval(self, start, finish, n_breaks) -> typing.Tuple[typing.Tuple[int, int]]:
def _compute_breaks_in_interval(self, start, finish, n_breaks) -> models.BoundarySequence_t:
break_delay = ((finish - start) - (n_breaks * self.coffee_duration)) // (n_breaks+1)
break_times = []
end = start
@ -311,13 +314,13 @@ class FormData:
break_times.append((begin, end))
return tuple(break_times)
def lunch_break_times(self) -> typing.Tuple[typing.Tuple[int, int]]:
def lunch_break_times(self) -> models.BoundarySequence_t:
result = []
if self.lunch_option:
result.append((self.lunch_start, self.lunch_finish))
return tuple(result)
def coffee_break_times(self) -> typing.Tuple[typing.Tuple[int, int]]:
def coffee_break_times(self) -> models.BoundarySequence_t:
if not self.coffee_breaks:
return ()
if self.lunch_option:
@ -341,7 +344,7 @@ class FormData:
self,
start: int,
finish: int,
breaks: typing.Tuple[typing.Tuple[int, int], ...] = None,
breaks: typing.Optional[models.BoundarySequence_t] = None,
) -> models.Interval:
"""
Calculate the presence interval given the start and end times (in minutes), and
@ -354,14 +357,14 @@ class FormData:
# Order the breaks by their start-time, and ensure that they are monotonic
# and that the start of one break happens after the end of another.
breaks = sorted(breaks, key=lambda break_pair: break_pair[0])
break_boundaries: models.BoundarySequence_t = tuple(sorted(breaks, key=lambda break_pair: break_pair[0]))
for break_start, break_end in breaks:
for break_start, break_end in break_boundaries:
if break_start >= break_end:
raise ValueError("Break ends before it begins.")
prev_break_end = breaks[0][1]
for break_start, break_end in breaks[1:]:
prev_break_end = break_boundaries[0][1]
for break_start, break_end in break_boundaries[1:]:
if prev_break_end >= break_start:
raise ValueError(f"A break starts before another ends ({break_start}, {break_end}, {prev_break_end}).")
prev_break_end = break_end
@ -386,7 +389,7 @@ class FormData:
# 5. The interval straddles the end of the break. Bs <= S < Be <= E
# 6. The interval is entirely after the break. Bs < Be <= S < E
for current_break in breaks:
for current_break in break_boundaries:
if current_time >= finish:
break
@ -475,10 +478,10 @@ def expiration_blend(expiration_weights: typing.Dict[models.Expiration, int]) ->
ejection_factor += np.array(expiration.ejection_factor) * weight
particle_sizes += np.array(expiration.particle_sizes) * weight
return models.Expiration(
ejection_factor=tuple(ejection_factor/total_weight),
particle_sizes=tuple(particle_sizes/total_weight),
)
r_ejection_factor: typing.Tuple[float, float, float, float] = tuple(ejection_factor/total_weight) # type: ignore
r_particle_sizes: typing.Tuple[float, float, float, float] = tuple(particle_sizes/total_weight) # type: ignore
return models.Expiration(ejection_factor=r_ejection_factor, particle_sizes=r_particle_sizes)
def model_from_form(form: FormData) -> models.ExposureModel:

View file

@ -113,9 +113,9 @@ def minutes_to_time(minutes: int) -> str:
return f"{hour_string}:{minute_string}"
def readable_minutes(minutes: int) -> str:
time = minutes
def readable_minutes(minutes: int) -> str:
time = float(minutes)
unit = " minute"
if time % 60 == 0:
time = minutes/60
@ -124,19 +124,20 @@ def readable_minutes(minutes: int) -> str:
unit += "s"
if time.is_integer():
time = "{:0.0f}".format(time)
time_str = "{:0.0f}".format(time)
else:
time = "{0:.2f}".format(time)
time_str = "{0:.2f}".format(time)
return time + unit
return time_str + unit
def non_zero_percentage (percentage: int) -> str:
def non_zero_percentage(percentage: int) -> str:
if percentage < 0.01:
return "<0.01%"
else:
return "{:0.2f}%".format(percentage)
def manufacture_alternative_scenarios(form: FormData) -> typing.Dict[str, models.ExposureModel]:
scenarios = {}
@ -243,7 +244,7 @@ def build_report(model: models.ExposureModel, form: FormData):
cara_templates = Path(__file__).parent.parent / "templates"
calculator_templates = Path(__file__).parent / "templates"
env = jinja2.Environment(
loader=jinja2.FileSystemLoader([cara_templates, calculator_templates]),
loader=jinja2.FileSystemLoader([str(cara_templates), str(calculator_templates)]),
undefined=jinja2.StrictUndefined,
)
env.filters['non_zero_percentage'] = non_zero_percentage

View file

@ -32,9 +32,9 @@ WidgetPairType = typing.Tuple[widgets.Widget, widgets.Widget]
class WidgetGroup:
def __init__(self, label_widget_pairs: typing.Sequence[WidgetPairType]):
self.labels = []
self.widgets = []
def __init__(self, label_widget_pairs: typing.Iterable[WidgetPairType]):
self.labels: typing.List[widgets.Widget] = []
self.widgets: typing.List[widgets.Widget] = []
self.add_pairs(label_widget_pairs)
def set_visible(self, visible: bool):
@ -46,10 +46,10 @@ class WidgetGroup:
widget.layout.visible = False
widget.layout.display = 'none'
def pairs(self) -> typing.Sequence[WidgetPairType]:
def pairs(self) -> typing.Iterable[WidgetPairType]:
return zip(*[self.labels, self.widgets])
def add_pairs(self, label_widget_pairs: typing.Sequence[WidgetPairType]):
def add_pairs(self, label_widget_pairs: typing.Iterable[WidgetPairType]):
labels, widgets_ = zip(*label_widget_pairs)
self.labels.extend(labels)
self.widgets.extend(widgets_)
@ -190,13 +190,13 @@ class ExposureComparissonResult(View):
return ax
def scenarios_updated(self, scenarios: typing.Sequence[ScenarioType], _):
labels, models = zip(*scenarios)
conc_models: typing.Tuple[models.ConcentrationModel] = tuple(
model.concentration_model.dcs_instance() for model in models
updated_labels, updated_models = zip(*scenarios)
conc_models = tuple(
model.concentration_model.dcs_instance() for model in updated_models
)
self.update_plot(conc_models, labels)
self.update_plot(conc_models, updated_labels)
def update_plot(self, conc_models: typing.Tuple[models.ConcentrationModel], labels: typing.Tuple[str]):
def update_plot(self, conc_models: typing.Tuple[models.ConcentrationModel, ...], labels: typing.Tuple[str, ...]):
self.ax.lines.clear()
start, finish = models_start_end(conc_models)
ts = np.linspace(start, finish, num=250)
@ -268,12 +268,14 @@ class ModelWidgets(View):
outside_temp.observe(outsidetemp_change, names=['value'])
auto_width = widgets.Layout(width='auto')
return WidgetGroup([
[
widgets.Label('Outside temperature (℃)', layout=auto_width,),
outside_temp,
],
])
return WidgetGroup(
(
(
widgets.Label('Outside temperature (℃)', layout=auto_width,),
outside_temp,
),
),
)
def _build_window(self, node) -> WidgetGroup:
period = widgets.IntSlider(value=node.active.period, min=0, max=240)
@ -314,24 +316,26 @@ class ModelWidgets(View):
toggle_outsidetemp(outsidetemp_w.value)
auto_width = widgets.Layout(width='auto')
result = WidgetGroup([
[
widgets.Label('Interval between openings (minutes)', layout=auto_width),
period,
],
[
widgets.Label('Duration of opening (minutes)', layout=auto_width),
interval,
],
[
widgets.Label('Inside temperature (℃)', layout=auto_width),
inside_temp,
],
[
widgets.Label('Outside temperature scheme', layout=auto_width),
outsidetemp_w,
]
])
result = WidgetGroup(
(
(
widgets.Label('Interval between openings (minutes)', layout=auto_width),
period,
),
(
widgets.Label('Duration of opening (minutes)', layout=auto_width),
interval,
),
(
widgets.Label('Inside temperature (℃)', layout=auto_width),
inside_temp,
),
(
widgets.Label('Outside temperature scheme', layout=auto_width),
outsidetemp_w,
),
),
)
for sub_group in outsidetemp_widgets.values():
result.add_pairs(sub_group.pairs())
return result
@ -362,9 +366,9 @@ class ModelWidgets(View):
month_choice.observe(on_month_change, names=['value'])
return WidgetGroup(
[
[widgets.Label("Month"), month_choice],
]
(
(widgets.Label("Month"), month_choice),
),
)
def _build_activity(self, node):
@ -414,7 +418,13 @@ class ModelWidgets(View):
[[widgets.Label("Expiration"), expiration_choice]]
)
def _build_ventilation(self, node):
def _build_ventilation(
self,
node: typing.Union[
state.DataclassStateNamed[models.Ventilation],
state.DataclassStateNamed[models.MultipleVentilation],
],
) -> widgets.Widget:
ventilation_widgets = {
'Natural': self._build_window(node._states['Natural']).build(),
'Mechanical': self._build_mechanical(node._states['Mechanical']),
@ -479,14 +489,18 @@ baseline_model = models.ExposureModel(
class CARAStateBuilder(state.StateBuilder):
# Note: The methods in this class must correspond to the *type* of the data classes.
# For example, build_type__VentilationBase is called when dealing with ConcentrationModel
# types as it has a ventilation: _VentilationBase field.
def build_type_Mask(self, _: dataclasses.Field):
return state.DataclassStatePredefined(
models.Mask,
choices=models.Mask.types,
)
def build_type_Ventilation(self, _: dataclasses.Field):
s = state.DataclassStateNamed(
def build_type__VentilationBase(self, _: dataclasses.Field):
s: state.DataclassStateNamed = state.DataclassStateNamed(
states={
'Natural': self.build_generic(models.WindowOpening),
'Mechanical': self.build_generic(models.HVACMechanical),
@ -526,7 +540,7 @@ class ExpertApplication(Controller):
)
self.add_scenario('Scenario 1')
def build_new_model(self):
def build_new_model(self) -> state.DataclassInstanceState[models.ExposureModel]:
default_model = state.DataclassInstanceState(
models.ExposureModel,
state_builder=CARAStateBuilder(),

View file

@ -2,13 +2,9 @@ import dataclasses
import typing
DCInst = typing.TypeVar('T')
def nested_replace(obj: DCInst, new_values: typing.Dict[str, typing.Any]) -> DCInst:
"""
Replace an attribute on a dataclass, much like dataclasses.replace, except it
supports nested replacement definitions. For example:
def nested_replace(obj, new_values: typing.Dict[str, typing.Any]):
"""Replace an attribute on a dataclass, much like dataclasses.replace,
except it supports nested replacement definitions. For example:
>>> new_obj = nested_replace(obj, {'attr1.sub_attr2.sub_sub_attr3': 4})
>>> new_obj.attr1.sub_attr2.sub_sub_attr3

View file

@ -12,6 +12,11 @@ class Room:
volume: float
Time_t = typing.TypeVar('Time_t', float, int)
BoundaryPair_t = typing.Tuple[Time_t, Time_t]
BoundarySequence_t = typing.Union[typing.Tuple[BoundaryPair_t, ...], typing.Tuple]
@dataclass(frozen=True)
class Interval:
"""
@ -26,7 +31,7 @@ class Interval:
start < t <= end
"""
def boundaries(self) -> typing.Tuple[typing.Tuple[float, float], ...]:
def boundaries(self) -> BoundarySequence_t:
return ()
def transition_times(self) -> typing.Set[float]:
@ -48,23 +53,23 @@ class SpecificInterval(Interval):
#: A sequence of times (start, stop), in hours, that the infected person
#: is present. The flattened list of times must be strictly monotonically
#: increasing.
present_times: typing.Tuple[typing.Tuple[float, float], ...]
present_times: BoundarySequence_t
def boundaries(self):
def boundaries(self) -> BoundarySequence_t:
return self.present_times
@dataclass(frozen=True)
class PeriodicInterval(Interval):
#: How often does the interval occur (minutes).
period: int
period: float
#: How long does the interval occur for (minutes).
#: A value greater than :data:`period` signifies the event is permanently
#: occurring, a value of 0 signifies that the event never happens.
duration: int
duration: float
def boundaries(self) -> typing.Tuple[typing.Tuple[float, float], ...]:
def boundaries(self) -> BoundarySequence_t:
if self.period == 0 or self.duration == 0:
return tuple()
result = []
@ -91,38 +96,44 @@ class PiecewiseConstant:
if tuple(sorted(set(self.transition_times))) != self.transition_times:
raise ValueError("transition_times should not contain duplicated elements and should be sorted")
def value(self,time) -> float:
def value(self, time) -> float:
if time <= self.transition_times[0]:
return self.values[0]
if time > self.transition_times[-1]:
elif time > self.transition_times[-1]:
return self.values[-1]
for t1,t2,value in zip(self.transition_times[:-1],
self.transition_times[1:],self.values):
if time > t1 and time <= t2:
return value
for t1, t2, value in zip(self.transition_times[:-1],
self.transition_times[1:], self.values):
if t1 < time <= t2:
break
return value
def interval(self) -> Interval:
# build an Interval object
present_times = []
for t1,t2,value in zip(self.transition_times[:-1],
self.transition_times[1:],self.values):
for t1, t2, value in zip(self.transition_times[:-1],
self.transition_times[1:], self.values):
if value:
present_times.append((t1,t2))
return SpecificInterval(present_times=present_times)
return SpecificInterval(present_times=tuple(present_times))
def refine(self,refine_factor=10):
def refine(self, refine_factor=10):
# build a new PiecewiseConstant object with a refined mesh,
# using a linear interpolation in-between the initial mesh points
refined_times = np.linspace(self.transition_times[0],self.transition_times[-1],
(len(self.transition_times)-1)*refine_factor+1)
return PiecewiseConstant(tuple(refined_times),
tuple(np.interp(refined_times[:-1],self.transition_times,
self.values+(self.values[-1],) ) ) )
refined_times = np.linspace(self.transition_times[0], self.transition_times[-1],
(len(self.transition_times)-1) * refine_factor+1)
return PiecewiseConstant(
tuple(refined_times),
tuple(np.interp(
refined_times[:-1],
self.transition_times,
self.values + (self.values[-1], ),
)),
)
@dataclass(frozen=True)
class Ventilation:
class _VentilationBase:
"""
Represents a mechanism by which air can be exchanged (replaced/filtered)
in a time dependent manner.
@ -133,11 +144,8 @@ class Ventilation:
mechanical air exchange through a filter.
"""
#: The times at which the air exchange is taking place.
active: Interval
def transition_times(self) -> typing.Set[float]:
return self.active.transition_times()
raise NotImplementedError("Subclass must implement")
def air_exchange(self, room: Room, time: float) -> float:
"""
@ -153,7 +161,16 @@ class Ventilation:
@dataclass(frozen=True)
class MultipleVentilation:
class Ventilation(_VentilationBase):
#: The interval in which the ventilation is active.
active: Interval
def transition_times(self) -> typing.Set[float]:
return self.active.transition_times()
@dataclass(frozen=True)
class MultipleVentilation(_VentilationBase):
"""
Represents a mechanism by which air can be exchanged (replaced/filtered)
in a time dependent manner.
@ -161,7 +178,7 @@ class MultipleVentilation:
Group together different sources of ventilations.
"""
ventilations: typing.Tuple[Ventilation, ...]
ventilations: typing.Tuple[_VentilationBase, ...]
def transition_times(self) -> typing.Set[float]:
transitions = set()
@ -174,7 +191,7 @@ class MultipleVentilation:
Returns the rate at which air is being exchanged in the given room
at a given time (in hours).
"""
return sum([ventilation.air_exchange(room,time)
return sum([ventilation.air_exchange(room, time)
for ventilation in self.ventilations])
@ -258,10 +275,10 @@ class HingedWindow(WindowOpening):
horizontal plane).
"""
#: Window width (m).
window_width: float = None
window_width: float = 0.0
def __post_init__(self):
if self.window_width is None:
if not self.window_width > 0:
raise ValueError('window_width must be set')
@property
@ -387,7 +404,10 @@ class Mask:
#: Filtration efficiency of masks when inhaling.
η_inhale: float
particle_sizes: typing.Tuple[float] = (0.8e-4, 1.8e-4, 3.5e-4, 5.5e-4) # In cm.
#: Particle sizes in cm.
particle_sizes: typing.Tuple[float, float, float, float] = (
0.8e-4, 1.8e-4, 3.5e-4, 5.5e-4
)
#: Pre-populated examples of Masks.
types: typing.ClassVar[typing.Dict[str, "Mask"]]
@ -542,7 +562,7 @@ class InfectedPopulation(Population):
@dataclass(frozen=True)
class ConcentrationModel:
room: Room
ventilation: Ventilation
ventilation: _VentilationBase
infected: InfectedPopulation
@property

View file

@ -14,7 +14,7 @@ import dataclasses
import typing
Datamodel_T = typing.Type
Datamodel_T = typing.TypeVar('Datamodel_T')
dataclass_instance = typing.Any
@ -34,11 +34,11 @@ class StateBuilder:
return method
return self.build_generic
def build_generic(self, type_to_build: typing.Type):
def build_generic(self, type_to_build: typing.Type) -> "DataclassInstanceState":
return DataclassInstanceState(type_to_build, state_builder=self)
class DataclassState:
class DataclassState(typing.Generic[Datamodel_T]):
def __init__(self, state_builder=StateBuilder()):
with self._object_setattr():
self._state_builder = state_builder
@ -54,7 +54,7 @@ class DataclassState:
yield
object.__setattr__(self, '_use_base_setattr', False)
def dcs_instance(self):
def dcs_instance(self) -> Datamodel_T:
"""
Return the instance that this state represents. The instance returned
is immutable, so it is advised to call this method each time that
@ -81,7 +81,7 @@ class DataclassState:
"""
yield
def dcs_update_from(self, data: dataclass_instance):
def dcs_update_from(self, data: Datamodel_T):
"""
Update the state based on the values of the given dataclass instance.
@ -94,7 +94,7 @@ class DataclassState:
"""
def dcs_set_instance_type(self, instance_dataclass: Datamodel_T):
def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]):
"""
Update the current instance of the state to this type.
@ -103,7 +103,7 @@ class DataclassState:
"""
class DataclassInstanceState(DataclassState):
class DataclassInstanceState(DataclassState[Datamodel_T]):
"""
Represents the state of a frozen dataclass.
No type checking of the attributes is attempted.
@ -122,7 +122,7 @@ class DataclassInstanceState(DataclassState):
"""
def __init__(self, dataclass: Datamodel_T, state_builder=StateBuilder()):
def __init__(self, dataclass: typing.Type[Datamodel_T], state_builder=StateBuilder()):
super().__init__(state_builder=state_builder)
# Note that the constructor does *not* insert any data by default. It
@ -138,15 +138,16 @@ class DataclassInstanceState(DataclassState):
self._base = dataclass
#: The actual instance type that this state represents (i.e. may be a
#: subclass of _base).
self._instance_type = dataclass
self._instance_type: typing.Type[Datamodel_T] = dataclass
#: The instance of dataclass which this state represents. Undefined until
#: sufficient data is provided.
self._instance = None
self._data = {}
self._observers: typing.List[callable] = []
self._instance: typing.Optional[Datamodel_T] = None
#: The underlying state data mapping attribute name to object.
self._data: typing.Dict[str, typing.Any] = {}
self._observers: typing.List[typing.Callable] = []
self._state_builder = state_builder
self._held_events = []
self._held_events: typing.List[bool] = []
self._hold_fire = False
self.dcs_set_instance_type(dataclass)
@ -169,7 +170,7 @@ class DataclassInstanceState(DataclassState):
self._held_events.clear()
self._fire_observers()
def dcs_update_from(self, data: dataclass_instance):
def dcs_update_from(self, data: Datamodel_T):
with self.dcs_state_transaction():
self.dcs_set_instance_type(data.__class__)
for field in dataclasses.fields(data):
@ -225,7 +226,7 @@ class DataclassInstanceState(DataclassState):
else:
raise AttributeError(f"No attribute {attr_name} on a {self._instance_type.__name__}")
def dcs_set_instance_type(self, instance_dataclass: Datamodel_T):
def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]):
if not dataclasses.is_dataclass(instance_dataclass):
raise TypeError("The given class is not a valid dataclass")
if not issubclass(instance_dataclass, self._base):
@ -266,7 +267,7 @@ class DataclassInstanceState(DataclassState):
return self._instance
class DataclassStatePredefined(DataclassInstanceState):
class DataclassStatePredefined(DataclassInstanceState[Datamodel_T]):
"""
Only a pre-defined selection of states for the given type are allowed.
Selected by name (the keys in the dictionary).
@ -277,19 +278,21 @@ class DataclassStatePredefined(DataclassInstanceState):
"""
def __init__(self,
dataclass: Datamodel_T,
choices: typing.Dict[typing.Hashable, dataclass_instance],
dataclass: typing.Type[Datamodel_T],
choices: typing.Dict[str, Datamodel_T],
**kwargs,
):
super().__init__(dataclass=dataclass, **kwargs)
with self._object_setattr():
self._choices = choices
self._selected = None
# Pick the first choice until we know otherwise.
self.dcs_select(list(choices.keys())[0])
self._selected: str = None # type: ignore
def dcs_select(self, name: typing.Hashable):
# Pick the first choice until we know otherwise.
default_selection = list(choices.keys())[0]
self.dcs_select(default_selection)
def dcs_select(self, name: str):
if name not in self._choices:
raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._choices)}')
self._selected = name
@ -309,14 +312,14 @@ class DataclassStatePredefined(DataclassInstanceState):
return dataclasses.asdict(self.dcs_instance())
class DataclassStateNamed(DataclassState):
class DataclassStateNamed(DataclassState[Datamodel_T]):
"""
A collection of instances of the given type, switchable by name, but each
instance is still mutable.
"""
def __init__(self,
states: typing.Dict[typing.Hashable, DataclassState],
states: typing.Dict[str, DataclassState[Datamodel_T]],
**kwargs
):
# TODO: This is effectively a container type. We shouldn't use the standard constructor for this.
@ -326,7 +329,7 @@ class DataclassStateNamed(DataclassState):
with self._object_setattr():
self._states = states.copy()
self._selected = None
self._selected: str = None # type: ignore
# Pick the first choice until we know otherwise.
self.dcs_select(enabled)
@ -348,7 +351,7 @@ class DataclassStateNamed(DataclassState):
return object.__setattr__(self, name, value)
setattr(self._selected_state(), name, value)
def dcs_select(self, name: typing.Hashable):
def dcs_select(self, name: str):
if name not in self._states:
raise ValueError(f'The choice {name} is not valid. Possible options are {", ".join(self._states)}')
self._selected = name
@ -369,10 +372,10 @@ class DataclassStateNamed(DataclassState):
for state in self._states.values():
state.dcs_observe(callback)
def dcs_update_from(self, data: dataclass_instance):
def dcs_update_from(self, data: Datamodel_T):
return self._selected_state().dcs_update_from(data)
def dcs_set_instance_type(self, instance_dataclass: Datamodel_T):
def dcs_set_instance_type(self, instance_dataclass: typing.Type[Datamodel_T]):
return self._selected_state().dcs_set_instance_type(instance_dataclass)
@contextmanager

View file

@ -19,3 +19,17 @@ def test_generate_report(baseline_form):
report = report_generator.build_report(model, baseline_form)
assert report != ""
@pytest.mark.parametrize(
["test_input", "expected"],
[
[1, '1 minute'],
[2, '2 minutes'],
[60, '1 hour'],
[120, '2 hours'],
[150, '150 minutes'],
],
)
def test_readable_minutes(test_input, expected):
assert report_generator.readable_minutes(test_input) == expected

View file

@ -25,7 +25,7 @@ class DCSimpleSubclass(DCSimple):
@dataclass
class DCOverrideSubclass(DCSimple):
attr1: float
attr1: float # type: ignore
@dataclass

18
setup.cfg Normal file
View file

@ -0,0 +1,18 @@
[tool:pytest]
addopts = --mypy
[mypy]
no_warn_no_return = True
[mypy-matplotlib.*]
ignore_missing_imports = True
[mypy-ipympl.*]
ignore_missing_imports = True
[mypy-ipywidgets.*]
ignore_missing_imports = True
[mypy-mistune.*]
ignore_missing_imports = True

View file

@ -30,7 +30,9 @@ REQUIREMENTS: dict = {
'app': [],
'test': [
'pytest',
'pytest-tornasync', # Unused, but needed because of a downstream dependency.
'pytest-mypy',
'pytest-tornasync',
'numpy-stubs @ git+https://github.com/numpy/numpy-stubs.git',
],
'dev': [
'jupyterlab',