diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 61c4bcd46..4346d64e6 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -111,14 +111,18 @@ TIMEOUT_DURATION = 3 if WEBSOCKET_MANAGER == "redis": log.debug("Using Redis to manage websockets.") - SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL) - USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL) - USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL) + sentinel_hosts=WEBSOCKET_SENTINEL_HOSTS.split(',') + sentinel_port=int(WEBSOCKET_SENTINEL_PORT) + sentinels=[(host, sentinel_port) for host in sentinel_hosts] + SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL, sentinels) + USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL, sentinels) + USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL, sentinels) clean_up_lock = RedisLock( redis_url=WEBSOCKET_REDIS_URL, lock_name="usage_cleanup_lock", timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT, + sentinels, ) aquire_func = clean_up_lock.aquire_lock renew_func = clean_up_lock.renew_lock diff --git a/backend/open_webui/socket/utils.py b/backend/open_webui/socket/utils.py index f5628ee1e..284dd3290 100644 --- a/backend/open_webui/socket/utils.py +++ b/backend/open_webui/socket/utils.py @@ -16,13 +16,35 @@ def parse_redis_sentinel_url(redis_url): "db": int(parsed_url.path.lstrip("/") or 0), } +def get_redis_connection(redis_url, sentinels, decode_responses=True): + """ + Creates a Redis connection from either a standard Redis URL or uses special + parsing to setup a Sentinel connection, if given an array of host/port tuples. + """ + if sentinels: + redis_config = parse_redis_sentinel_url(redis_url) + sentinel = redis.sentinel.Sentinel( + self.sentinels, + port=redis_config['port'], + db=redis_config['db'], + username=redis_config['username'], + password=redis_config['password'], + decode_responses=decode_responses + } + + # Get a master connection from Sentinel + return sentinel.master_for(redis_config['service']) + else: + # Standard Redis connection + return redis.Redis.from_url(redis_url, decode_responses=decode_responses) + class RedisLock: - def __init__(self, redis_url, lock_name, timeout_secs): + def __init__(self, redis_url, lock_name, timeout_secs, sentinels=[]): self.lock_name = lock_name self.lock_id = str(uuid.uuid4()) self.timeout_secs = timeout_secs self.lock_obtained = False - self.redis = redis.Redis.from_url(redis_url, decode_responses=True) + self.redis = get_redis_connection(redis_url, sentinels, decode_responses=True) def aquire_lock(self): # nx=True will only set this key if it _hasn't_ already been set @@ -44,9 +66,9 @@ class RedisLock: class RedisDict: - def __init__(self, name, redis_url): + def __init__(self, name, redis_url, sentinels=[]): self.name = name - self.redis = redis.Redis.from_url(redis_url, decode_responses=True) + self.redis = get_redis_connection(redis_url, sentinels, decode_responses=True) def __setitem__(self, key, value): serialized_value = json.dumps(value)