diff --git a/src/octoprint/server/__init__.py b/src/octoprint/server/__init__.py index 476000dd..3226d02f 100644 --- a/src/octoprint/server/__init__.py +++ b/src/octoprint/server/__init__.py @@ -7,7 +7,7 @@ __copyright__ = "Copyright (C) 2014 The OctoPrint Project - Released under terms import uuid from sockjs.tornado import SockJSRouter -from flask import Flask, g, request, session, Blueprint +from flask import Flask, g, request, session, Blueprint, Request, Response from flask.ext.login import LoginManager, current_user from flask.ext.principal import Principal, Permission, RoleNeed, identity_loaded, UserNeed from flask.ext.babel import Babel, gettext, ngettext @@ -168,7 +168,7 @@ class Server(object): util.flask.enable_additional_translations(additional_folders=[s.getBaseFolder("translations")]) # setup app - self._setup_app() + self._setup_app(app) # setup i18n self._setup_i18n(app) @@ -199,9 +199,15 @@ class Server(object): pluginLifecycleManager = LifecycleManager(pluginManager) preemptiveCache = PreemptiveCache(os.path.join(s.getBaseFolder("data"), "preemptive_cache_config.yaml")) + # ... and initialize all plugins + def octoprint_plugin_inject_factory(name, implementation): + """Factory for injections for all OctoPrintPlugins""" + if not isinstance(implementation, octoprint.plugin.OctoPrintPlugin): + # we only care about OctoPrintPlugins return None + return dict( plugin_manager=pluginManager, printer_profile_manager=printerProfileManager, @@ -217,8 +223,13 @@ class Server(object): ) def settings_plugin_inject_factory(name, implementation): + """Factory for additional injections depending on plugin type""" + if not isinstance(implementation, octoprint.plugin.SettingsPlugin): + # we only care about SettingsPlugins return None + + # SettingsPlugin instnances get a PluginSettings instance injected default_settings = implementation.get_settings_defaults() get_preprocessors, set_preprocessors = implementation.get_settings_preprocessors() plugin_settings = octoprint.plugin.plugin_settings(name, @@ -228,6 +239,8 @@ class Server(object): return dict(settings=plugin_settings) def settings_plugin_config_migration_and_cleanup(name, implementation): + """Take care of migrating and cleaning up any old settings""" + if not isinstance(implementation, octoprint.plugin.SettingsPlugin): return @@ -269,6 +282,8 @@ class Server(object): # setup jinja2 self._setup_jinja2() + + # make sure plugin lifecycle events relevant for jinja2 are taken care of def template_enabled(name, plugin): if plugin.implementation is None or not isinstance(plugin.implementation, octoprint.plugin.TemplatePlugin): return @@ -296,31 +311,12 @@ class Server(object): try: clazz = octoprint.util.get_class(userManagerName) userManager = clazz() - except AttributeError, e: + except AttributeError as e: self._logger.exception("Could not instantiate user manager {}, falling back to FilebasedUserManager!".format(userManagerName)) userManager = octoprint.users.FilebasedUserManager() finally: userManager.enabled = s.getBoolean(["accessControl", "enabled"]) - app.wsgi_app = util.ReverseProxied( - app.wsgi_app, - s.get(["server", "reverseProxy", "prefixHeader"]), - s.get(["server", "reverseProxy", "schemeHeader"]), - s.get(["server", "reverseProxy", "hostHeader"]), - s.get(["server", "reverseProxy", "prefixFallback"]), - s.get(["server", "reverseProxy", "schemeFallback"]), - s.get(["server", "reverseProxy", "hostFallback"]) - ) - - secret_key = s.get(["server", "secretKey"]) - if not secret_key: - import string - from random import choice - chars = string.ascii_lowercase + string.ascii_uppercase + string.digits - secret_key = "".join(choice(chars) for _ in xrange(32)) - s.set(["server", "secretKey"], secret_key) - s.save() - app.secret_key = secret_key loginManager = LoginManager() loginManager.session_protection = "strong" loginManager.user_callback = load_user @@ -329,18 +325,16 @@ class Server(object): principals.identity_loaders.appendleft(users.dummy_identity_loader) loginManager.init_app(app) - if self._host is None: - self._host = s.get(["server", "host"]) - if self._port is None: - self._port = s.getInt(["server", "port"]) - - app.debug = self._debug - # register API blueprint self._setup_blueprints() ## Tornado initialization starts here + if self._host is None: + self._host = s.get(["server", "host"]) + if self._port is None: + self._port = s.getInt(["server", "port"]) + ioloop = IOLoop() ioloop.install() @@ -382,6 +376,8 @@ class Server(object): (r"/online.txt", util.tornado.StaticDataHandler, dict(data="online\n")), (r"/online.gif", util.tornado.StaticDataHandler, dict(data=bytes(base64.b64decode("R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7")), content_type="image/gif")) ] + + # fetch additional routes from plugins for name, hook in pluginManager.get_hooks("octoprint.server.http.routes").items(): try: result = hook(list(server_routes)) @@ -435,10 +431,13 @@ class Server(object): self._stop_intermediary_server() + # initialize and bind the server self._server = util.tornado.CustomHTTPServer(self._tornado_app, max_body_sizes=max_body_sizes, default_max_body_size=s.getInt(["server", "maxSize"])) self._server.listen(self._port, address=self._host) eventManager.fire(events.Events.STARTUP) + + # auto connect if s.getBoolean(["serial", "autoconnect"]): (port, baudrate) = s.get(["serial", "port"]), s.getInt(["serial", "baudrate"]) printer_profile = printerProfileManager.get_default() @@ -531,7 +530,8 @@ class Server(object): def _create_socket_connection(self, session): global printer, fileManager, analysisQueue, userManager, eventManager - return util.sockjs.PrinterStateConnection(printer, fileManager, analysisQueue, userManager, eventManager, pluginManager, session) + return util.sockjs.PrinterStateConnection(printer, fileManager, analysisQueue, userManager, + eventManager, pluginManager, session) def _check_for_root(self): if "geteuid" in dir(os) and os.geteuid() == 0: @@ -635,7 +635,41 @@ class Server(object): logging.getLogger("SERIAL").setLevel(logging.DEBUG) logging.getLogger("SERIAL").debug("Enabling serial logging") - def _setup_app(self): + def _setup_app(self, app): + from octoprint.server.util.flask import ReverseProxiedEnvironment, OctoPrintFlaskRequest, OctoPrintFlaskResponse + + s = settings() + + app.debug = self._debug + + secret_key = s.get(["server", "secretKey"]) + if not secret_key: + import string + from random import choice + chars = string.ascii_lowercase + string.ascii_uppercase + string.digits + secret_key = "".join(choice(chars) for _ in range(32)) + s.set(["server", "secretKey"], secret_key) + s.save() + + app.secret_key = secret_key + + reverse_proxied = ReverseProxiedEnvironment( + header_prefix=s.get(["server", "reverseProxy", "prefixHeader"]), + header_scheme=s.get(["server", "reverseProxy", "schemeHeader"]), + header_host=s.get(["server", "reverseProxy", "hostHeader"]), + header_server=s.get(["server", "reverseProxy", "serverHeader"]), + header_port=s.get(["server", "reverseProxy", "portHeader"]), + prefix=s.get(["server", "reverseProxy", "prefixFallback"]), + scheme=s.get(["server", "reverseProxy", "schemeFallback"]), + host=s.get(["server", "reverseProxy", "hostFallback"]), + server=s.get(["server", "reverseProxy", "serverFallback"]), + port=s.get(["server", "reverseProxy", "portFallback"]) + ) + + OctoPrintFlaskRequest.environment_wrapper = reverse_proxied + app.request_class = OctoPrintFlaskRequest + app.response_class = OctoPrintFlaskResponse + @app.before_request def before_request(): g.locale = self._get_locale() @@ -932,7 +966,7 @@ class Server(object): # that might be caused by the user still having the folder open somewhere, let's try again after # waiting a bit import time - for n in xrange(3): + for n in range(3): time.sleep(0.5) self._logger.debug("Creating {path}: Retry #{retry} after {time}s".format(path=path, retry=n+1, time=(n + 1)*0.5)) try: diff --git a/src/octoprint/server/util/__init__.py b/src/octoprint/server/util/__init__.py index e75d45fd..dc4c4c83 100644 --- a/src/octoprint/server/util/__init__.py +++ b/src/octoprint/server/util/__init__.py @@ -155,97 +155,3 @@ def get_plugin_hash(): plugin_hash = hashlib.sha1() plugin_hash.update(",".join(ui_plugins)) return plugin_hash.hexdigest() - - -#~~ reverse proxy compatible WSGI middleware - - -class ReverseProxied(object): - """ - Wrap the application in this middleware and configure the - front-end server to add these headers, to let you quietly bind - this to a URL other than / and to an HTTP scheme that is - different than what is used locally. - - In nginx: - - .. code-block:: none - - location /myprefix { - proxy_pass http://192.168.0.1:5001; - proxy_set_header Host $host; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Scheme $scheme; - proxy_set_header X-Script-Name /myprefix; - } - - Alternatively define prefix and scheme via config.yaml: - - .. code-block:: yaml - - server: - baseUrl: /myprefix - scheme: http - - :param app: the WSGI application - :param header_script_name: the HTTP header in the wsgi environment from which to determine the prefix - :param header_scheme: the HTTP header in the wsgi environment from which to determine the scheme - :param header_host: the HTTP header in the wsgi environment from which to determine the host for which to generate external URLs - :param base_url: the prefix to use as fallback if headers are not set - :param scheme: the scheme to use as fallback if headers are not set - :param host: the host to use as fallback if headers are not set - """ - - def __init__(self, app, header_prefix="x-script-name", header_scheme="x-scheme", header_host="x-forwarded-host", base_url="", scheme="", host=""): - self.app = app - - # headers for prefix & scheme & host, converted to conform to WSGI format - to_wsgi_format = lambda header: "HTTP_" + header.upper().replace("-", "_") - self._header_prefix = to_wsgi_format(header_prefix) - self._header_scheme = to_wsgi_format(header_scheme) - self._header_host = to_wsgi_format(header_host) - - # fallback prefix & scheme & host from config - self._fallback_prefix = base_url - self._fallback_scheme = scheme - self._fallback_host = host - - def __call__(self, environ, start_response): - # determine prefix - prefix = environ.get(self._header_prefix, "") - if not prefix: - prefix = self._fallback_prefix - - # rewrite SCRIPT_NAME and if necessary also PATH_INFO based on prefix - if prefix: - environ["SCRIPT_NAME"] = prefix - path_info = environ["PATH_INFO"] - if path_info.startswith(prefix): - environ["PATH_INFO"] = path_info[len(prefix):] - - # determine scheme - scheme = environ.get(self._header_scheme, "") - if scheme and "," in scheme: - # Scheme might be something like "https,https" if doubly-reverse-proxied - # without stripping original scheme header first, make sure to only use - # the first entry in such a case. See #1391. - scheme, _ = map(lambda x: x.strip(), scheme.split(",", 1)) - if not scheme: - scheme = self._fallback_scheme - - # rewrite wsgi.url_scheme based on scheme - if scheme: - environ["wsgi.url_scheme"] = scheme - - # determine host - host = environ.get(self._header_host, "") - if not host: - host = self._fallback_host - - # rewrite host header based on host - if host: - environ["HTTP_HOST"] = host - - # call wrapped app with rewritten environment - return self.app(environ, start_response) - diff --git a/src/octoprint/server/util/flask.py b/src/octoprint/server/util/flask.py index 087f424d..fe48c7f7 100644 --- a/src/octoprint/server/util/flask.py +++ b/src/octoprint/server/util/flask.py @@ -219,6 +219,193 @@ def fix_webassets_filtertool(): FilterTool._wrap_cache = fixed_wrap_cache +#~~ WSGI environment wrapper for reverse proxying + +class ReverseProxiedEnvironment(object): + + @staticmethod + def to_header_candidates(values): + if values is None: + return [] + if not isinstance(values, (list, tuple)): + values = [values] + to_wsgi_format = lambda header: "HTTP_" + header.upper().replace("-", "_") + return map(to_wsgi_format, values) + + def __init__(self, + header_prefix=None, + header_scheme=None, + header_host=None, + header_server=None, + header_port=None, + prefix=None, + scheme=None, + host=None, + server=None, + port=None): + + # sensible defaults + if header_prefix is None: + header_prefix = ["x-script-name"] + if header_scheme is None: + header_scheme = ["x-forwarded-proto", "x-scheme"] + if header_host is None: + header_host = ["x-forwarded-host"] + if header_server is None: + header_server = ["x-forwarded-server"] + if header_port is None: + header_port = ["x-forwarded-port"] + + # header candidates + self._headers_prefix = self.to_header_candidates(header_prefix) + self._headers_scheme = self.to_header_candidates(header_scheme) + self._headers_host = self.to_header_candidates(header_host) + self._headers_server = self.to_header_candidates(header_server) + self._headers_port = self.to_header_candidates(header_port) + + # fallback prefix & scheme & host from config + self._fallback_prefix = prefix + self._fallback_scheme = scheme + self._fallback_host = host + self._fallback_server = server + self._fallback_port = port + + def __call__(self, environ): + def retrieve_header(header_type): + candidates = getattr(self, "_headers_" + header_type, []) + fallback = getattr(self, "_fallback_" + header_type, None) + + for candidate in candidates: + value = environ.get(candidate, None) + if value is not None: + return value + else: + return fallback + + def host_to_server_and_port(host, scheme): + if host is None: + return None, None + + if ":" in host: + server, port = host.split(":", 1) + else: + server = host + port = "443" if scheme == "https" else "80" + + return server, port + + # determine prefix + prefix = retrieve_header("prefix") + if prefix is not None: + environ["SCRIPT_NAME"] = prefix + path_info = environ["PATH_INFO"] + if path_info.startswith(prefix): + environ["PATH_INFO"] = path_info[len(prefix):] + + # determine scheme + scheme = retrieve_header("scheme") + if scheme is not None and "," in scheme: + # Scheme might be something like "https,https" if doubly-reverse-proxied + # without stripping original scheme header first, make sure to only use + # the first entry in such a case. See #1391. + scheme, _ = map(lambda x: x.strip(), scheme.split(",", 1)) + if scheme is not None: + environ["wsgi.url_scheme"] = scheme + + # determine host + url_scheme = environ["wsgi.url_scheme"] + host = retrieve_header("host") + if host is not None: + # if we have a host, we take server_name and server_port from it + server, port = host_to_server_and_port(host, url_scheme) + environ["HTTP_HOST"] = host + environ["SERVER_NAME"] = server + environ["SERVER_PORT"] = port + else: + # else we take a look at the server and port headers and if we have + # something there we derive the host from it + + # determine server - should usually not be used + server = retrieve_header("server") + if server is not None: + environ["SERVER_NAME"] = server + + # determine port - should usually not be used + port = retrieve_header("port") + if port is not None: + environ["SERVER_PORT"] = port + + # make sure HTTP_HOST matches SERVER_NAME and SERVER_PORT + expected_server, expected_port = host_to_server_and_port(environ.get("HTTP_HOST", None), url_scheme) + if expected_server != environ["SERVER_NAME"] or expected_port != environ["SERVER_PORT"]: + # there's a difference, fix it! + if url_scheme == "http" and environ["SERVER_PORT"] == "80" or url_scheme == "https" and environ["SERVER_PORT"] == "443": + # default port for scheme, can be skipped + environ["HTTP_HOST"] = environ["SERVER_NAME"] + else: + environ["HTTP_HOST"] = environ["SERVER_NAME"] + ":" + environ["SERVER_PORT"] + + # call wrapped app with rewritten environment + return environ + +#~~ request and response versions + +from werkzeug.wrappers import cached_property + +class OctoPrintFlaskRequest(flask.Request): + environment_wrapper = staticmethod(lambda x: x) + + def __init__(self, environ, *args, **kwargs): + # apply environment wrapper to provided WSGI environment + flask.Request.__init__(self, self.environment_wrapper(environ), *args, **kwargs) + + @cached_property + def cookies(self): + # strip cookie_suffix from all cookies in the request, return result + cookies = flask.Request.cookies.__get__(self) + + def cookie_name_converter(key): + return key[:-len(self.cookie_suffix)] if key.endswith(self.cookie_suffix) else key + + return dict((cookie_name_converter(key), value) for key, value in cookies.items()) + + @cached_property + def server_name(self): + """Short cut to the request's server name header""" + return self.environ.get("SERVER_NAME") + + @cached_property + def server_port(self): + """Short cut to the request's server port header""" + return self.environ.get("SERVER_PORT") + + @cached_property + def cookie_suffix(self): + """ + Request specific suffix for set and read cookies + + We need this because cookies are not port-specific and we don't want to overwrite our + session and other cookies from one OctoPrint instance on our machine with those of another + one who happens to listen on the same address albeit a different port. + """ + return "_P" + self.server_port + + +class OctoPrintFlaskResponse(flask.Response): + def set_cookie(self, key, *args, **kwargs): + # restrict cookie path to script root + kwargs["path"] = flask.request.script_root + kwargs.get("path", "/") + + # add request specific cookie suffix to name + flask.Response.set_cookie(self, key + flask.request.cookie_suffix, *args, **kwargs) + + def delete_cookie(self, key, *args, **kwargs): + # restrict cookie path to script root + kwargs["path"] = flask.request.script_root + kwargs.get("path", "/") + + # add request specific cookie suffix to name + flask.Response.delete_cookie(self, key + flask.request.cookie_suffix, *args, **kwargs) + #~~ passive login helper def passive_login(): diff --git a/src/octoprint/settings.py b/src/octoprint/settings.py index c99fba5f..63c5cab5 100644 --- a/src/octoprint/settings.py +++ b/src/octoprint/settings.py @@ -108,12 +108,16 @@ default_settings = { "firstRun": True, "secretKey": None, "reverseProxy": { - "prefixHeader": "X-Script-Name", - "schemeHeader": "X-Scheme", - "hostHeader": "X-Forwarded-Host", - "prefixFallback": "", - "schemeFallback": "", - "hostFallback": "" + "prefixHeader": None, + "schemeHeader": None, + "hostHeader": None, + "serverHeader": None, + "portHeader": None, + "prefixFallback": None, + "schemeFallback": None, + "hostFallback": None, + "serverFallback": None, + "portFallback": None }, "uploads": { "maxSize": 1 * 1024 * 1024 * 1024, # 1GB diff --git a/tests/server/__init__.py b/tests/server/__init__.py new file mode 100644 index 00000000..1010c694 --- /dev/null +++ b/tests/server/__init__.py @@ -0,0 +1,11 @@ +# coding=utf-8 +""" +Unit tests for ``octoprint.server``. +""" + +from __future__ import absolute_import + +__author__ = "Gina Häußge " +__license__ = 'GNU Affero General Public License http://www.gnu.org/licenses/agpl.html' +__copyright__ = "Copyright (C) 2016 The OctoPrint Project - Released under terms of the AGPLv3 License" + diff --git a/tests/server/util/__init__.py b/tests/server/util/__init__.py new file mode 100644 index 00000000..be5059e3 --- /dev/null +++ b/tests/server/util/__init__.py @@ -0,0 +1,10 @@ +# coding=utf-8 +""" +Unit tests for ``octoprint.server.util``. +""" + +from __future__ import absolute_import + +__author__ = "Gina Häußge " +__license__ = 'GNU Affero General Public License http://www.gnu.org/licenses/agpl.html' +__copyright__ = "Copyright (C) 2016 The OctoPrint Project - Released under terms of the AGPLv3 License" diff --git a/tests/server/util/flask.py b/tests/server/util/flask.py new file mode 100644 index 00000000..e18e2dfb --- /dev/null +++ b/tests/server/util/flask.py @@ -0,0 +1,411 @@ +# coding=utf-8 +""" +Unit tests for ``octoprint.server.util.flask``. +""" + +from __future__ import absolute_import + +__author__ = "Gina Häußge " +__license__ = 'GNU Affero General Public License http://www.gnu.org/licenses/agpl.html' +__copyright__ = "Copyright (C) 2016 The OctoPrint Project - Released under terms of the AGPLv3 License" + + +import unittest +import mock +from ddt import ddt, data, unpack + +from octoprint.server.util.flask import ReverseProxiedEnvironment, OctoPrintFlaskRequest, OctoPrintFlaskResponse + +standard_environ = { + "HTTP_HOST": "localhost:5000", + "SERVER_NAME": "localhost", + "SERVER_PORT": "5000", + "SCRIPT_NAME": "", + "PATH_INFO": "/", + "wsgi.url_scheme": "http" +} + +@ddt +class ReverseProxiedEnvironmentTest(unittest.TestCase): + + @data( + # defaults + ({}, + {}), + + # prefix set, path info not prefixed + ({ + "HTTP_X_SCRIPT_NAME": "/octoprint", + "PATH_INFO": "/static/online.gif" + }, { + "SCRIPT_NAME": "/octoprint" + }), + + # prefix set, path info prefixed + ({ + "HTTP_X_SCRIPT_NAME": "/octoprint", + "PATH_INFO": "/octoprint/static/online.gif", + }, { + "SCRIPT_NAME": "/octoprint", + "PATH_INFO": "/static/online.gif" + }), + + # host set + ({ + "HTTP_X_FORWARDED_HOST": "example.com" + }, { + "HTTP_HOST": "example.com", + "SERVER_NAME": "example.com", + "SERVER_PORT": "80" + }), + + # host set with port + ({ + "HTTP_X_FORWARDED_HOST": "example.com:1234" + }, { + "HTTP_HOST": "example.com:1234", + "SERVER_NAME": "example.com", + "SERVER_PORT": "1234" + }), + + # host and scheme set + ({ + "HTTP_X_FORWARDED_HOST": "example.com", + "HTTP_X_FORWARDED_PROTO": "https" + }, { + "HTTP_HOST": "example.com", + "SERVER_NAME": "example.com", + "SERVER_PORT": "443", + "wsgi.url_scheme": "https" + }), + + # host and scheme 2 set + ({ + "HTTP_X_FORWARDED_HOST": "example.com", + "HTTP_X_SCHEME": "https" + }, { + "HTTP_HOST": "example.com", + "SERVER_NAME": "example.com", + "SERVER_PORT": "443", + "wsgi.url_scheme": "https" + }), + + # host, server and port headers set -> only host wins + ({ + "HTTP_X_FORWARDED_HOST": "example.com", + "HTTP_X_FORWARDED_SERVER": "example2.com", + "HTTP_X_FORWARDED_PORT": "444", + "HTTP_X_FORWARDED_PROTO": "https" + }, { + "HTTP_HOST": "example.com", + "SERVER_NAME": "example.com", + "SERVER_PORT": "443", + "wsgi.url_scheme": "https" + }), + + # server and port headers set -> host derived with port + ({ + "HTTP_X_FORWARDED_SERVER": "example2.com", + "HTTP_X_FORWARDED_PORT": "444", + "HTTP_X_FORWARDED_PROTO": "https" + }, { + "HTTP_HOST": "example2.com:444", + "SERVER_NAME": "example2.com", + "SERVER_PORT": "444", + "wsgi.url_scheme": "https" + }), + + # server and port headers set, standard port -> host derived, no port + ({ + "HTTP_X_FORWARDED_SERVER": "example.com", + "HTTP_X_FORWARDED_PORT": "80", + }, { + "HTTP_HOST": "example.com", + "SERVER_NAME": "example.com", + "SERVER_PORT": "80", + }), + + # multiple scheme entries -> only use first one + ({ + "HTTP_X_FORWARDED_PROTO": "https,http", + }, { + "wsgi.url_scheme": "https" + }), + + # host = none -> should never happen but you never know... + ({ + "HTTP_HOST": None, + "HTTP_X_FORWARDED_SERVER": "example.com", + "HTTP_X_FORWARDED_PORT": "80" + }, { + "HTTP_HOST": "example.com", + "SERVER_NAME": "example.com", + "SERVER_PORT": "80" + }) + ) + @unpack + def test_stock(self, environ, expected): + reverse_proxied = ReverseProxiedEnvironment() + + merged_environ = dict(standard_environ) + merged_environ.update(environ) + + actual = reverse_proxied(merged_environ) + + merged_expected = dict(standard_environ) + merged_expected.update(environ) + merged_expected.update(expected) + + self.assertDictEqual(merged_expected, actual) + + @data( + # prefix overridden + ({ + "prefix": "fallback_prefix" + }, { + }, { + "SCRIPT_NAME": "fallback_prefix", + }), + + # scheme overridden + ({ + "scheme": "https" + }, { + }, { + "wsgi.url_scheme": "https" + }), + + # host overridden, default port + ({ + "host": "example.com" + }, { + }, { + "HTTP_HOST": "example.com", + "SERVER_NAME": "example.com", + "SERVER_PORT": "80" + }), + + # host overridden, included port + ({ + "host": "example.com:81" + }, { + }, { + "HTTP_HOST": "example.com:81", + "SERVER_NAME": "example.com", + "SERVER_PORT": "81" + }), + + # server overridden + ({ + "server": "example.com" + }, { + }, { + "HTTP_HOST": "example.com:5000", + "SERVER_NAME": "example.com", + "SERVER_PORT": "5000" + }), + + # port overridden, standard port + ({ + "port": "80" + }, { + }, { + "HTTP_HOST": "localhost", + "SERVER_PORT": "80" + }), + + # port overridden, non standard port + ({ + "port": "81" + }, { + }, { + "HTTP_HOST": "localhost:81", + "SERVER_PORT": "81" + }), + + # server and port overridden, default port + ({ + "server": "example.com", + "port": "80" + }, { + }, { + "HTTP_HOST": "example.com", + "SERVER_NAME": "example.com", + "SERVER_PORT": "80" + }), + + # server and port overridden, non default port + ({ + "server": "example.com", + "port": "81" + }, { + }, { + "HTTP_HOST": "example.com:81", + "SERVER_NAME": "example.com", + "SERVER_PORT": "81" + }), + + # prefix not really overridden + ({ + "prefix": "/octoprint" + }, { + "HTTP_X_SCRIPT_NAME": "" + }, { + }), + + # scheme not really overridden + ({ + "scheme": "https" + }, { + "HTTP_X_FORWARDED_PROTO": "http" + }, { + }), + + # scheme 2 not really overridden + ({ + "scheme": "https" + }, { + "HTTP_X_SCHEME": "http" + }, { + }), + + # host not really overridden + ({ + "host": "example.com:444" + }, { + "HTTP_X_FORWARDED_HOST": "localhost:5000" + }, { + }), + + # server not really overridden + ({ + "server": "example.com" + }, { + "HTTP_X_FORWARDED_SERVER": "localhost" + }, { + }), + + # port not really overridden + ({ + "port": "444" + }, { + "HTTP_X_FORWARDED_PORT": "5000" + }, { + }) + ) + @unpack + def test_fallbacks(self, fallbacks, environ, expected): + reverse_proxied = ReverseProxiedEnvironment(**fallbacks) + + merged_environ = dict(standard_environ) + merged_environ.update(environ) + + actual = reverse_proxied(merged_environ) + + merged_expected = dict(standard_environ) + merged_expected.update(environ) + merged_expected.update(expected) + + self.assertDictEqual(merged_expected, actual) + + def test_header_config_ok(self): + result = ReverseProxiedEnvironment.to_header_candidates(["prefix-header1", "prefix-header2"]) + self.assertEquals(result, ["HTTP_PREFIX_HEADER1", "HTTP_PREFIX_HEADER2"]) + + def test_header_config_string(self): + result = ReverseProxiedEnvironment.to_header_candidates("prefix-header") + self.assertEquals(result, ["HTTP_PREFIX_HEADER"]) + + def test_header_config_none(self): + result = ReverseProxiedEnvironment.to_header_candidates(None) + self.assertEquals(result, []) + +##~~ + +class OctoPrintFlaskRequestTest(unittest.TestCase): + + def setUp(self): + self.orig_environment_wrapper = OctoPrintFlaskRequest.environment_wrapper + + def tearDown(self): + OctoPrintFlaskRequest.environment_wrapper = staticmethod(self.orig_environment_wrapper) + + def test_environment_wrapper(self): + def environment_wrapper(environ): + environ.update({ + "TEST": "yes" + }) + return environ + + OctoPrintFlaskRequest.environment_wrapper = staticmethod(environment_wrapper) + request = OctoPrintFlaskRequest(standard_environ) + + self.assertTrue("TEST" in request.environ) + + def test_server_name(self): + request = OctoPrintFlaskRequest(standard_environ) + self.assertEquals(request.server_name, "localhost") + + def test_server_port(self): + request = OctoPrintFlaskRequest(standard_environ) + self.assertEquals(request.server_port, "5000") + + def test_cookie_suffix(self): + request = OctoPrintFlaskRequest(standard_environ) + self.assertEquals(request.cookie_suffix, "_P5000") + + def test_cookies(self): + environ = dict(standard_environ) + environ["HTTP_COOKIE"] = "postfixed_P5000=postfixed_value; " \ + "postfixed_wrong_P5001=postfixed_wrong_value; " \ + "unpostfixed=unpostfixed_value" + + request = OctoPrintFlaskRequest(environ) + + cookies = request.cookies + self.assertDictEqual(cookies, {"postfixed": "postfixed_value", + "postfixed_wrong_P5001": "postfixed_wrong_value", + "unpostfixed": "unpostfixed_value"}) + +##~~ + +@ddt +class OctoPrintFlaskResponseTest(unittest.TestCase): + + @data([None, None], + ["/subfolder/", None], + [None, "/some/other/script/root"], + ["/subfolder/", "/some/other/script/root"]) + @unpack + def test_cookie_set_and_delete(self, path, scriptroot): + environ = dict(standard_environ) + + if scriptroot is not None: + environ.update(dict(SCRIPT_NAME=scriptroot)) + + request = OctoPrintFlaskRequest(environ) + + if path: + expected_path = path + else: + expected_path = "/" + if scriptroot: + expected_path = scriptroot + expected_path + + if path is not None: + kwargs = dict(path=path) + else: + kwargs = dict() + + with mock.patch("flask.request", new=request): + response = OctoPrintFlaskResponse() + + # test set_cookie + with mock.patch("flask.Response.set_cookie") as set_cookie_mock: + response.set_cookie("some_key", "some_value", **kwargs) + set_cookie_mock.assert_called_once_with(response, "some_key_P5000", "some_value", path=expected_path) + + # test delete_cookie + with mock.patch("flask.Response.delete_cookie") as delete_cookie_mock: + response.delete_cookie("some_key", "some_value", **kwargs) + delete_cookie_mock.assert_called_once_with(response, "some_key_P5000", "some_value", path=expected_path)