diff --git a/src/octoprint/users.py b/src/octoprint/users.py index c5bb27d6..56ddf786 100644 --- a/src/octoprint/users.py +++ b/src/octoprint/users.py @@ -26,7 +26,7 @@ class UserManager(object): def __init__(self): self._logger = logging.getLogger(__name__) self._session_users_by_session = dict() - self._session_users_by_userid = dict() + self._sessionids_by_userid = dict() self._enabled = True @property @@ -61,9 +61,10 @@ class UserManager(object): self._session_users_by_session[user.get_session()] = user userid = user.get_id() - if not userid in self._session_users_by_userid: - self._session_users_by_userid[userid] = [] - self._session_users_by_userid[userid].append(user) + if not userid in self._sessionids_by_userid: + self._sessionids_by_userid[userid] = set() + + self._sessionids_by_userid[userid].add(user.get_session()) self._logger.debug("Logged in user: %r" % user) @@ -80,15 +81,16 @@ class UserManager(object): return userid = user.get_id() - if userid in self._session_users_by_userid: - users_by_userid = self._session_users_by_userid[userid] - for u in users_by_userid: - if u.get_session() == user.get_session(): - users_by_userid.remove(u) - break + sessionid = user.get_session() - if user.get_session() in self._session_users_by_session: - del self._session_users_by_session[user.get_session()] + if userid in self._sessionids_by_userid: + try: + self._sessionids_by_userid[userid].remove(sessionid) + except KeyError: + pass + + if sessionid in self._session_users_by_session: + del self._session_users_by_session[sessionid] self._logger.debug("Logged out user: %r" % user) @@ -165,13 +167,12 @@ class UserManager(object): pass def removeUser(self, username): - if username in self._session_users_by_userid: - users = self._session_users_by_userid[username] - sessions = [user.get_session() for user in users if isinstance(user, SessionUser)] + if username in self._sessionids_by_userid: + sessions = self._sessionids_by_userid[username] for session in sessions: if session in self._session_users_by_session: del self._session_users_by_session[session] - del self._session_users_by_userid[username] + del self._sessionids_by_userid[username] def findUser(self, userid=None, session=None): if session is not None and session in self._session_users_by_session: @@ -217,6 +218,9 @@ class FilebasedUserManager(UserManager): if "settings" in attributes: settings = attributes["settings"] self._users[name] = User(name, attributes["password"], attributes["active"], attributes["roles"], apikey=apikey, settings=settings) + for sessionid in self._sessionids_by_userid.get(name, set()): + if sessionid in self._session_users_by_session: + self._session_users_by_session[sessionid].update_user(self._users[name]) else: self._customized = False @@ -421,7 +425,7 @@ class User(UserMixin): self._roles = roles self._apikey = apikey - if not settings: + if settings is None: settings = dict() self._settings = settings @@ -505,7 +509,6 @@ class User(UserMixin): class SessionUser(User): def __init__(self, user): self._user = user - User.__init__(self, user._username, user._passwordHash, user._active, user._roles, user._apikey, user._settings) import string import random @@ -515,7 +518,7 @@ class SessionUser(User): self._created = time.time() def __getattribute__(self, item): - if item in ("get_session", "_user", "_session", "_created"): + if item in ("get_session", "update_user", "_user", "_session", "_created"): return object.__getattribute__(self, item) else: return getattr(object.__getattribute__(self, "_user"), item) @@ -529,6 +532,9 @@ class SessionUser(User): def get_session(self): return self._session + def update_user(self, user): + self._user = user + def __repr__(self): return "SessionUser(id=%s,name=%s,active=%r,user=%r,admin=%r,session=%s,created=%s)" % (self.get_id(), self.get_name(), self.is_active(), self.is_user(), self.is_admin(), self._session, self._created)