diff --git a/src/octoprint/plugin/__init__.py b/src/octoprint/plugin/__init__.py index 8885ded5..fa24ccf0 100644 --- a/src/octoprint/plugin/__init__.py +++ b/src/octoprint/plugin/__init__.py @@ -44,7 +44,7 @@ def _validate_plugin(phase, plugin_info): setattr(plugin_info.instance, PluginInfo.attr_hooks, hooks) return True -def plugin_manager(init=False, plugin_folders=None, plugin_types=None, plugin_entry_points=None, plugin_disabled_list=None, +def plugin_manager(init=False, plugin_folders=None, plugin_bases=None, plugin_entry_points=None, plugin_disabled_list=None, plugin_blacklist=None, plugin_restart_needing_hooks=None, plugin_obsolete_hooks=None, plugin_validators=None): """ @@ -59,9 +59,8 @@ def plugin_manager(init=False, plugin_folders=None, plugin_types=None, plugin_en plugin_folders (list): A list of folders (as strings containing the absolute path to them) in which to look for potential plugin modules. If not provided this defaults to the configured ``plugins`` base folder and ``src/plugins`` within OctoPrint's code base. - plugin_types (list): A list of recognized plugin types for which to look for provided implementations. If not - provided this defaults to the plugin types found in :mod:`octoprint.plugin.types` without - :class:`~octoprint.plugin.OctoPrintPlugin`. + plugin_bases (list): A list of recognized plugin base classes for which to look for provided implementations. If not + provided this defaults to :class:`~octoprint.plugin.OctoPrintPlugin`. plugin_entry_points (list): A list of entry points pointing to modules which to load as plugins. If not provided this defaults to the entry point ``octoprint.plugin``. plugin_disabled_list (list): A list of plugin identifiers that are currently disabled. If not provided this @@ -90,21 +89,8 @@ def plugin_manager(init=False, plugin_folders=None, plugin_types=None, plugin_en else: if init: - if plugin_types is None: - plugin_types = [EnvironmentDetectionPlugin, - StartupPlugin, - ShutdownPlugin, - TemplatePlugin, - SettingsPlugin, - SimpleApiPlugin, - AssetPlugin, - BlueprintPlugin, - EventHandlerPlugin, - SlicerPlugin, - AppPlugin, - ProgressPlugin, - WizardPlugin, - UiPlugin] + if plugin_bases is None: + plugin_bases = [OctoPrintPlugin] if plugin_restart_needing_hooks is None: plugin_restart_needing_hooks = ["octoprint.server.http.*", @@ -119,7 +105,7 @@ def plugin_manager(init=False, plugin_folders=None, plugin_types=None, plugin_en plugin_validators.append(_validate_plugin) _instance = PluginManager(plugin_folders, - plugin_types, + plugin_bases, plugin_entry_points, logging_prefix="octoprint.plugins.", plugin_disabled_list=plugin_disabled_list, diff --git a/src/octoprint/plugin/core.py b/src/octoprint/plugin/core.py index 4231b56c..04b000c8 100644 --- a/src/octoprint/plugin/core.py +++ b/src/octoprint/plugin/core.py @@ -32,6 +32,7 @@ import imp from collections import defaultdict, namedtuple, OrderedDict import logging import fnmatch +import inspect import pkg_resources import pkginfo @@ -465,7 +466,7 @@ class PluginManager(object): It is able to discover plugins both through possible file system locations as well as customizable entry points. """ - def __init__(self, plugin_folders, plugin_types, plugin_entry_points, logging_prefix=None, + def __init__(self, plugin_folders, plugin_bases, plugin_entry_points, logging_prefix=None, plugin_disabled_list=None, plugin_blacklist=None, plugin_restart_needing_hooks=None, plugin_obsolete_hooks=None, plugin_validators=None): self.logger = logging.getLogger(__name__) @@ -474,8 +475,8 @@ class PluginManager(object): logging_prefix = "" if plugin_folders is None: plugin_folders = [] - if plugin_types is None: - plugin_types = [] + if plugin_bases is None: + plugin_bases = [] if plugin_entry_points is None: plugin_entry_points = [] if plugin_disabled_list is None: @@ -484,7 +485,7 @@ class PluginManager(object): plugin_blacklist = [] self.plugin_folders = plugin_folders - self.plugin_types = plugin_types + self.plugin_bases = plugin_bases self.plugin_entry_points = plugin_entry_points self.plugin_disabled_list = plugin_disabled_list self.plugin_blacklist = plugin_blacklist @@ -975,9 +976,9 @@ class PluginManager(object): # evaluate registered implementation if plugin.implementation: - for plugin_type in self.plugin_types: - if isinstance(plugin.implementation, plugin_type): - self.plugin_implementations_by_type[plugin_type].append((name, plugin.implementation)) + mixins = self.mixins_matching_bases(plugin.implementation.__class__, *self.plugin_bases) + for mixin in mixins: + self.plugin_implementations_by_type[mixin].append((name, plugin.implementation)) self.plugin_implementations[name] = plugin.implementation @@ -1000,9 +1001,10 @@ class PluginManager(object): if name in self.plugin_implementations: del self.plugin_implementations[name] - for plugin_type in self.plugin_types: + mixins = self.mixins_matching_bases(plugin.implementation.__class__, *self.plugin_bases) + for mixin in mixins: try: - self.plugin_implementations_by_type[plugin_type].remove((name, plugin.implementation)) + self.plugin_implementations_by_type[mixin].remove((name, plugin.implementation)) except ValueError: # that's ok, the plugin was just not registered for the type pass @@ -1087,6 +1089,17 @@ class PluginManager(object): return any(map(lambda h: fnmatch.fnmatch(hook, h), hooks)) + @staticmethod + def mixins_matching_bases(klass, *bases): + result = set() + for c in inspect.getmro(klass): + if c == klass or c in bases: + # ignore the exact class and our bases + continue + if issubclass(c, bases): + result.add(c) + return result + @staticmethod def has_any_of_mixins(plugin, *mixins): """ diff --git a/tests/plugin/test_core.py b/tests/plugin/test_core.py index cb25a4bd..5ac072b4 100644 --- a/tests/plugin/test_core.py +++ b/tests/plugin/test_core.py @@ -5,6 +5,48 @@ import ddt import octoprint.plugin import octoprint.plugin.core +##~~ Helpers for testing mixin type extraction + +class A(object): + pass + + +class A_1(A): + pass + + +class A_2(A): + pass + + +class A_3(A): + pass + + +class A1_1(A_1): + pass + + +class B(object): + pass + + +class B_1(B): + pass + + +class C(object): + pass + + +class C_1(C): + pass + + +class D(object): + pass + + @ddt.ddt class PluginTestCase(unittest.TestCase): @@ -18,12 +60,10 @@ class PluginTestCase(unittest.TestCase): self.plugin_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "_plugins") plugin_folders = [self.plugin_folder] - plugin_types = [octoprint.plugin.SettingsPlugin, - octoprint.plugin.StartupPlugin, - octoprint.plugin.AssetPlugin] + plugin_bases = [octoprint.plugin.OctoPrintPlugin] plugin_entry_points = None self.plugin_manager = octoprint.plugin.core.PluginManager(plugin_folders, - plugin_types, + plugin_bases, plugin_entry_points, plugin_disabled_list=[], logging_prefix="logging_prefix.") @@ -274,3 +314,20 @@ class PluginTestCase(unittest.TestCase): result = octoprint.plugin.core.PluginManager.has_any_of_mixins(plugin, octoprint.plugin.RestartNeedingPlugin) self.assertFalse(result) + + @ddt.data( + ((A1_1, A_2, B_1, C_1), (A, C), (A_1, A1_1, A_2, C_1)), + ((A1_1, A_2, B_1, C_1), (B,), (B_1,)), + + # not a subclass + ((A1_1, A_2, B_1, C_1), (D,), ()), + + # subclass only of base + ((A,), (A,), ()) + ) + @ddt.unpack + def test_mixins_matching_bases(self, bases_to_set, bases_to_check, expected): + Foo = type("Foo", bases_to_set, dict()) + actual = octoprint.plugin.core.PluginManager.mixins_matching_bases(Foo, *bases_to_check) + self.assertSetEqual(actual, set(expected)) +