From 43a6b6e0a6c7af008b618887c09fe81d7244032f Mon Sep 17 00:00:00 2001 From: Luis Date: Tue, 14 Nov 2023 14:09:43 +0000 Subject: [PATCH] moved interface methods from init file to co2_model_generator --- caimira/apps/calculator/__init__.py | 67 ++----------------- .../apps/calculator/co2_model_generator.py | 60 ++++++++++++++++- 2 files changed, 65 insertions(+), 62 deletions(-) diff --git a/caimira/apps/calculator/__init__.py b/caimira/apps/calculator/__init__.py index 38824ced..28a0839d 100644 --- a/caimira/apps/calculator/__init__.py +++ b/caimira/apps/calculator/__init__.py @@ -18,9 +18,6 @@ import traceback import typing import uuid import zlib -import matplotlib.pyplot as plt -import numpy as np -import ruptures as rpt import jinja2 import loky @@ -30,7 +27,7 @@ import tornado.log from . import markdown_tools from . import model_generator, co2_model_generator -from .report_generator import ReportGenerator, calculate_report_data, img2base64, _figure2bytes +from .report_generator import ReportGenerator, calculate_report_data from .user import AuthenticatedUser, AnonymousUser # The calculator version is based on a combination of the model version and the @@ -345,65 +342,13 @@ class GenericExtraPage(BaseRequestHandler): )) -class CO2Data(BaseRequestHandler): +class CO2ModelResponse(BaseRequestHandler): def check_xsrf_cookie(self): """ This request handler implements a stateless API that returns report data in JSON format. Thus, XSRF cookies are disabled by overriding base class implementation of this method with a pass statement. """ pass - - def find_change_points_with_pelt(self, CO2_data: dict): - """ - Perform change point detection using Pelt algorithm from ruptures library with pen=15. - Returns a list of tuples containing (index, X-axis value) for the detected significant changes. - """ - - times: list = CO2_data['times'] - CO2_values: list = CO2_data['CO2'] - - if len(times) != len(CO2_values): - raise ValueError("times and CO2 values must have the same length.") - - # Convert the input list to a numpy array for use with the ruptures library - CO2_np = np.array(CO2_values) - - # Define the model for change point detection (Radial Basis Function kernel) - model = "rbf" - - # Fit the Pelt algorithm to the data with the specified model - algo = rpt.Pelt(model=model).fit(CO2_np) - - # Predict change points using the Pelt algorithm with a penalty value of 15 - result = algo.predict(pen=15) - - # Find local minima and maxima - segments = np.split(np.arange(len(CO2_values)), result) - merged_segments = [np.hstack((segments[i], segments[i + 1])) for i in range(len(segments) - 1)] - result_set = set() - for segment in merged_segments[:-2]: - result_set.add(times[CO2_values.index(min(CO2_np[segment]))]) - result_set.add(times[CO2_values.index(max(CO2_np[segment]))]) - return list(result_set) - - def generate_ventilation_plot(self, CO2_data: dict, - transition_times: typing.Optional[list] = None, - predictive_CO2: typing.Optional[list] = None): - times_values = CO2_data['times'] - CO2_values = CO2_data['CO2'] - - fig = plt.figure(figsize=(7, 4), dpi=110) - plt.plot(times_values, CO2_values, label='Input CO₂') - - if (transition_times): - for time in transition_times: - plt.axvline(x = time, color = 'grey', linewidth=0.5, linestyle='--') - if (predictive_CO2): - plt.plot(times_values, predictive_CO2, label='Predictive CO₂') - plt.xlabel('Time of day') - plt.ylabel('Concentration (ppm)') - plt.legend() - return img2base64(_figure2bytes(fig)) async def post(self, endpoint: str) -> None: requested_model_config = tornado.escape.json_decode(self.request.body) @@ -419,8 +364,8 @@ class CO2Data(BaseRequestHandler): return if endpoint.rstrip('/') == 'plot': - transition_times = self.find_change_points_with_pelt(form.CO2_data) - self.finish({'CO2_plot': self.generate_ventilation_plot(CO2_data=form.CO2_data, transition_times=transition_times), + transition_times = co2_model_generator.CO2FormData.find_change_points_with_pelt(form.CO2_data) + self.finish({'CO2_plot': co2_model_generator.CO2FormData.generate_ventilation_plot(form.CO2_data, transition_times), 'transition_times': [round(el, 2) for el in transition_times]}) else: executor = loky.get_reusable_executor( @@ -437,7 +382,7 @@ class CO2Data(BaseRequestHandler): result['fitting_ventilation_type'] = form.fitting_ventilation_type result['transition_times'] = ventilation_transition_times - result['CO2_plot'] = self.generate_ventilation_plot(CO2_data=form.CO2_data, + result['CO2_plot'] = co2_model_generator.CO2FormData.generate_ventilation_plot(CO2_data=form.CO2_data, transition_times=ventilation_transition_times[:-1], predictive_CO2=result['predictive_CO2']) self.finish(result) @@ -464,7 +409,7 @@ def make_app( base_urls: typing.List = [ (get_root_url(r'/?'), LandingPage), (get_root_calculator_url(r'/?'), CalculatorForm), - (get_root_calculator_url(r'/co2-fit/(.*)'), CO2Data), + (get_root_calculator_url(r'/co2-fit/(.*)'), CO2ModelResponse), (get_root_calculator_url(r'/report'), ConcentrationModel), (get_root_url(r'/static/(.*)'), StaticFileHandler, {'path': static_dir}), (get_root_calculator_url(r'/static/(.*)'), StaticFileHandler, {'path': calculator_static_dir}), diff --git a/caimira/apps/calculator/co2_model_generator.py b/caimira/apps/calculator/co2_model_generator.py index f278bb6c..0ca201c0 100644 --- a/caimira/apps/calculator/co2_model_generator.py +++ b/caimira/apps/calculator/co2_model_generator.py @@ -2,10 +2,14 @@ import dataclasses import html import logging import typing +import numpy as np +import ruptures as rpt +import matplotlib.pyplot as plt from caimira import models from . import model_generator -from .defaults import DEFAULT_MC_SAMPLE_SIZE, NO_DEFAULT, COFFEE_OPTIONS_INT +from .defaults import DEFAULT_MC_SAMPLE_SIZE, NO_DEFAULT +from .report_generator import img2base64, _figure2bytes minutes_since_midnight = typing.NewType('minutes_since_midnight', int) @@ -96,6 +100,60 @@ class CO2FormData(model_generator.FormData): instance = self(**form_data) instance.validate_population_parameters() return instance + + @classmethod + def find_change_points_with_pelt(self, CO2_data: dict): + """ + Perform change point detection using Pelt algorithm from ruptures library with pen=15. + Returns a list of tuples containing (index, X-axis value) for the detected significant changes. + """ + + times: list = CO2_data['times'] + CO2_values: list = CO2_data['CO2'] + + if len(times) != len(CO2_values): + raise ValueError("times and CO2 values must have the same length.") + + # Convert the input list to a numpy array for use with the ruptures library + CO2_np = np.array(CO2_values) + + # Define the model for change point detection (Radial Basis Function kernel) + model = "rbf" + + # Fit the Pelt algorithm to the data with the specified model + algo = rpt.Pelt(model=model).fit(CO2_np) + + # Predict change points using the Pelt algorithm with a penalty value of 15 + result = algo.predict(pen=15) + + # Find local minima and maxima + segments = np.split(np.arange(len(CO2_values)), result) + merged_segments = [np.hstack((segments[i], segments[i + 1])) for i in range(len(segments) - 1)] + result_set = set() + for segment in merged_segments[:-2]: + result_set.add(times[CO2_values.index(min(CO2_np[segment]))]) + result_set.add(times[CO2_values.index(max(CO2_np[segment]))]) + return list(result_set) + + @classmethod + def generate_ventilation_plot(self, CO2_data: dict, + transition_times: typing.Optional[list] = None, + predictive_CO2: typing.Optional[list] = None): + times_values = CO2_data['times'] + CO2_values = CO2_data['CO2'] + + fig = plt.figure(figsize=(7, 4), dpi=110) + plt.plot(times_values, CO2_values, label='Input CO₂') + + if (transition_times): + for time in transition_times: + plt.axvline(x = time, color = 'grey', linewidth=0.5, linestyle='--') + if (predictive_CO2): + plt.plot(times_values, predictive_CO2, label='Predictive CO₂') + plt.xlabel('Time of day') + plt.ylabel('Concentration (ppm)') + plt.legend() + return img2base64(_figure2bytes(fig)) def population_present_changes(self, infected_presence: models.Interval, exposed_presence: models.Interval) -> typing.List[float]: state_change_times = set(infected_presence.transition_times())