cara/caimira/apps/calculator/__init__.py

529 lines
22 KiB
Python

# This module is part of CAiMIRA. Please see the repository at
# https://gitlab.cern.ch/caimira/caimira for details of the license and terms of use.
import ast
import logging
import asyncio
import concurrent.futures
import datetime
import base64
import functools
import html
import json
import pandas as pd
from pprint import pformat
from io import StringIO
import os
from pathlib import Path
import traceback
import typing
import uuid
import zlib
import jinja2
import loky
from tornado.web import Application, RequestHandler, StaticFileHandler
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
import tornado.log
from caimira.store.data_registry import DataRegistry
from caimira.store.data_service import DataService
from . import markdown_tools
from . import model_generator, co2_model_generator
from .report_generator import ReportGenerator, calculate_report_data
from .user import AuthenticatedUser, AnonymousUser
# The calculator version is based on a combination of the model version and the
# semantic version of the calculator itself. The version uses the terms
# "{MAJOR}.{MINOR}.{PATCH}" to describe the 3 distinct numbers constituting a version.
# Effectively, if the model increases its MAJOR version then so too should this
# calculator version. If the calculator needs to make breaking changes (e.g. change
# form attributes) then it can also increase its MAJOR version without needing to
# increase the overall CAiMIRA version (found at ``caimira.__version__``).
__version__ = "4.14.3"
LOG = logging.getLogger("APP")
class BaseRequestHandler(RequestHandler):
async def prepare(self):
"""Called at the beginning of a request before `get`/`post`/etc."""
# Read the secure cookie which exists if we are in an authenticated
# context (though not if the caimira webservice is running standalone).
session = json.loads(self.get_secure_cookie('session') or 'null')
if session:
self.current_user = AuthenticatedUser(
username=session['username'],
email=session['email'],
fullname=session['fullname'],
)
else:
self.current_user = AnonymousUser()
def write_error(self, status_code: int, **kwargs) -> None:
template = self.settings["template_environment"].get_template(
"error.html.j2")
error_id = uuid.uuid4()
# Print the error to the log (and not to the browser!)
if "exc_info" in kwargs:
print(f"ERROR UUID {error_id}")
print(traceback.format_exc())
self.finish(template.render(
user=self.current_user,
get_url = template.globals['get_url'],
get_calculator_url = template.globals["get_calculator_url"],
active_page='Error',
error_id=error_id,
status_code=status_code,
datetime=datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S"),
))
class Missing404Handler(BaseRequestHandler):
async def prepare(self):
await super().prepare()
self.set_status(404)
template = self.settings["template_environment"].get_template(
"error.html.j2")
self.finish(template.render(
user=self.current_user,
get_url = template.globals['get_url'],
get_calculator_url = template.globals["get_calculator_url"],
active_page='Error',
status_code=404,
))
class ConcentrationModel(BaseRequestHandler):
async def post(self) -> None:
debug = self.settings.get("debug", False)
data_registry: DataRegistry = self.settings["data_registry"]
data_service: typing.Optional[DataService] = self.settings.get("data_service", None)
if data_service:
data_service.update_registry(data_registry)
requested_model_config = {
name: self.get_argument(name) for name in self.request.arguments
}
LOG.debug(pformat(requested_model_config))
try:
form = model_generator.VirusFormData.from_dict(requested_model_config, data_registry)
except Exception as err:
LOG.exception(err)
response_json = {'code': 400, 'error': f'Your request was invalid {html.escape(str(err))}'}
self.set_status(400)
self.finish(json.dumps(response_json))
return
base_url = self.request.protocol + "://" + self.request.host
report_generator: ReportGenerator = self.settings['report_generator']
executor = loky.get_reusable_executor(
max_workers=self.settings['handler_worker_pool_size'],
timeout=300,
)
# Re-generate the report with the conditional probability of infection plot
if self.get_cookie('conditional_plot'):
form.conditional_probability_plot = True if self.get_cookie('conditional_plot') == '1' else False
self.clear_cookie('conditional_plot') # Clears cookie after changing the form value.
report_task = executor.submit(
report_generator.build_report, base_url, form,
executor_factory=functools.partial(
concurrent.futures.ThreadPoolExecutor,
self.settings['report_generation_parallelism'],
),
)
report: str = await asyncio.wrap_future(report_task)
self.finish(report)
class ConcentrationModelJsonResponse(BaseRequestHandler):
def check_xsrf_cookie(self):
"""
This request handler implements a stateless API that returns report data in JSON format.
Thus, XSRF cookies are disabled by overriding base class implementation of this method with a pass statement.
"""
pass
async def post(self) -> None:
"""
Expects algorithm input in HTTP POST request body in JSON format.
Returns report data (algorithm output) in HTTP POST response body in JSON format.
"""
debug = self.settings.get("debug", False)
data_registry: DataRegistry = self.settings["data_registry"]
data_service: typing.Optional[DataService] = self.settings.get("data_service", None)
if data_service:
data_service.update_registry(data_registry)
requested_model_config = json.loads(self.request.body)
LOG.debug(pformat(requested_model_config))
try:
form = model_generator.VirusFormData.from_dict(requested_model_config, data_registry)
except Exception as err:
LOG.exception(err)
response_json = {'code': 400, 'error': f'Your request was invalid {html.escape(str(err))}'}
self.set_status(400)
await self.finish(json.dumps(response_json))
return
executor = loky.get_reusable_executor(
max_workers=self.settings['handler_worker_pool_size'],
timeout=300,
)
model = form.build_model()
report_data_task = executor.submit(calculate_report_data, form, model)
report_data: dict = await asyncio.wrap_future(report_data_task)
await self.finish(report_data)
class StaticModel(BaseRequestHandler):
async def get(self) -> None:
debug = self.settings.get("debug", False)
data_registry: DataRegistry = self.settings["data_registry"]
data_service: typing.Optional[DataService] = self.settings.get("data_service", None)
if data_service:
data_service.update_registry(data_registry)
form = model_generator.VirusFormData.from_dict(model_generator.baseline_raw_form_data(), data_registry)
base_url = self.request.protocol + "://" + self.request.host
report_generator: ReportGenerator = self.settings['report_generator']
executor = loky.get_reusable_executor(max_workers=self.settings['handler_worker_pool_size'])
report_task = executor.submit(
report_generator.build_report, base_url, form,
executor_factory=functools.partial(
concurrent.futures.ThreadPoolExecutor,
self.settings['report_generation_parallelism'],
),
)
report: str = await asyncio.wrap_future(report_task)
self.finish(report)
class LandingPage(BaseRequestHandler):
def get(self):
template_environment = self.settings["template_environment"]
template = template_environment.get_template(
"index.html.j2")
report = template.render(
user=self.current_user,
get_url = template_environment.globals['get_url'],
get_calculator_url = template_environment.globals['get_calculator_url'],
text_blocks=template_environment.globals["common_text"],
)
self.finish(report)
class CalculatorForm(BaseRequestHandler):
def get(self):
template_environment = self.settings["template_environment"]
template = template_environment.get_template(
"calculator.form.html.j2")
report = template.render(
user=self.current_user,
xsrf_form_html=self.xsrf_form_html(),
get_url = template.globals['get_url'],
get_calculator_url = template.globals["get_calculator_url"],
calculator_version=__version__,
text_blocks=template_environment.globals["common_text"],
)
self.finish(report)
class CompressedCalculatorFormInputs(BaseRequestHandler):
def get(self, compressed_args: str):
# Convert a base64 zlib encoded shortened URL into a non compressed
# URL, and redirect.
try:
args = zlib.decompress(base64.b64decode(compressed_args)).decode()
except Exception as err: # noqa
self.set_status(400)
return self.finish("Invalid calculator data: it seems incomplete. Was there an error copying & pasting the URL?")
template_environment = self.settings["template_environment"]
self.redirect(f'{template_environment.globals["get_calculator_url"]()}?{args}')
class ArveData(BaseRequestHandler):
async def get(self, hotel_id, floor_id):
client_id = self.settings["arve_client_id"]
client_secret = self.settings['arve_client_secret']
arve_api_key = self.settings['arve_api_key']
if (client_id == None or client_secret == None or arve_api_key == None):
# If the credentials are not defined, we skip the ARVE API connection
return self.send_error(401)
http_client = AsyncHTTPClient()
URL = 'https://arveapi.auth.eu-central-1.amazoncognito.com/oauth2/token'
headers = { "Content-Type": "application/x-www-form-urlencoded",
"Authorization": b"Basic " + base64.b64encode(f'{client_id}:{client_secret}'.encode())
}
try:
response = await http_client.fetch(HTTPRequest(
url=URL,
method='POST',
headers=headers,
body="grant_type=client_credentials"
),
raise_error=True)
except Exception as e:
print("Something went wrong: %s" % e)
access_token = json.loads(response.body)['access_token']
URL = f'https://api.arve.swiss/v1/{hotel_id}/{floor_id}'
headers = {
"x-api-key": arve_api_key,
"Authorization": f'Bearer {access_token}'
}
try:
response = await http_client.fetch(HTTPRequest(
url=URL,
method='GET',
headers=headers,
),
raise_error=True)
except Exception as e:
print("Something went wrong: %s" % e)
self.set_header("Content-Type", 'application/json')
return self.finish(response.body)
class CasesData(BaseRequestHandler):
async def get(self, country):
http_client = AsyncHTTPClient()
# First we need the country to fetch the data
URL = f'https://restcountries.com/v3.1/alpha/{country}?fields=name'
try:
response = await http_client.fetch(HTTPRequest(
url=URL,
method='GET',
),
raise_error=True)
except Exception as e:
print("Something went wrong: %s" % e)
country_name = json.loads(response.body)['name']['common']
# Get global incident rates
URL = 'https://covid19.who.int/WHO-COVID-19-global-data.csv'
try:
response = await http_client.fetch(HTTPRequest(
url=URL,
method='GET',
),
raise_error=True)
except Exception as e:
print("Something went wrong: %s" % e)
df = pd.read_csv(StringIO(response.body.decode('utf-8')), index_col=False)
cases = df.loc[df['Country'] == country_name]
# 7-day rolling average
current_date = str(datetime.datetime.now()).split(' ')[0]
eight_days_ago = str(datetime.datetime.now() - datetime.timedelta(days=7)).split(' ')[0]
cases = cases.set_index(['Date_reported'])
# If any of the 'New_cases' is 0, it means the data is not updated.
if (cases.loc[eight_days_ago:current_date]['New_cases'] == 0).any(): return self.finish('')
return self.finish(str(round(cases.loc[eight_days_ago:current_date]['New_cases'].mean())))
class GenericExtraPage(BaseRequestHandler):
def initialize(self, active_page: str, filename: str):
self.active_page = active_page
# The endpoint that will be used as template name
self.filename = filename
def get(self):
template_environment = self.settings["template_environment"]
template = template_environment.get_template(self.filename)
self.finish(template.render(
user=self.current_user,
get_url = template.globals['get_url'],
get_calculator_url = template.globals["get_calculator_url"],
active_page=self.active_page,
text_blocks=template_environment.globals["common_text"]
))
class CO2ModelResponse(BaseRequestHandler):
def check_xsrf_cookie(self):
"""
This request handler implements a stateless API that returns report data in JSON format.
Thus, XSRF cookies are disabled by overriding base class implementation of this method with a pass statement.
"""
pass
async def post(self, endpoint: str) -> None:
data_registry: DataRegistry = self.settings["data_registry"]
data_service: typing.Optional[DataService] = self.settings.get("data_service", None)
if data_service:
data_service.update_registry(data_registry)
requested_model_config = tornado.escape.json_decode(self.request.body)
try:
form = co2_model_generator.CO2FormData.from_dict(requested_model_config, data_registry)
except Exception as err:
if self.settings.get("debug", False):
import traceback
print(traceback.format_exc())
response_json = {'code': 400, 'error': f'Your request was invalid {html.escape(str(err))}'}
self.set_status(400)
self.finish(json.dumps(response_json))
return
if endpoint.rstrip('/') == 'plot':
transition_times = co2_model_generator.CO2FormData.find_change_points_with_pelt(form.CO2_data)
self.finish({'CO2_plot': co2_model_generator.CO2FormData.generate_ventilation_plot(form.CO2_data, transition_times),
'transition_times': [round(el, 2) for el in transition_times]})
else:
executor = loky.get_reusable_executor(
max_workers=self.settings['handler_worker_pool_size'],
timeout=300,
)
report_task = executor.submit(
co2_model_generator.CO2FormData.build_model, form,
)
report = await asyncio.wrap_future(report_task)
result = dict(report.CO2_fit_params())
ventilation_transition_times = report.ventilation_transition_times
result['fitting_ventilation_type'] = form.fitting_ventilation_type
result['transition_times'] = ventilation_transition_times
result['CO2_plot'] = co2_model_generator.CO2FormData.generate_ventilation_plot(CO2_data=form.CO2_data,
transition_times=ventilation_transition_times[:-1],
predictive_CO2=result['predictive_CO2'])
self.finish(result)
def get_url(app_root: str, relative_path: str = '/'):
return app_root.rstrip('/') + relative_path.rstrip('/')
def get_calculator_url(app_root: str, calculator_prefix: str, relative_path: str = '/'):
return app_root.rstrip('/') + calculator_prefix.rstrip('/') + relative_path.rstrip('/')
def make_app(
debug: bool = False,
APPLICATION_ROOT: str = '/',
calculator_prefix: str = '/calculator',
theme_dir: typing.Optional[Path] = None,
) -> Application:
static_dir = Path(__file__).absolute().parent.parent / 'static'
calculator_static_dir = Path(__file__).absolute().parent / 'static'
get_root_url = functools.partial(get_url, APPLICATION_ROOT)
get_root_calculator_url = functools.partial(get_calculator_url, APPLICATION_ROOT, calculator_prefix)
base_urls: typing.List = [
(get_root_url(r'/?'), LandingPage),
(get_root_calculator_url(r'/?'), CalculatorForm),
(get_root_calculator_url(r'/co2-fit/(.*)'), CO2ModelResponse),
(get_root_calculator_url(r'/report'), ConcentrationModel),
(get_root_url(r'/static/(.*)'), StaticFileHandler, {'path': static_dir}),
(get_root_calculator_url(r'/static/(.*)'), StaticFileHandler, {'path': calculator_static_dir}),
]
urls: typing.List = base_urls + [
(get_root_url(r'/_c/(.*)'), CompressedCalculatorFormInputs),
(get_root_calculator_url(r'/report-json'), ConcentrationModelJsonResponse),
(get_root_calculator_url(r'/baseline-model/result'), StaticModel),
(get_root_calculator_url(r'/api/arve/v1/(.*)/(.*)'), ArveData),
(get_root_calculator_url(r'/cases/(.*)'), CasesData),
# Generic Pages
(get_root_url(r'/about'), GenericExtraPage, {
'active_page': 'about',
'filename': 'about.html.j2'}),
(get_root_calculator_url(r'/user-guide'), GenericExtraPage, {
'active_page': 'calculator/user-guide',
'filename': 'userguide.html.j2'}),
]
interface: str = os.environ.get('CAIMIRA_THEME', '<undefined>')
if interface != '<undefined>' and (interface != '<undefined>' and 'cern' not in interface): urls = list(filter(lambda i: i in base_urls, urls))
# Any extra generic page must be declared in the env. variable "EXTRA_PAGES"
extra_pages: typing.Union[str, typing.List] = os.environ.get('EXTRA_PAGES', [])
pages: typing.List = []
try:
pages = ast.literal_eval(extra_pages) # type: ignore
except (SyntaxError, ValueError):
LOG.warning('Warning: There was a problem with the extra pages. Is the "EXTRA_PAGES" environment variable defined?')
pass
for extra in pages:
urls.append((get_root_url(r'%s' % extra['url_path']),
GenericExtraPage, {
'active_page': extra['url_path'].strip('/'),
'filename': extra['filename'], }))
caimira_templates = Path(__file__).parent.parent / "templates"
calculator_templates = Path(__file__).parent / "templates"
templates_directories = [caimira_templates, calculator_templates]
if theme_dir:
templates_directories.insert(0, theme_dir)
loader = jinja2.FileSystemLoader([str(path) for path in templates_directories])
template_environment = jinja2.Environment(
loader=loader,
undefined=jinja2.StrictUndefined, # fail when rendering any undefined template context variable
)
template_environment.globals["common_text"] = markdown_tools.extract_rendered_markdown_blocks(
template_environment.get_template('common_text.md.j2')
)
template_environment.globals['get_url']=get_root_url
template_environment.globals['get_calculator_url']=get_root_calculator_url
if debug:
tornado.log.enable_pretty_logging()
data_registry = DataRegistry()
data_service = None
data_service_enabled = os.environ.get("DATA_SERVICE_ENABLED", "False")
is_enabled = data_service_enabled.lower() == "true"
if is_enabled: data_service = DataService.create()
return Application(
urls,
debug=debug,
data_registry=data_registry,
data_service=data_service,
template_environment=template_environment,
default_handler_class=Missing404Handler,
report_generator=ReportGenerator(loader, get_root_url, get_root_calculator_url),
xsrf_cookies=True,
# COOKIE_SECRET being undefined will result in no login information being
# presented to the user.
cookie_secret=os.environ.get('COOKIE_SECRET', '<undefined>'),
arve_client_id=os.environ.get('ARVE_CLIENT_ID', None),
arve_client_secret=os.environ.get('ARVE_CLIENT_SECRET', None),
arve_api_key=os.environ.get('ARVE_API_KEY', None),
# Process parallelism controls. There is a balance between serving a single report
# requests quickly or serving multiple requests concurrently.
# The defaults are: handle one report at a time, and allow parallelism
# of that report generation. A value of ``None`` will result in the number of
# processes being determined based on the number of CPUs. For some deployments,
# such as on OpenShift this number does *not* reflect the real number of CPUs that
# can be used, and it is recommended to specify these values explicitly (through
# the environment variables).
handler_worker_pool_size=(
int(os.environ.get("HANDLER_WORKER_POOL_SIZE", 1)) or None
),
report_generation_parallelism=(
int(os.environ.get('REPORT_PARALLELISM', 0)) or None
),
)