From ebfc6f2f26451f4d45b2437c76359e277e6296d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gina=20H=C3=A4u=C3=9Fge?= Date: Mon, 13 Nov 2017 15:36:46 +0100 Subject: [PATCH] More dynamic detection of plugin mixins Instead of hardcoding a list of supported mixins, we now instead define a base class (or a list thereof) from which mixins might inherit and use inspection to determine the set of mixins each plugin implementation implements. Not only will that reduce the risk of forgetting to add a new mixin to the whitelist, but it will also allow mixins defined by plugins. --- src/octoprint/plugin/__init__.py | 26 +++---------- src/octoprint/plugin/core.py | 31 ++++++++++----- tests/plugin/test_core.py | 65 ++++++++++++++++++++++++++++++-- 3 files changed, 89 insertions(+), 33 deletions(-) diff --git a/src/octoprint/plugin/__init__.py b/src/octoprint/plugin/__init__.py index 8885ded5..fa24ccf0 100644 --- a/src/octoprint/plugin/__init__.py +++ b/src/octoprint/plugin/__init__.py @@ -44,7 +44,7 @@ def _validate_plugin(phase, plugin_info): setattr(plugin_info.instance, PluginInfo.attr_hooks, hooks) return True -def plugin_manager(init=False, plugin_folders=None, plugin_types=None, plugin_entry_points=None, plugin_disabled_list=None, +def plugin_manager(init=False, plugin_folders=None, plugin_bases=None, plugin_entry_points=None, plugin_disabled_list=None, plugin_blacklist=None, plugin_restart_needing_hooks=None, plugin_obsolete_hooks=None, plugin_validators=None): """ @@ -59,9 +59,8 @@ def plugin_manager(init=False, plugin_folders=None, plugin_types=None, plugin_en plugin_folders (list): A list of folders (as strings containing the absolute path to them) in which to look for potential plugin modules. If not provided this defaults to the configured ``plugins`` base folder and ``src/plugins`` within OctoPrint's code base. - plugin_types (list): A list of recognized plugin types for which to look for provided implementations. If not - provided this defaults to the plugin types found in :mod:`octoprint.plugin.types` without - :class:`~octoprint.plugin.OctoPrintPlugin`. + plugin_bases (list): A list of recognized plugin base classes for which to look for provided implementations. If not + provided this defaults to :class:`~octoprint.plugin.OctoPrintPlugin`. plugin_entry_points (list): A list of entry points pointing to modules which to load as plugins. If not provided this defaults to the entry point ``octoprint.plugin``. plugin_disabled_list (list): A list of plugin identifiers that are currently disabled. If not provided this @@ -90,21 +89,8 @@ def plugin_manager(init=False, plugin_folders=None, plugin_types=None, plugin_en else: if init: - if plugin_types is None: - plugin_types = [EnvironmentDetectionPlugin, - StartupPlugin, - ShutdownPlugin, - TemplatePlugin, - SettingsPlugin, - SimpleApiPlugin, - AssetPlugin, - BlueprintPlugin, - EventHandlerPlugin, - SlicerPlugin, - AppPlugin, - ProgressPlugin, - WizardPlugin, - UiPlugin] + if plugin_bases is None: + plugin_bases = [OctoPrintPlugin] if plugin_restart_needing_hooks is None: plugin_restart_needing_hooks = ["octoprint.server.http.*", @@ -119,7 +105,7 @@ def plugin_manager(init=False, plugin_folders=None, plugin_types=None, plugin_en plugin_validators.append(_validate_plugin) _instance = PluginManager(plugin_folders, - plugin_types, + plugin_bases, plugin_entry_points, logging_prefix="octoprint.plugins.", plugin_disabled_list=plugin_disabled_list, diff --git a/src/octoprint/plugin/core.py b/src/octoprint/plugin/core.py index 4231b56c..04b000c8 100644 --- a/src/octoprint/plugin/core.py +++ b/src/octoprint/plugin/core.py @@ -32,6 +32,7 @@ import imp from collections import defaultdict, namedtuple, OrderedDict import logging import fnmatch +import inspect import pkg_resources import pkginfo @@ -465,7 +466,7 @@ class PluginManager(object): It is able to discover plugins both through possible file system locations as well as customizable entry points. """ - def __init__(self, plugin_folders, plugin_types, plugin_entry_points, logging_prefix=None, + def __init__(self, plugin_folders, plugin_bases, plugin_entry_points, logging_prefix=None, plugin_disabled_list=None, plugin_blacklist=None, plugin_restart_needing_hooks=None, plugin_obsolete_hooks=None, plugin_validators=None): self.logger = logging.getLogger(__name__) @@ -474,8 +475,8 @@ class PluginManager(object): logging_prefix = "" if plugin_folders is None: plugin_folders = [] - if plugin_types is None: - plugin_types = [] + if plugin_bases is None: + plugin_bases = [] if plugin_entry_points is None: plugin_entry_points = [] if plugin_disabled_list is None: @@ -484,7 +485,7 @@ class PluginManager(object): plugin_blacklist = [] self.plugin_folders = plugin_folders - self.plugin_types = plugin_types + self.plugin_bases = plugin_bases self.plugin_entry_points = plugin_entry_points self.plugin_disabled_list = plugin_disabled_list self.plugin_blacklist = plugin_blacklist @@ -975,9 +976,9 @@ class PluginManager(object): # evaluate registered implementation if plugin.implementation: - for plugin_type in self.plugin_types: - if isinstance(plugin.implementation, plugin_type): - self.plugin_implementations_by_type[plugin_type].append((name, plugin.implementation)) + mixins = self.mixins_matching_bases(plugin.implementation.__class__, *self.plugin_bases) + for mixin in mixins: + self.plugin_implementations_by_type[mixin].append((name, plugin.implementation)) self.plugin_implementations[name] = plugin.implementation @@ -1000,9 +1001,10 @@ class PluginManager(object): if name in self.plugin_implementations: del self.plugin_implementations[name] - for plugin_type in self.plugin_types: + mixins = self.mixins_matching_bases(plugin.implementation.__class__, *self.plugin_bases) + for mixin in mixins: try: - self.plugin_implementations_by_type[plugin_type].remove((name, plugin.implementation)) + self.plugin_implementations_by_type[mixin].remove((name, plugin.implementation)) except ValueError: # that's ok, the plugin was just not registered for the type pass @@ -1087,6 +1089,17 @@ class PluginManager(object): return any(map(lambda h: fnmatch.fnmatch(hook, h), hooks)) + @staticmethod + def mixins_matching_bases(klass, *bases): + result = set() + for c in inspect.getmro(klass): + if c == klass or c in bases: + # ignore the exact class and our bases + continue + if issubclass(c, bases): + result.add(c) + return result + @staticmethod def has_any_of_mixins(plugin, *mixins): """ diff --git a/tests/plugin/test_core.py b/tests/plugin/test_core.py index cb25a4bd..5ac072b4 100644 --- a/tests/plugin/test_core.py +++ b/tests/plugin/test_core.py @@ -5,6 +5,48 @@ import ddt import octoprint.plugin import octoprint.plugin.core +##~~ Helpers for testing mixin type extraction + +class A(object): + pass + + +class A_1(A): + pass + + +class A_2(A): + pass + + +class A_3(A): + pass + + +class A1_1(A_1): + pass + + +class B(object): + pass + + +class B_1(B): + pass + + +class C(object): + pass + + +class C_1(C): + pass + + +class D(object): + pass + + @ddt.ddt class PluginTestCase(unittest.TestCase): @@ -18,12 +60,10 @@ class PluginTestCase(unittest.TestCase): self.plugin_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "_plugins") plugin_folders = [self.plugin_folder] - plugin_types = [octoprint.plugin.SettingsPlugin, - octoprint.plugin.StartupPlugin, - octoprint.plugin.AssetPlugin] + plugin_bases = [octoprint.plugin.OctoPrintPlugin] plugin_entry_points = None self.plugin_manager = octoprint.plugin.core.PluginManager(plugin_folders, - plugin_types, + plugin_bases, plugin_entry_points, plugin_disabled_list=[], logging_prefix="logging_prefix.") @@ -274,3 +314,20 @@ class PluginTestCase(unittest.TestCase): result = octoprint.plugin.core.PluginManager.has_any_of_mixins(plugin, octoprint.plugin.RestartNeedingPlugin) self.assertFalse(result) + + @ddt.data( + ((A1_1, A_2, B_1, C_1), (A, C), (A_1, A1_1, A_2, C_1)), + ((A1_1, A_2, B_1, C_1), (B,), (B_1,)), + + # not a subclass + ((A1_1, A_2, B_1, C_1), (D,), ()), + + # subclass only of base + ((A,), (A,), ()) + ) + @ddt.unpack + def test_mixins_matching_bases(self, bases_to_set, bases_to_check, expected): + Foo = type("Foo", bases_to_set, dict()) + actual = octoprint.plugin.core.PluginManager.mixins_matching_bases(Foo, *bases_to_check) + self.assertSetEqual(actual, set(expected)) +