From 7ab4d125481b4f9ead137999efaa508da2689699 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gina=20H=C3=A4u=C3=9Fge?= Date: Thu, 29 Jun 2017 15:07:17 +0200 Subject: [PATCH] Better error resilience against wrong user manager Also improved get_class by using importlib instead of complicated climbing through the whole module tree. --- src/octoprint/server/__init__.py | 2 +- src/octoprint/util/__init__.py | 18 +++++++++--------- tests/util/test_misc.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 10 deletions(-) create mode 100644 tests/util/test_misc.py diff --git a/src/octoprint/server/__init__.py b/src/octoprint/server/__init__.py index bf922010..0b0b9f5b 100644 --- a/src/octoprint/server/__init__.py +++ b/src/octoprint/server/__init__.py @@ -210,7 +210,7 @@ class Server(object): try: clazz = octoprint.util.get_class(userManagerName) userManager = clazz() - except AttributeError as e: + except: self._logger.exception("Could not instantiate user manager {}, falling back to FilebasedUserManager!".format(userManagerName)) userManager = octoprint.users.FilebasedUserManager() finally: diff --git a/src/octoprint/util/__init__.py b/src/octoprint/util/__init__.py index 84dd955f..d43733ce 100644 --- a/src/octoprint/util/__init__.py +++ b/src/octoprint/util/__init__.py @@ -174,8 +174,6 @@ def get_class(name): """ Retrieves the class object for a given fully qualified class name. - Taken from http://stackoverflow.com/a/452981/2028598. - Arguments: name (string): The fully qualified class name, including all modules separated by ``.`` @@ -183,15 +181,17 @@ def get_class(name): type: The class if it could be found. Raises: - AttributeError: The class could not be found. + ImportError """ - parts = name.split(".") - module = ".".join(parts[:-1]) - m = __import__(module) - for comp in parts[1:]: - m = getattr(m, comp) - return m + import importlib + + mod_name, cls_name = name.rsplit(".", 1) + m = importlib.import_module(mod_name) + try: + return getattr(m, cls_name) + except AttributeError: + raise ImportError("No module named " + name) def get_exception_string(): diff --git a/tests/util/test_misc.py b/tests/util/test_misc.py new file mode 100644 index 00000000..644654b7 --- /dev/null +++ b/tests/util/test_misc.py @@ -0,0 +1,31 @@ +# coding=utf-8 +from __future__ import absolute_import + +__license__ = 'GNU Affero General Public License http://www.gnu.org/licenses/agpl.html' +__copyright__ = "Copyright (C) 2017 The OctoPrint Project - Released under terms of the AGPLv3 License" + + +import unittest + +import octoprint.util + +class MiscTestCase(unittest.TestCase): + + def test_get_class(self): + octoprint.util.get_class("octoprint.users.FilebasedUserManager") + + def test_get_class_wrongmodule(self): + try: + octoprint.util.get_class("octoprint2.users.FilebasedUserManager") + self.fail("This should have thrown an ImportError") + except ImportError: + # success + pass + + def test_get_class_wrongclass(self): + try: + octoprint.util.get_class("octoprint.users.FilebasedUserManagerBzzztWrong") + self.fail("This should have thrown an ImportError") + except ImportError: + # success + pass