diff --git a/lyric_search/sources/redis_cache.py b/lyric_search/sources/redis_cache.py index e632dfd..74cb200 100644 --- a/lyric_search/sources/redis_cache.py +++ b/lyric_search/sources/redis_cache.py @@ -6,7 +6,7 @@ import sys import regex from regex import Pattern import asyncio -from typing import Union, Optional +from typing import Union, Optional, cast, List sys.path.insert(1, "..") from lyric_search import notifier @@ -14,7 +14,7 @@ from lyric_search.constructors import LyricsResult import redis.asyncio as redis from redis.commands.search.query import Query # type: ignore from redis.commands.search.indexDefinition import IndexDefinition, IndexType # type: ignore -from redis.commands.search.field import TextField # type: ignore +from redis.commands.search.field import TextField, Field # type: ignore from redis.commands.json.path import Path # type: ignore from . import private @@ -46,17 +46,27 @@ class RedisCache: except Exception as e: logging.debug("Failed to create redis create_index task: %s", str(e)) + async def _ensure_connection(self) -> None: + """Ensure Redis connection is active, reconnect if necessary.""" + try: + await self.redis_client.ping() + except Exception: + logging.debug("Redis connection lost, attempting to reconnect.") + self.redis_client = redis.Redis(password=private.REDIS_PW) + await self.redis_client.ping() # Test the new connection + async def create_index(self) -> None: """Create Index""" try: - schema = ( + await self._ensure_connection() + schema = [ TextField("$.search_artist", as_name="artist"), TextField("$.search_song", as_name="song"), TextField("$.src", as_name="src"), TextField("$.lyrics", as_name="lyrics"), - ) - result = await self.redis_client.ft().create_index( - schema, + ] + result = await self.redis_client.ft().create_index( # type: ignore + cast(List[Field], schema), definition=IndexDefinition( prefix=["lyrics:"], index_type=IndexType.JSON ), @@ -98,6 +108,7 @@ class RedisCache: None """ try: + await self._ensure_connection() src = src.strip().lower() await self.redis_client.incr(f"returned:{src}") except Exception as e: @@ -113,6 +124,7 @@ class RedisCache: dict: In the form {'source': count, 'source2': count, ...} """ try: + await self._ensure_connection() sources: list = ["cache", "lrclib", "genius", "failed"] counts: dict[str, int] = {} for src in sources: @@ -142,6 +154,7 @@ class RedisCache: """ try: + await self._ensure_connection() fuzzy_artist = None fuzzy_song = None is_random_search = artist == "!" and song == "!" @@ -167,7 +180,7 @@ class RedisCache: result["id"].split(":", maxsplit=1)[1], dict(json.loads(result["json"])), ) - for result in search_res.docs + for result in search_res.docs # type: ignore ] # type: ignore if not search_res_out: logging.debug( @@ -189,13 +202,13 @@ class RedisCache: result["id"].split(":", maxsplit=1)[1], dict(json.loads(result["json"])), ) - for result in search_res.docs + for result in search_res.docs # type: ignore ] # type: ignore else: random_redis_key: str = await self.redis_client.randomkey() out_id: str = str(random_redis_key).split(":", maxsplit=1)[1][:-1] - search_res = await self.redis_client.json().get(random_redis_key) + search_res = await self.redis_client.json().get(random_redis_key) # type: ignore search_res_out = [(out_id, search_res)] if not search_res_out and self.notify_warnings: @@ -218,6 +231,7 @@ class RedisCache: None """ try: + await self._ensure_connection() (search_artist, search_song) = self.sanitize_input( lyr_result.artist, lyr_result.song ) @@ -237,7 +251,7 @@ class RedisCache: "liked": 0, } newkey: str = f"lyrics:000{sqlite_id}" - jsonset: bool = await self.redis_client.json().set( + jsonset: bool = await self.redis_client.json().set( # type: ignore newkey, Path.root_path(), redis_mapping ) if not jsonset: