Cookie names are now port specific, based on the request associated with a response

* make sure server_port headers are properly set in reverse proxied scenarios
  * overwrite request and response classes to
    * always apply reverse proxy environment changes (so far missing for tornado
      context)
    * strip cookie name suffixes from cookie names on requests and
    * be sure to set cookie name suffixes for cookie names on responses
    * include script root in path used for cookies
  * some minor refactoring in octoprint.server setup routines
  * removed ReverseProxied class (didn't work for tornado context)
  * add unit tests for the whole reverse proxy, request and response customization
This commit is contained in:
Gina Häußge 2016-09-05 12:06:56 +02:00
parent ce28b637ab
commit 9d9eb3390d
7 changed files with 695 additions and 132 deletions

View file

@ -7,7 +7,7 @@ __copyright__ = "Copyright (C) 2014 The OctoPrint Project - Released under terms
import uuid
from sockjs.tornado import SockJSRouter
from flask import Flask, g, request, session, Blueprint
from flask import Flask, g, request, session, Blueprint, Request, Response
from flask.ext.login import LoginManager, current_user
from flask.ext.principal import Principal, Permission, RoleNeed, identity_loaded, UserNeed
from flask.ext.babel import Babel, gettext, ngettext
@ -168,7 +168,7 @@ class Server(object):
util.flask.enable_additional_translations(additional_folders=[s.getBaseFolder("translations")])
# setup app
self._setup_app()
self._setup_app(app)
# setup i18n
self._setup_i18n(app)
@ -199,9 +199,15 @@ class Server(object):
pluginLifecycleManager = LifecycleManager(pluginManager)
preemptiveCache = PreemptiveCache(os.path.join(s.getBaseFolder("data"), "preemptive_cache_config.yaml"))
# ... and initialize all plugins
def octoprint_plugin_inject_factory(name, implementation):
"""Factory for injections for all OctoPrintPlugins"""
if not isinstance(implementation, octoprint.plugin.OctoPrintPlugin):
# we only care about OctoPrintPlugins
return None
return dict(
plugin_manager=pluginManager,
printer_profile_manager=printerProfileManager,
@ -217,8 +223,13 @@ class Server(object):
)
def settings_plugin_inject_factory(name, implementation):
"""Factory for additional injections depending on plugin type"""
if not isinstance(implementation, octoprint.plugin.SettingsPlugin):
# we only care about SettingsPlugins
return None
# SettingsPlugin instnances get a PluginSettings instance injected
default_settings = implementation.get_settings_defaults()
get_preprocessors, set_preprocessors = implementation.get_settings_preprocessors()
plugin_settings = octoprint.plugin.plugin_settings(name,
@ -228,6 +239,8 @@ class Server(object):
return dict(settings=plugin_settings)
def settings_plugin_config_migration_and_cleanup(name, implementation):
"""Take care of migrating and cleaning up any old settings"""
if not isinstance(implementation, octoprint.plugin.SettingsPlugin):
return
@ -269,6 +282,8 @@ class Server(object):
# setup jinja2
self._setup_jinja2()
# make sure plugin lifecycle events relevant for jinja2 are taken care of
def template_enabled(name, plugin):
if plugin.implementation is None or not isinstance(plugin.implementation, octoprint.plugin.TemplatePlugin):
return
@ -296,31 +311,12 @@ class Server(object):
try:
clazz = octoprint.util.get_class(userManagerName)
userManager = clazz()
except AttributeError, e:
except AttributeError as e:
self._logger.exception("Could not instantiate user manager {}, falling back to FilebasedUserManager!".format(userManagerName))
userManager = octoprint.users.FilebasedUserManager()
finally:
userManager.enabled = s.getBoolean(["accessControl", "enabled"])
app.wsgi_app = util.ReverseProxied(
app.wsgi_app,
s.get(["server", "reverseProxy", "prefixHeader"]),
s.get(["server", "reverseProxy", "schemeHeader"]),
s.get(["server", "reverseProxy", "hostHeader"]),
s.get(["server", "reverseProxy", "prefixFallback"]),
s.get(["server", "reverseProxy", "schemeFallback"]),
s.get(["server", "reverseProxy", "hostFallback"])
)
secret_key = s.get(["server", "secretKey"])
if not secret_key:
import string
from random import choice
chars = string.ascii_lowercase + string.ascii_uppercase + string.digits
secret_key = "".join(choice(chars) for _ in xrange(32))
s.set(["server", "secretKey"], secret_key)
s.save()
app.secret_key = secret_key
loginManager = LoginManager()
loginManager.session_protection = "strong"
loginManager.user_callback = load_user
@ -329,18 +325,16 @@ class Server(object):
principals.identity_loaders.appendleft(users.dummy_identity_loader)
loginManager.init_app(app)
if self._host is None:
self._host = s.get(["server", "host"])
if self._port is None:
self._port = s.getInt(["server", "port"])
app.debug = self._debug
# register API blueprint
self._setup_blueprints()
## Tornado initialization starts here
if self._host is None:
self._host = s.get(["server", "host"])
if self._port is None:
self._port = s.getInt(["server", "port"])
ioloop = IOLoop()
ioloop.install()
@ -382,6 +376,8 @@ class Server(object):
(r"/online.txt", util.tornado.StaticDataHandler, dict(data="online\n")),
(r"/online.gif", util.tornado.StaticDataHandler, dict(data=bytes(base64.b64decode("R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7")), content_type="image/gif"))
]
# fetch additional routes from plugins
for name, hook in pluginManager.get_hooks("octoprint.server.http.routes").items():
try:
result = hook(list(server_routes))
@ -435,10 +431,13 @@ class Server(object):
self._stop_intermediary_server()
# initialize and bind the server
self._server = util.tornado.CustomHTTPServer(self._tornado_app, max_body_sizes=max_body_sizes, default_max_body_size=s.getInt(["server", "maxSize"]))
self._server.listen(self._port, address=self._host)
eventManager.fire(events.Events.STARTUP)
# auto connect
if s.getBoolean(["serial", "autoconnect"]):
(port, baudrate) = s.get(["serial", "port"]), s.getInt(["serial", "baudrate"])
printer_profile = printerProfileManager.get_default()
@ -531,7 +530,8 @@ class Server(object):
def _create_socket_connection(self, session):
global printer, fileManager, analysisQueue, userManager, eventManager
return util.sockjs.PrinterStateConnection(printer, fileManager, analysisQueue, userManager, eventManager, pluginManager, session)
return util.sockjs.PrinterStateConnection(printer, fileManager, analysisQueue, userManager,
eventManager, pluginManager, session)
def _check_for_root(self):
if "geteuid" in dir(os) and os.geteuid() == 0:
@ -635,7 +635,41 @@ class Server(object):
logging.getLogger("SERIAL").setLevel(logging.DEBUG)
logging.getLogger("SERIAL").debug("Enabling serial logging")
def _setup_app(self):
def _setup_app(self, app):
from octoprint.server.util.flask import ReverseProxiedEnvironment, OctoPrintFlaskRequest, OctoPrintFlaskResponse
s = settings()
app.debug = self._debug
secret_key = s.get(["server", "secretKey"])
if not secret_key:
import string
from random import choice
chars = string.ascii_lowercase + string.ascii_uppercase + string.digits
secret_key = "".join(choice(chars) for _ in range(32))
s.set(["server", "secretKey"], secret_key)
s.save()
app.secret_key = secret_key
reverse_proxied = ReverseProxiedEnvironment(
header_prefix=s.get(["server", "reverseProxy", "prefixHeader"]),
header_scheme=s.get(["server", "reverseProxy", "schemeHeader"]),
header_host=s.get(["server", "reverseProxy", "hostHeader"]),
header_server=s.get(["server", "reverseProxy", "serverHeader"]),
header_port=s.get(["server", "reverseProxy", "portHeader"]),
prefix=s.get(["server", "reverseProxy", "prefixFallback"]),
scheme=s.get(["server", "reverseProxy", "schemeFallback"]),
host=s.get(["server", "reverseProxy", "hostFallback"]),
server=s.get(["server", "reverseProxy", "serverFallback"]),
port=s.get(["server", "reverseProxy", "portFallback"])
)
OctoPrintFlaskRequest.environment_wrapper = reverse_proxied
app.request_class = OctoPrintFlaskRequest
app.response_class = OctoPrintFlaskResponse
@app.before_request
def before_request():
g.locale = self._get_locale()
@ -932,7 +966,7 @@ class Server(object):
# that might be caused by the user still having the folder open somewhere, let's try again after
# waiting a bit
import time
for n in xrange(3):
for n in range(3):
time.sleep(0.5)
self._logger.debug("Creating {path}: Retry #{retry} after {time}s".format(path=path, retry=n+1, time=(n + 1)*0.5))
try:

View file

@ -155,97 +155,3 @@ def get_plugin_hash():
plugin_hash = hashlib.sha1()
plugin_hash.update(",".join(ui_plugins))
return plugin_hash.hexdigest()
#~~ reverse proxy compatible WSGI middleware
class ReverseProxied(object):
"""
Wrap the application in this middleware and configure the
front-end server to add these headers, to let you quietly bind
this to a URL other than / and to an HTTP scheme that is
different than what is used locally.
In nginx:
.. code-block:: none
location /myprefix {
proxy_pass http://192.168.0.1:5001;
proxy_set_header Host $host;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Scheme $scheme;
proxy_set_header X-Script-Name /myprefix;
}
Alternatively define prefix and scheme via config.yaml:
.. code-block:: yaml
server:
baseUrl: /myprefix
scheme: http
:param app: the WSGI application
:param header_script_name: the HTTP header in the wsgi environment from which to determine the prefix
:param header_scheme: the HTTP header in the wsgi environment from which to determine the scheme
:param header_host: the HTTP header in the wsgi environment from which to determine the host for which to generate external URLs
:param base_url: the prefix to use as fallback if headers are not set
:param scheme: the scheme to use as fallback if headers are not set
:param host: the host to use as fallback if headers are not set
"""
def __init__(self, app, header_prefix="x-script-name", header_scheme="x-scheme", header_host="x-forwarded-host", base_url="", scheme="", host=""):
self.app = app
# headers for prefix & scheme & host, converted to conform to WSGI format
to_wsgi_format = lambda header: "HTTP_" + header.upper().replace("-", "_")
self._header_prefix = to_wsgi_format(header_prefix)
self._header_scheme = to_wsgi_format(header_scheme)
self._header_host = to_wsgi_format(header_host)
# fallback prefix & scheme & host from config
self._fallback_prefix = base_url
self._fallback_scheme = scheme
self._fallback_host = host
def __call__(self, environ, start_response):
# determine prefix
prefix = environ.get(self._header_prefix, "")
if not prefix:
prefix = self._fallback_prefix
# rewrite SCRIPT_NAME and if necessary also PATH_INFO based on prefix
if prefix:
environ["SCRIPT_NAME"] = prefix
path_info = environ["PATH_INFO"]
if path_info.startswith(prefix):
environ["PATH_INFO"] = path_info[len(prefix):]
# determine scheme
scheme = environ.get(self._header_scheme, "")
if scheme and "," in scheme:
# Scheme might be something like "https,https" if doubly-reverse-proxied
# without stripping original scheme header first, make sure to only use
# the first entry in such a case. See #1391.
scheme, _ = map(lambda x: x.strip(), scheme.split(",", 1))
if not scheme:
scheme = self._fallback_scheme
# rewrite wsgi.url_scheme based on scheme
if scheme:
environ["wsgi.url_scheme"] = scheme
# determine host
host = environ.get(self._header_host, "")
if not host:
host = self._fallback_host
# rewrite host header based on host
if host:
environ["HTTP_HOST"] = host
# call wrapped app with rewritten environment
return self.app(environ, start_response)

View file

@ -219,6 +219,193 @@ def fix_webassets_filtertool():
FilterTool._wrap_cache = fixed_wrap_cache
#~~ WSGI environment wrapper for reverse proxying
class ReverseProxiedEnvironment(object):
@staticmethod
def to_header_candidates(values):
if values is None:
return []
if not isinstance(values, (list, tuple)):
values = [values]
to_wsgi_format = lambda header: "HTTP_" + header.upper().replace("-", "_")
return map(to_wsgi_format, values)
def __init__(self,
header_prefix=None,
header_scheme=None,
header_host=None,
header_server=None,
header_port=None,
prefix=None,
scheme=None,
host=None,
server=None,
port=None):
# sensible defaults
if header_prefix is None:
header_prefix = ["x-script-name"]
if header_scheme is None:
header_scheme = ["x-forwarded-proto", "x-scheme"]
if header_host is None:
header_host = ["x-forwarded-host"]
if header_server is None:
header_server = ["x-forwarded-server"]
if header_port is None:
header_port = ["x-forwarded-port"]
# header candidates
self._headers_prefix = self.to_header_candidates(header_prefix)
self._headers_scheme = self.to_header_candidates(header_scheme)
self._headers_host = self.to_header_candidates(header_host)
self._headers_server = self.to_header_candidates(header_server)
self._headers_port = self.to_header_candidates(header_port)
# fallback prefix & scheme & host from config
self._fallback_prefix = prefix
self._fallback_scheme = scheme
self._fallback_host = host
self._fallback_server = server
self._fallback_port = port
def __call__(self, environ):
def retrieve_header(header_type):
candidates = getattr(self, "_headers_" + header_type, [])
fallback = getattr(self, "_fallback_" + header_type, None)
for candidate in candidates:
value = environ.get(candidate, None)
if value is not None:
return value
else:
return fallback
def host_to_server_and_port(host, scheme):
if host is None:
return None, None
if ":" in host:
server, port = host.split(":", 1)
else:
server = host
port = "443" if scheme == "https" else "80"
return server, port
# determine prefix
prefix = retrieve_header("prefix")
if prefix is not None:
environ["SCRIPT_NAME"] = prefix
path_info = environ["PATH_INFO"]
if path_info.startswith(prefix):
environ["PATH_INFO"] = path_info[len(prefix):]
# determine scheme
scheme = retrieve_header("scheme")
if scheme is not None and "," in scheme:
# Scheme might be something like "https,https" if doubly-reverse-proxied
# without stripping original scheme header first, make sure to only use
# the first entry in such a case. See #1391.
scheme, _ = map(lambda x: x.strip(), scheme.split(",", 1))
if scheme is not None:
environ["wsgi.url_scheme"] = scheme
# determine host
url_scheme = environ["wsgi.url_scheme"]
host = retrieve_header("host")
if host is not None:
# if we have a host, we take server_name and server_port from it
server, port = host_to_server_and_port(host, url_scheme)
environ["HTTP_HOST"] = host
environ["SERVER_NAME"] = server
environ["SERVER_PORT"] = port
else:
# else we take a look at the server and port headers and if we have
# something there we derive the host from it
# determine server - should usually not be used
server = retrieve_header("server")
if server is not None:
environ["SERVER_NAME"] = server
# determine port - should usually not be used
port = retrieve_header("port")
if port is not None:
environ["SERVER_PORT"] = port
# make sure HTTP_HOST matches SERVER_NAME and SERVER_PORT
expected_server, expected_port = host_to_server_and_port(environ.get("HTTP_HOST", None), url_scheme)
if expected_server != environ["SERVER_NAME"] or expected_port != environ["SERVER_PORT"]:
# there's a difference, fix it!
if url_scheme == "http" and environ["SERVER_PORT"] == "80" or url_scheme == "https" and environ["SERVER_PORT"] == "443":
# default port for scheme, can be skipped
environ["HTTP_HOST"] = environ["SERVER_NAME"]
else:
environ["HTTP_HOST"] = environ["SERVER_NAME"] + ":" + environ["SERVER_PORT"]
# call wrapped app with rewritten environment
return environ
#~~ request and response versions
from werkzeug.wrappers import cached_property
class OctoPrintFlaskRequest(flask.Request):
environment_wrapper = staticmethod(lambda x: x)
def __init__(self, environ, *args, **kwargs):
# apply environment wrapper to provided WSGI environment
flask.Request.__init__(self, self.environment_wrapper(environ), *args, **kwargs)
@cached_property
def cookies(self):
# strip cookie_suffix from all cookies in the request, return result
cookies = flask.Request.cookies.__get__(self)
def cookie_name_converter(key):
return key[:-len(self.cookie_suffix)] if key.endswith(self.cookie_suffix) else key
return dict((cookie_name_converter(key), value) for key, value in cookies.items())
@cached_property
def server_name(self):
"""Short cut to the request's server name header"""
return self.environ.get("SERVER_NAME")
@cached_property
def server_port(self):
"""Short cut to the request's server port header"""
return self.environ.get("SERVER_PORT")
@cached_property
def cookie_suffix(self):
"""
Request specific suffix for set and read cookies
We need this because cookies are not port-specific and we don't want to overwrite our
session and other cookies from one OctoPrint instance on our machine with those of another
one who happens to listen on the same address albeit a different port.
"""
return "_P" + self.server_port
class OctoPrintFlaskResponse(flask.Response):
def set_cookie(self, key, *args, **kwargs):
# restrict cookie path to script root
kwargs["path"] = flask.request.script_root + kwargs.get("path", "/")
# add request specific cookie suffix to name
flask.Response.set_cookie(self, key + flask.request.cookie_suffix, *args, **kwargs)
def delete_cookie(self, key, *args, **kwargs):
# restrict cookie path to script root
kwargs["path"] = flask.request.script_root + kwargs.get("path", "/")
# add request specific cookie suffix to name
flask.Response.delete_cookie(self, key + flask.request.cookie_suffix, *args, **kwargs)
#~~ passive login helper
def passive_login():

View file

@ -108,12 +108,16 @@ default_settings = {
"firstRun": True,
"secretKey": None,
"reverseProxy": {
"prefixHeader": "X-Script-Name",
"schemeHeader": "X-Scheme",
"hostHeader": "X-Forwarded-Host",
"prefixFallback": "",
"schemeFallback": "",
"hostFallback": ""
"prefixHeader": None,
"schemeHeader": None,
"hostHeader": None,
"serverHeader": None,
"portHeader": None,
"prefixFallback": None,
"schemeFallback": None,
"hostFallback": None,
"serverFallback": None,
"portFallback": None
},
"uploads": {
"maxSize": 1 * 1024 * 1024 * 1024, # 1GB

11
tests/server/__init__.py Normal file
View file

@ -0,0 +1,11 @@
# coding=utf-8
"""
Unit tests for ``octoprint.server``.
"""
from __future__ import absolute_import
__author__ = "Gina Häußge <osd@foosel.net>"
__license__ = 'GNU Affero General Public License http://www.gnu.org/licenses/agpl.html'
__copyright__ = "Copyright (C) 2016 The OctoPrint Project - Released under terms of the AGPLv3 License"

View file

@ -0,0 +1,10 @@
# coding=utf-8
"""
Unit tests for ``octoprint.server.util``.
"""
from __future__ import absolute_import
__author__ = "Gina Häußge <osd@foosel.net>"
__license__ = 'GNU Affero General Public License http://www.gnu.org/licenses/agpl.html'
__copyright__ = "Copyright (C) 2016 The OctoPrint Project - Released under terms of the AGPLv3 License"

411
tests/server/util/flask.py Normal file
View file

@ -0,0 +1,411 @@
# coding=utf-8
"""
Unit tests for ``octoprint.server.util.flask``.
"""
from __future__ import absolute_import
__author__ = "Gina Häußge <osd@foosel.net>"
__license__ = 'GNU Affero General Public License http://www.gnu.org/licenses/agpl.html'
__copyright__ = "Copyright (C) 2016 The OctoPrint Project - Released under terms of the AGPLv3 License"
import unittest
import mock
from ddt import ddt, data, unpack
from octoprint.server.util.flask import ReverseProxiedEnvironment, OctoPrintFlaskRequest, OctoPrintFlaskResponse
standard_environ = {
"HTTP_HOST": "localhost:5000",
"SERVER_NAME": "localhost",
"SERVER_PORT": "5000",
"SCRIPT_NAME": "",
"PATH_INFO": "/",
"wsgi.url_scheme": "http"
}
@ddt
class ReverseProxiedEnvironmentTest(unittest.TestCase):
@data(
# defaults
({},
{}),
# prefix set, path info not prefixed
({
"HTTP_X_SCRIPT_NAME": "/octoprint",
"PATH_INFO": "/static/online.gif"
}, {
"SCRIPT_NAME": "/octoprint"
}),
# prefix set, path info prefixed
({
"HTTP_X_SCRIPT_NAME": "/octoprint",
"PATH_INFO": "/octoprint/static/online.gif",
}, {
"SCRIPT_NAME": "/octoprint",
"PATH_INFO": "/static/online.gif"
}),
# host set
({
"HTTP_X_FORWARDED_HOST": "example.com"
}, {
"HTTP_HOST": "example.com",
"SERVER_NAME": "example.com",
"SERVER_PORT": "80"
}),
# host set with port
({
"HTTP_X_FORWARDED_HOST": "example.com:1234"
}, {
"HTTP_HOST": "example.com:1234",
"SERVER_NAME": "example.com",
"SERVER_PORT": "1234"
}),
# host and scheme set
({
"HTTP_X_FORWARDED_HOST": "example.com",
"HTTP_X_FORWARDED_PROTO": "https"
}, {
"HTTP_HOST": "example.com",
"SERVER_NAME": "example.com",
"SERVER_PORT": "443",
"wsgi.url_scheme": "https"
}),
# host and scheme 2 set
({
"HTTP_X_FORWARDED_HOST": "example.com",
"HTTP_X_SCHEME": "https"
}, {
"HTTP_HOST": "example.com",
"SERVER_NAME": "example.com",
"SERVER_PORT": "443",
"wsgi.url_scheme": "https"
}),
# host, server and port headers set -> only host wins
({
"HTTP_X_FORWARDED_HOST": "example.com",
"HTTP_X_FORWARDED_SERVER": "example2.com",
"HTTP_X_FORWARDED_PORT": "444",
"HTTP_X_FORWARDED_PROTO": "https"
}, {
"HTTP_HOST": "example.com",
"SERVER_NAME": "example.com",
"SERVER_PORT": "443",
"wsgi.url_scheme": "https"
}),
# server and port headers set -> host derived with port
({
"HTTP_X_FORWARDED_SERVER": "example2.com",
"HTTP_X_FORWARDED_PORT": "444",
"HTTP_X_FORWARDED_PROTO": "https"
}, {
"HTTP_HOST": "example2.com:444",
"SERVER_NAME": "example2.com",
"SERVER_PORT": "444",
"wsgi.url_scheme": "https"
}),
# server and port headers set, standard port -> host derived, no port
({
"HTTP_X_FORWARDED_SERVER": "example.com",
"HTTP_X_FORWARDED_PORT": "80",
}, {
"HTTP_HOST": "example.com",
"SERVER_NAME": "example.com",
"SERVER_PORT": "80",
}),
# multiple scheme entries -> only use first one
({
"HTTP_X_FORWARDED_PROTO": "https,http",
}, {
"wsgi.url_scheme": "https"
}),
# host = none -> should never happen but you never know...
({
"HTTP_HOST": None,
"HTTP_X_FORWARDED_SERVER": "example.com",
"HTTP_X_FORWARDED_PORT": "80"
}, {
"HTTP_HOST": "example.com",
"SERVER_NAME": "example.com",
"SERVER_PORT": "80"
})
)
@unpack
def test_stock(self, environ, expected):
reverse_proxied = ReverseProxiedEnvironment()
merged_environ = dict(standard_environ)
merged_environ.update(environ)
actual = reverse_proxied(merged_environ)
merged_expected = dict(standard_environ)
merged_expected.update(environ)
merged_expected.update(expected)
self.assertDictEqual(merged_expected, actual)
@data(
# prefix overridden
({
"prefix": "fallback_prefix"
}, {
}, {
"SCRIPT_NAME": "fallback_prefix",
}),
# scheme overridden
({
"scheme": "https"
}, {
}, {
"wsgi.url_scheme": "https"
}),
# host overridden, default port
({
"host": "example.com"
}, {
}, {
"HTTP_HOST": "example.com",
"SERVER_NAME": "example.com",
"SERVER_PORT": "80"
}),
# host overridden, included port
({
"host": "example.com:81"
}, {
}, {
"HTTP_HOST": "example.com:81",
"SERVER_NAME": "example.com",
"SERVER_PORT": "81"
}),
# server overridden
({
"server": "example.com"
}, {
}, {
"HTTP_HOST": "example.com:5000",
"SERVER_NAME": "example.com",
"SERVER_PORT": "5000"
}),
# port overridden, standard port
({
"port": "80"
}, {
}, {
"HTTP_HOST": "localhost",
"SERVER_PORT": "80"
}),
# port overridden, non standard port
({
"port": "81"
}, {
}, {
"HTTP_HOST": "localhost:81",
"SERVER_PORT": "81"
}),
# server and port overridden, default port
({
"server": "example.com",
"port": "80"
}, {
}, {
"HTTP_HOST": "example.com",
"SERVER_NAME": "example.com",
"SERVER_PORT": "80"
}),
# server and port overridden, non default port
({
"server": "example.com",
"port": "81"
}, {
}, {
"HTTP_HOST": "example.com:81",
"SERVER_NAME": "example.com",
"SERVER_PORT": "81"
}),
# prefix not really overridden
({
"prefix": "/octoprint"
}, {
"HTTP_X_SCRIPT_NAME": ""
}, {
}),
# scheme not really overridden
({
"scheme": "https"
}, {
"HTTP_X_FORWARDED_PROTO": "http"
}, {
}),
# scheme 2 not really overridden
({
"scheme": "https"
}, {
"HTTP_X_SCHEME": "http"
}, {
}),
# host not really overridden
({
"host": "example.com:444"
}, {
"HTTP_X_FORWARDED_HOST": "localhost:5000"
}, {
}),
# server not really overridden
({
"server": "example.com"
}, {
"HTTP_X_FORWARDED_SERVER": "localhost"
}, {
}),
# port not really overridden
({
"port": "444"
}, {
"HTTP_X_FORWARDED_PORT": "5000"
}, {
})
)
@unpack
def test_fallbacks(self, fallbacks, environ, expected):
reverse_proxied = ReverseProxiedEnvironment(**fallbacks)
merged_environ = dict(standard_environ)
merged_environ.update(environ)
actual = reverse_proxied(merged_environ)
merged_expected = dict(standard_environ)
merged_expected.update(environ)
merged_expected.update(expected)
self.assertDictEqual(merged_expected, actual)
def test_header_config_ok(self):
result = ReverseProxiedEnvironment.to_header_candidates(["prefix-header1", "prefix-header2"])
self.assertEquals(result, ["HTTP_PREFIX_HEADER1", "HTTP_PREFIX_HEADER2"])
def test_header_config_string(self):
result = ReverseProxiedEnvironment.to_header_candidates("prefix-header")
self.assertEquals(result, ["HTTP_PREFIX_HEADER"])
def test_header_config_none(self):
result = ReverseProxiedEnvironment.to_header_candidates(None)
self.assertEquals(result, [])
##~~
class OctoPrintFlaskRequestTest(unittest.TestCase):
def setUp(self):
self.orig_environment_wrapper = OctoPrintFlaskRequest.environment_wrapper
def tearDown(self):
OctoPrintFlaskRequest.environment_wrapper = staticmethod(self.orig_environment_wrapper)
def test_environment_wrapper(self):
def environment_wrapper(environ):
environ.update({
"TEST": "yes"
})
return environ
OctoPrintFlaskRequest.environment_wrapper = staticmethod(environment_wrapper)
request = OctoPrintFlaskRequest(standard_environ)
self.assertTrue("TEST" in request.environ)
def test_server_name(self):
request = OctoPrintFlaskRequest(standard_environ)
self.assertEquals(request.server_name, "localhost")
def test_server_port(self):
request = OctoPrintFlaskRequest(standard_environ)
self.assertEquals(request.server_port, "5000")
def test_cookie_suffix(self):
request = OctoPrintFlaskRequest(standard_environ)
self.assertEquals(request.cookie_suffix, "_P5000")
def test_cookies(self):
environ = dict(standard_environ)
environ["HTTP_COOKIE"] = "postfixed_P5000=postfixed_value; " \
"postfixed_wrong_P5001=postfixed_wrong_value; " \
"unpostfixed=unpostfixed_value"
request = OctoPrintFlaskRequest(environ)
cookies = request.cookies
self.assertDictEqual(cookies, {"postfixed": "postfixed_value",
"postfixed_wrong_P5001": "postfixed_wrong_value",
"unpostfixed": "unpostfixed_value"})
##~~
@ddt
class OctoPrintFlaskResponseTest(unittest.TestCase):
@data([None, None],
["/subfolder/", None],
[None, "/some/other/script/root"],
["/subfolder/", "/some/other/script/root"])
@unpack
def test_cookie_set_and_delete(self, path, scriptroot):
environ = dict(standard_environ)
if scriptroot is not None:
environ.update(dict(SCRIPT_NAME=scriptroot))
request = OctoPrintFlaskRequest(environ)
if path:
expected_path = path
else:
expected_path = "/"
if scriptroot:
expected_path = scriptroot + expected_path
if path is not None:
kwargs = dict(path=path)
else:
kwargs = dict()
with mock.patch("flask.request", new=request):
response = OctoPrintFlaskResponse()
# test set_cookie
with mock.patch("flask.Response.set_cookie") as set_cookie_mock:
response.set_cookie("some_key", "some_value", **kwargs)
set_cookie_mock.assert_called_once_with(response, "some_key_P5000", "some_value", path=expected_path)
# test delete_cookie
with mock.patch("flask.Response.delete_cookie") as delete_cookie_mock:
response.delete_cookie("some_key", "some_value", **kwargs)
delete_cookie_mock.assert_called_once_with(response, "some_key_P5000", "some_value", path=expected_path)