diff --git a/src/octoprint/server/__init__.py b/src/octoprint/server/__init__.py index acbd0345..7c8577a5 100644 --- a/src/octoprint/server/__init__.py +++ b/src/octoprint/server/__init__.py @@ -26,6 +26,7 @@ import base64 SUCCESS = {} NO_CONTENT = ("", 204) +NOT_MODIFIED = ("Not Modified", 304) app = Flask("octoprint") assets = None @@ -43,6 +44,7 @@ loginManager = None pluginManager = None appSessionManager = None pluginLifecycleManager = None +preemptiveCache = None principals = Principal(app) admin_permission = Permission(RoleNeed("admin")) @@ -62,6 +64,7 @@ import octoprint.util import octoprint.filemanager.storage import octoprint.filemanager.analysis import octoprint.slicing +from octoprint.server.util.flask import PreemptiveCache from . import util @@ -145,6 +148,7 @@ class Server(): global pluginManager global appSessionManager global pluginLifecycleManager + global preemptiveCache global debug from tornado.ioloop import IOLoop @@ -191,6 +195,7 @@ class Server(): printer = Printer(fileManager, analysisQueue, printerProfileManager) appSessionManager = util.flask.AppSessionManager() pluginLifecycleManager = LifecycleManager(pluginManager) + preemptiveCache = PreemptiveCache(os.path.join(s.getBaseFolder("data"), "preemptive_cache_config.yaml")) def octoprint_plugin_inject_factory(name, implementation): if not isinstance(implementation, octoprint.plugin.OctoPrintPlugin): @@ -205,7 +210,8 @@ class Server(): printer=printer, app_session_manager=appSessionManager, plugin_lifecycle_manager=pluginLifecycleManager, - data_folder=os.path.join(settings().getBaseFolder("data"), name) + data_folder=os.path.join(settings().getBaseFolder("data"), name), + preemptive_cache=preemptiveCache ) def settings_plugin_inject_factory(name, implementation): @@ -478,6 +484,10 @@ class Server(): implementation.on_after_startup() pluginLifecycleManager.add_callback("enabled", call_on_after_startup) + # when we are through with that we also run our preemptive cache + if settings().getBoolean(["devel", "cache", "preemptive"]): + self._execute_preemptive_flask_caching(preemptiveCache) + import threading threading.Thread(target=work).start() ioloop.add_callback(on_after_startup) @@ -537,7 +547,7 @@ class Server(): if default_language is not None and not default_language == "_default" and default_language in LANGUAGES: return Locale.negotiate([default_language], LANGUAGES) - return request.accept_languages.best_match(LANGUAGES) + return Locale.parse(request.accept_languages.best_match(LANGUAGES)) def _setup_logging(self, debug, logConf=None): defaultConfig = { @@ -674,6 +684,40 @@ class Server(): self._register_template_plugins() + def _execute_preemptive_flask_caching(self, preemptive_cache): + from werkzeug.test import EnvironBuilder + import time + + # we clean up entries from our preemptive cache settings that haven't been + # accessed longer than server.preemptiveCache.until days + preemptive_cache_timeout = settings().getInt(["server", "preemptiveCache", "until"]) + cutoff_timestamp = time.time() + preemptive_cache_timeout * 24 * 60 * 60 + + cache_data = preemptive_cache.clean_all_data(lambda root, entries: filter(lambda entry: "_timestamp" in entry and entry["_timestamp"] <= cutoff_timestamp, entries)) + if not cache_data: + return + + def execute_caching(): + for route in sorted(cache_data.keys(), key=lambda x: (x.count("/"), x)): + entries = cache_data[route] + for kwargs in entries: + additional_request_data = kwargs.get("_additional_request_data", dict()) + kwargs = dict((k, v) for k, v in kwargs.items() if not k.startswith("_")) + kwargs.update(additional_request_data) + try: + + self._logger.info("Preemptively caching {} for {!r}".format(route, kwargs)) + builder = EnvironBuilder(**kwargs) + with preemptive_cache.disable_timestamp_update(): + app(builder.get_environ(), lambda *a, **kw: None) + except: + self._logger.exception("Error while trying to preemptively cache {} for {!r}".format(route, kwargs)) + + import threading + cache_thread = threading.Thread(target=execute_caching, name="Preemptive Cache Worker") + cache_thread.daemon = True + cache_thread.start() + def _register_template_plugins(self): template_plugins = pluginManager.get_implementations(octoprint.plugin.TemplatePlugin) for plugin in template_plugins: diff --git a/src/octoprint/server/util/__init__.py b/src/octoprint/server/util/__init__.py index d128e8ef..20640e64 100644 --- a/src/octoprint/server/util/__init__.py +++ b/src/octoprint/server/util/__init__.py @@ -143,6 +143,20 @@ def get_api_key(request): return None +def get_plugin_hash(): + from octoprint.plugin import plugin_manager + + plugin_signature = lambda impl: "{}:{}".format(impl._identifier, impl._plugin_version) + template_plugins = map(plugin_signature, plugin_manager().get_implementations(octoprint.plugin.TemplatePlugin)) + asset_plugins = map(plugin_signature, plugin_manager().get_implementations(octoprint.plugin.AssetPlugin)) + ui_plugins = sorted(set(template_plugins + asset_plugins)) + + import hashlib + plugin_hash = hashlib.sha1() + plugin_hash.update(",".join(ui_plugins)) + return plugin_hash.hexdigest() + + #~~ reverse proxy compatible WSGI middleware diff --git a/src/octoprint/server/util/flask.py b/src/octoprint/server/util/flask.py index ce59f6d6..d621671a 100644 --- a/src/octoprint/server/util/flask.py +++ b/src/octoprint/server/util/flask.py @@ -14,6 +14,7 @@ import flask.ext.assets import webassets.updater import webassets.utils import functools +import contextlib import time import uuid import threading @@ -330,13 +331,14 @@ def cached(timeout=5 * 60, key=lambda: "view:%s" % flask.request.path, unless=No return f(*args, **kwargs) cache_key = key() + rv = _cache.get(cache_key) # only take the value from the cache if we are not required to refresh it from the wrapped function - if not callable(refreshif) or not refreshif(): - rv = _cache.get(cache_key) - if rv is not None: - logger.debug("Serving entry for {path} from cache".format(path=flask.request.path)) - return rv + if rv is not None and (not callable(refreshif) or not refreshif(rv)): + logger.debug("Serving entry for {path} from cache".format(path=flask.request.path)) + if not "X-From-Cache" in rv.headers: + rv.headers["X-From-Cache"] = "true" + return rv # get value from wrapped function logger.debug("No cache entry or refreshing cache for {path} (key: {key}), calling wrapped function".format(path=flask.request.path, key=cache_key)) @@ -377,6 +379,220 @@ def cache_check_response_headers(response): return False +class PreemptiveCache(object): + + def __init__(self, cachefile): + self.cachefile = cachefile + + self._lock = threading.RLock() + self._logger = logging.getLogger(__name__ + "." + self.__class__.__name__) + self._update_timestamp = True + + def record(self, data, unless=None): + if callable(unless) and unless(): + return + + entry_data = data + if callable(entry_data): + entry_data = entry_data() + + if entry_data is not None: + from flask import request + self.add_data(request.path, entry_data) + + @contextlib.contextmanager + def disable_timestamp_update(self): + with self._lock: + self._update_timestamp = False + yield + self._update_timestamp = True + + def clean_all_data(self, cleanup_function): + assert callable(cleanup_function) + + with self._lock: + all_data = self.get_all_data() + for root, entries in all_data.items(): + old_count = len(entries) + entries = cleanup_function(root, entries) + if not entries: + del all_data[root] + self._logger.debug("Removed root {} from preemptive cache".format(root)) + elif len(entries) < old_count: + all_data[root] = entries + self._logger.debug("Removed {} from preemptive cache for root {}".format(old_count - len(entries), root)) + self.set_all_data(all_data) + + return all_data + + def get_all_data(self): + import yaml + + cache_data = None + with self._lock: + try: + with open(self.cachefile, "r") as f: + cache_data = yaml.safe_load(f) + except IOError as e: + import errno + if e.errno != errno.ENOENT: + raise + except: + self._logger.exception("Error while reading {}".format(self.cachefile)) + + if cache_data is None: + cache_data = dict() + + return cache_data + + def get_data(self, root): + cache_data = self.get_all_data() + return cache_data.get(root, dict()) + + def set_all_data(self, data): + from octoprint.util import atomic_write + import yaml + + with self._lock: + try: + with atomic_write(self.cachefile, "wb") as handle: + yaml.safe_dump(data, handle,default_flow_style=False, indent=" ", allow_unicode=True) + except: + self._logger.exception("Error while writing {}".format(self.cachefile)) + + def set_data(self, root, data): + with self._lock: + all_data = self.get_all_data() + all_data[root] = data + self.set_all_data(all_data) + + def add_data(self, root, data): + from octoprint.util import dict_filter + + def strip_ignored(d): + return dict_filter(d, lambda k, v: not k.startswith("_")) + + def compare(a, b): + return set(strip_ignored(a).items()) == set(strip_ignored(b).items()) + + def split_matched_and_unmatched(entry, entries): + matched = [] + unmatched = [] + + for e in entries: + if compare(e, entry): + matched.append(e) + else: + unmatched.append(e) + + return matched, unmatched + + with self._lock: + cache_data = self.get_all_data() + + if not root in cache_data: + cache_data[root] = [] + + existing, other = split_matched_and_unmatched(data, cache_data[root]) + + def get_newest(entries): + result = None + for entry in entries: + if "_timestamp" in entry and (result is None or ("_timestamp" in entry and result["_timestamp"] < entry["_timestamp"])): + result = entry + return result + + to_persist = get_newest(existing) + if not to_persist: + import copy + to_persist = copy.deepcopy(data) + to_persist["_timestamp"] = time.time() + self._logger.info("Adding entry for {} and {!r}".format(root, to_persist)) + elif self._update_timestamp: + to_persist["_timestamp"] = time.time() + self._logger.debug("Updating timestamp for {} and {!r}".format(root, data)) + else: + self._logger.debug("Not updating timestamp for {} and {!r}, currently flagged as disabled".format(root, data)) + + self.set_data(root, [to_persist] + other) + + +def preemptively_cached(cache, data, unless=None): + def decorator(f): + @functools.wraps(f) + def decorated_function(*args, **kwargs): + cache.record(data, unless=unless) + return f(*args, **kwargs) + return decorated_function + return decorator + + +def etagged(etag): + def decorator(f): + @functools.wraps(f) + def decorated_function(*args, **kwargs): + rv = f(*args, **kwargs) + if isinstance(rv, flask.Response): + result = etag + if callable(result): + result = result(rv) + if result: + rv.set_etag(result) + return rv + return decorated_function + return decorator + + +def lastmodified(date): + def decorator(f): + @functools.wraps(f) + def decorated_function(*args, **kwargs): + rv = f(*args, **kwargs) + if not "Last-Modified" in rv.headers: + result = date + if callable(result): + result = result(rv) + + if not isinstance(result, basestring): + from werkzeug.http import http_date + result = http_date(result) + + if result: + rv.headers["Last-Modified"] = result + return rv + return decorated_function + return decorator + + +def conditional(condition, met): + def decorator(f): + @functools.wraps(f) + def decorated_function(*args, **kwargs): + if callable(condition) and condition(): + # condition has been met, return met-response + rv = met + if callable(met): + rv = met() + return rv + + # condition hasn't been met, call decorated function + return f(*args, **kwargs) + return decorated_function + return decorator + + +def check_etag(etag): + return flask.request.method in ("GET", "HEAD") and \ + flask.request.if_none_match and \ + etag in flask.request.if_none_match + + +def check_lastmodified(lastmodified): + return flask.request.method in ("GET", "HEAD") and \ + flask.request.if_modified_since and \ + lastmodified >= flask.request.if_modified_since + + def add_non_caching_response_headers(response): response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, post-check=0, pre-check=0, max-age=0" response.headers["Pragma"] = "no-cache" diff --git a/src/octoprint/server/views.py b/src/octoprint/server/views.py index 3399a740..2632323b 100644 --- a/src/octoprint/server/views.py +++ b/src/octoprint/server/views.py @@ -13,7 +13,8 @@ from flask import request, g, url_for, make_response, render_template, send_from import octoprint.plugin from octoprint.server import app, userManager, pluginManager, gettext, \ - debug, LOCALES, VERSION, DISPLAY_VERSION, UI_API_KEY, BRANCH + debug, LOCALES, VERSION, DISPLAY_VERSION, UI_API_KEY, BRANCH, preemptiveCache, \ + NOT_MODIFIED from octoprint.settings import settings import re @@ -27,12 +28,17 @@ _valid_id_re = re.compile("[a-z_]+") _valid_div_re = re.compile("[a-zA-Z_-]+") @app.route("/") +@util.flask.preemptively_cached(cache=preemptiveCache, + data=lambda: dict(path=request.path, base_url=request.url_root, query_string="l10n={}".format(g.locale.language)) if g.locale else "en", + unless=lambda: request.url_root in settings().get(["server", "preemptiveCache", "exceptions"])) +@util.flask.conditional(lambda: _check_etag_and_lastmodified_for_index(), NOT_MODIFIED) @util.flask.cached(timeout=-1, - refreshif=lambda: util.flask.cache_check_headers() or "_refresh" in request.values, - key=lambda: "view:{}:{}".format(request.base_url, g.locale), - unless_response=util.flask.cache_check_response_headers) + refreshif=lambda cached: _validate_cache_for_index(cached), + key=lambda: "view:{}:{}".format(request.base_url, g.locale.language if g.locale else "en"), + unless_response=lambda response: util.flask.cache_check_response_headers(response)) +@util.flask.etagged(lambda _: _compute_etag_for_index()) +@util.flask.lastmodified(lambda _: _compute_date_for_index()) def index(): - #~~ a bunch of settings enable_gcodeviewer = settings().getBoolean(["gcodeViewer", "enabled"]) @@ -297,13 +303,10 @@ def index(): #~~ render! - import datetime - response = make_response(render_template( "index.jinja2", **render_kwargs )) - response.headers["Last-Modified"] = datetime.datetime.now() if first_run: response = util.flask.add_non_caching_response_headers(response) @@ -354,6 +357,7 @@ def _process_template_configs(name, implementation, configs, rules): return includes + def _process_template_config(name, implementation, rule, config=None, counter=1): if "mandatory" in rule: for mandatory in rule["mandatory"]: @@ -396,78 +400,23 @@ def _process_template_config(name, implementation, rule, config=None, counter=1) return data + @app.route("/robots.txt") +@util.flask.cached(timeout=-1) def robotsTxt(): return send_from_directory(app.static_folder, "robots.txt") @app.route("/i18n//.js") -@util.flask.cached(timeout=-1, - refreshif=lambda: util.flask.cache_check_headers() or "_refresh" in request.values, - key=lambda: "view:{}:{}".format(request.base_url, g.locale)) +@util.flask.conditional(lambda: _check_etag_and_lastmodified_for_i18n(), NOT_MODIFIED) +@util.flask.etagged(lambda _: _compute_etag_for_i18n(request.view_args["locale"], request.view_args["domain"])) +@util.flask.lastmodified(lambda _: _compute_date_for_i18n(request.view_args["locale"], request.view_args["domain"])) def localeJs(locale, domain): messages = dict() plural_expr = None if locale != "en": - from flask import _request_ctx_stack - from babel.messages.pofile import read_po - - def messages_from_po(base_path, locale, domain): - path = os.path.join(base_path, locale) - if not os.path.isdir(path): - return None, None - - path = os.path.join(path, "LC_MESSAGES", "{domain}.po".format(**locals())) - if not os.path.isfile(path): - return None, None - - messages = dict() - with file(path) as f: - catalog = read_po(f, locale=locale, domain=domain) - - for message in catalog: - message_id = message.id - if isinstance(message_id, (list, tuple)): - message_id = message_id[0] - messages[message_id] = message.string - - return messages, catalog.plural_expr - - user_base_path = os.path.join(settings().getBaseFolder("translations")) - user_plugin_path = os.path.join(user_base_path, "_plugins") - - # plugin translations - plugins = octoprint.plugin.plugin_manager().enabled_plugins - for name, plugin in plugins.items(): - dirs = [os.path.join(user_plugin_path, name), os.path.join(plugin.location, 'translations')] - for dirname in dirs: - if not os.path.isdir(dirname): - continue - - plugin_messages, _ = messages_from_po(dirname, locale, domain) - - if plugin_messages is not None: - messages = octoprint.util.dict_merge(messages, plugin_messages) - _logger.debug("Using translation folder {dirname} for locale {locale} of plugin {name}".format(**locals())) - break - else: - _logger.debug("No translations for locale {locale} for plugin {name}".format(**locals())) - - # core translations - ctx = _request_ctx_stack.top - base_path = os.path.join(ctx.app.root_path, "translations") - - dirs = [user_base_path, base_path] - for dirname in dirs: - core_messages, plural_expr = messages_from_po(dirname, locale, domain) - - if core_messages is not None: - messages = octoprint.util.dict_merge(messages, core_messages) - _logger.debug("Using translation folder {dirname} for locale {locale} of core translations".format(**locals())) - break - else: - _logger.debug("No core translations for locale {locale}".format(**locals())) + messages, plural_expr = _get_translations(locale, domain) catalog = dict( messages=messages, @@ -485,3 +434,192 @@ def plugin_assets(name, filename): return redirect(url_for("plugin." + name + ".static", filename=filename)) +def _compute_etag_for_index(files=None, lastmodified=None): + if files is None: + files = _files_for_index() + if lastmodified is None: + lastmodified = _compute_date(files) + if lastmodified and not isinstance(lastmodified, basestring): + from werkzeug.http import http_date + lastmodified = http_date(lastmodified) + + from octoprint import __version__ + from octoprint.server import UI_API_KEY + + import hashlib + hash = hashlib.sha1() + hash.update(__version__) + hash.update(UI_API_KEY) + hash.update(",".join(sorted(files))) + if lastmodified: + hash.update(lastmodified) + return hash.hexdigest() + + +def _compute_etag_for_i18n(locale, domain, files=None, lastmodified=None): + if files is None: + files = _get_all_translationfiles(locale, domain) + if lastmodified is None: + lastmodified = _compute_date(files) + if lastmodified and not isinstance(lastmodified, basestring): + from werkzeug.http import http_date + lastmodified = http_date(lastmodified) + + import hashlib + hash = hashlib.sha1() + hash.update(",".join(sorted(files))) + if lastmodified: + hash.update(lastmodified) + return hash.hexdigest() + + +def _compute_date_for_i18n(locale, domain): + return _compute_date(_get_all_translationfiles(locale, domain)) + + +def _compute_date_for_index(): + return _compute_date(_files_for_index()) + + +def _validate_cache_for_index(cached): + no_cache_headers = util.flask.cache_check_headers() + refresh_flag = "_refresh" in request.values + etag_different = _compute_etag_for_index() != cached.get_etag()[0] + + return no_cache_headers or refresh_flag or etag_different + + +def _files_for_index(): + """ + Collects all paths of files that the index page depends on. + + The relevant files are: + + * all jinja2 templates: they might be used within the index page, so + any changes here change the rendering outcome + * all defined assets: if one of them changes, the webassets bundle will + be regenerated and hence the URL included in the cached page won't be + valid anymore + * all translation files used for our current locale: if any of those change + we also need to re-render + """ + + templates = _get_all_templates() + assets = _get_all_assets() + translations = _get_all_translationfiles(g.locale.language if g.locale else "en", "messages") + return sorted(set(templates + assets + translations)) + + +def _compute_date(files): + from datetime import datetime + timestamps = map(lambda path: os.stat(path).st_mtime, files) + max_timestamp = max(*timestamps) if timestamps else None + if max_timestamp: + # we set the micros to 0 since microseconds are not speced for HTTP + max_timestamp = datetime.fromtimestamp(max_timestamp).replace(microsecond=0) + return max_timestamp + + +def _check_etag_and_lastmodified_for_index(): + files = _files_for_index() + lastmodified = _compute_date(files) + lastmodified_ok = util.flask.check_lastmodified(lastmodified) + etag_ok = util.flask.check_etag(_compute_etag_for_index(files, lastmodified)) + return etag_ok and lastmodified_ok + + +def _check_etag_and_lastmodified_for_i18n(): + locale = request.view_args["locale"] + domain = request.view_args["domain"] + + etag_ok = util.flask.check_etag(_compute_etag_for_i18n(request.view_args["locale"], request.view_args["domain"])) + + lastmodified = _compute_date_for_i18n(locale, domain) + lastmodified_ok = lastmodified is None or util.flask.check_lastmodified(lastmodified) + + return etag_ok and lastmodified_ok + + +def _get_all_templates(): + from octoprint.util.jinja import get_all_template_paths + return get_all_template_paths(app.jinja_loader, lambda path: not octoprint.util.is_hidden_path(path)) + + +def _get_all_assets(): + from octoprint.util.jinja import get_all_asset_paths + return get_all_asset_paths(app.jinja_env.assets_environment) + + +def _get_all_translationfiles(locale, domain): + from flask import _request_ctx_stack + + def get_po_path(basedir, locale, domain): + path = os.path.join(basedir, locale) + if not os.path.isdir(path): + return None + + path = os.path.join(path, "LC_MESSAGES", "{domain}.po".format(**locals())) + if not os.path.isfile(path): + return None + + return path + + po_files = [] + + user_base_path = os.path.join(settings().getBaseFolder("translations")) + user_plugin_path = os.path.join(user_base_path, "_plugins") + + # plugin translations + plugins = octoprint.plugin.plugin_manager().enabled_plugins + for name, plugin in plugins.items(): + dirs = [os.path.join(user_plugin_path, name), os.path.join(plugin.location, 'translations')] + for dirname in dirs: + if not os.path.isdir(dirname): + continue + + po_file = get_po_path(dirname, locale, domain) + if po_file: + po_files.append(po_file) + break + + # core translations + ctx = _request_ctx_stack.top + base_path = os.path.join(ctx.app.root_path, "translations") + + dirs = [user_base_path, base_path] + for dirname in dirs: + po_file = get_po_path(dirname, locale, domain) + if po_file: + po_files.append(po_file) + break + + return po_files + + +def _get_translations(locale, domain): + from babel.messages.pofile import read_po + from octoprint.util import dict_merge + + messages = dict() + plural_expr = None + + def messages_from_po(path, locale, domain): + messages = dict() + with file(path) as f: + catalog = read_po(f, locale=locale, domain=domain) + + for message in catalog: + message_id = message.id + if isinstance(message_id, (list, tuple)): + message_id = message_id[0] + messages[message_id] = message.string + + return messages, catalog.plural_expr + + po_files = _get_all_translationfiles(locale, domain) + for po_file in po_files: + po_messages, plural_expr = messages_from_po(po_file, locale, domain) + if po_messages is not None: + messages = dict_merge(messages, po_messages) + + return messages, plural_expr diff --git a/src/octoprint/settings.py b/src/octoprint/settings.py index fe608ad7..25d1469e 100644 --- a/src/octoprint/settings.py +++ b/src/octoprint/settings.py @@ -114,6 +114,10 @@ default_settings = { "diskspace": { "warning": 500 * 1024 * 1024, # 500 MB "critical": 200 * 1024 * 1024, # 200 MB + }, + "preemptiveCache": { + "exceptions": [], + "until": 7 } }, "webcam": { @@ -263,7 +267,8 @@ default_settings = { "devel": { "stylesheet": "css", "cache": { - "enabled": True + "enabled": True, + "preemptive": True }, "webassets": { "minify": False, diff --git a/src/octoprint/static/js/app/dataupdater.js b/src/octoprint/static/js/app/dataupdater.js index 55eb281e..4caf44ac 100644 --- a/src/octoprint/static/js/app/dataupdater.js +++ b/src/octoprint/static/js/app/dataupdater.js @@ -12,7 +12,7 @@ function DataUpdater(allViewModels) { self._pluginHash = undefined; self.reloadOverlay = $("#reloadui_overlay"); - $("#reloadui_overlay_reload").click(function() { location.reload(true); }); + $("#reloadui_overlay_reload").click(function() { location.reload(); }); self.connect = function() { var options = {}; diff --git a/src/octoprint/util/__init__.py b/src/octoprint/util/__init__.py index 2174f4cf..3426ed67 100644 --- a/src/octoprint/util/__init__.py +++ b/src/octoprint/util/__init__.py @@ -542,6 +542,48 @@ def dict_contains_keys(keys, dictionary): return True + +def dict_filter(dictionary, filter_function): + """ + Filters a dictionary with the provided filter_function + + Example:: + + >>> data = dict(key1="value1", key2="value2", other_key="other_value", foo="bar", bar="foo") + >>> dict_filter(data, lambda k, v: k.startswith("key")) == dict(key1="value1", key2="value2") + True + >>> dict_filter(data, lambda k, v: v.startswith("value")) == dict(key1="value1", key2="value2") + True + >>> dict_filter(data, lambda k, v: k == "foo" or v == "foo") == dict(foo="bar", bar="foo") + True + >>> dict_filter(data, lambda k, v: False) == dict() + True + >>> dict_filter(data, lambda k, v: True) == data + True + >>> dict_filter(None, lambda k, v: True) + Traceback (most recent call last): + ... + AssertionError + >>> dict_filter(data, None) + Traceback (most recent call last): + ... + AssertionError + + Arguments: + dictionary (dict): The dictionary to filter + filter_function (callable): The filter function to apply, called with key and + value of an entry in the dictionary, must return ``True`` for values to + keep and ``False`` for values to strip + + Returns: + dict: A shallow copy of the provided dictionary, stripped of the key-value-pairs + for which the ``filter_function`` returned ``False`` + """ + assert isinstance(dictionary, dict) + assert callable(filter_function) + return dict((k, v) for k, v in dictionary.items() if filter_function(k, v)) + + class Object(object): pass diff --git a/src/octoprint/util/jinja.py b/src/octoprint/util/jinja.py index bf3f8b30..ac2d01ba 100644 --- a/src/octoprint/util/jinja.py +++ b/src/octoprint/util/jinja.py @@ -6,7 +6,8 @@ __copyright__ = "Copyright (C) 2015 The OctoPrint Project - Released under terms import os -from jinja2.loaders import FileSystemLoader, TemplateNotFound, split_template_path +from jinja2.loaders import FileSystemLoader, PrefixLoader, ChoiceLoader, \ + TemplateNotFound, split_template_path class FilteredFileSystemLoader(FileSystemLoader): """ @@ -48,3 +49,60 @@ class FilteredFileSystemLoader(FileSystemLoader): filter_results = map(lambda x: not os.path.exists(os.path.join(x, path)) or self.path_filter(os.path.join(x, path)), self.searchpath) return all(filter_results) + + +def get_all_template_paths(loader): + def walk_folder(folder): + files = [] + walk_dir = os.walk(folder, followlinks=True) + for dirpath, dirnames, filenames in walk_dir: + for filename in filenames: + path = os.path.join(dirpath, filename) + files.append(path) + return files + + def collect_templates_for_loader(loader): + if isinstance(loader, FilteredFileSystemLoader): + result = [] + for folder in loader.searchpath: + result += walk_folder(folder) + return filter(loader.path_filter, result) + + elif isinstance(loader, FileSystemLoader): + result = [] + for folder in loader.searchpath: + result += walk_folder(folder) + return result + + elif isinstance(loader, PrefixLoader): + result = [] + for subloader in loader.mapping.values(): + result += collect_templates_for_loader(subloader) + return result + + elif isinstance(loader, ChoiceLoader): + result = [] + for subloader in loader.loaders: + result += collect_templates_for_loader(subloader) + return result + + return [] + + return collect_templates_for_loader(loader) + + +def get_all_asset_paths(env): + result = [] + for bundle in env: + for content in bundle.resolve_contents(): + try: + if not content: + continue + path = content[1] + if not os.path.isfile(path): + continue + result.append(path) + except IndexError: + # intentionally ignored + pass + return result