Add connection check to Redis methods and refactor index creation logic

This commit is contained in:
2025-09-27 09:22:56 -04:00
parent 061aed296f
commit 566f1f6692

View File

@@ -6,7 +6,7 @@ import sys
import regex import regex
from regex import Pattern from regex import Pattern
import asyncio import asyncio
from typing import Union, Optional from typing import Union, Optional, cast, List
sys.path.insert(1, "..") sys.path.insert(1, "..")
from lyric_search import notifier from lyric_search import notifier
@@ -14,7 +14,7 @@ from lyric_search.constructors import LyricsResult
import redis.asyncio as redis import redis.asyncio as redis
from redis.commands.search.query import Query # type: ignore from redis.commands.search.query import Query # type: ignore
from redis.commands.search.indexDefinition import IndexDefinition, IndexType # type: ignore from redis.commands.search.indexDefinition import IndexDefinition, IndexType # type: ignore
from redis.commands.search.field import TextField # type: ignore from redis.commands.search.field import TextField, Field # type: ignore
from redis.commands.json.path import Path # type: ignore from redis.commands.json.path import Path # type: ignore
from . import private from . import private
@@ -46,17 +46,27 @@ class RedisCache:
except Exception as e: except Exception as e:
logging.debug("Failed to create redis create_index task: %s", str(e)) logging.debug("Failed to create redis create_index task: %s", str(e))
async def _ensure_connection(self) -> None:
"""Ensure Redis connection is active, reconnect if necessary."""
try:
await self.redis_client.ping()
except Exception:
logging.debug("Redis connection lost, attempting to reconnect.")
self.redis_client = redis.Redis(password=private.REDIS_PW)
await self.redis_client.ping() # Test the new connection
async def create_index(self) -> None: async def create_index(self) -> None:
"""Create Index""" """Create Index"""
try: try:
schema = ( await self._ensure_connection()
schema = [
TextField("$.search_artist", as_name="artist"), TextField("$.search_artist", as_name="artist"),
TextField("$.search_song", as_name="song"), TextField("$.search_song", as_name="song"),
TextField("$.src", as_name="src"), TextField("$.src", as_name="src"),
TextField("$.lyrics", as_name="lyrics"), TextField("$.lyrics", as_name="lyrics"),
) ]
result = await self.redis_client.ft().create_index( result = await self.redis_client.ft().create_index( # type: ignore
schema, cast(List[Field], schema),
definition=IndexDefinition( definition=IndexDefinition(
prefix=["lyrics:"], index_type=IndexType.JSON prefix=["lyrics:"], index_type=IndexType.JSON
), ),
@@ -98,6 +108,7 @@ class RedisCache:
None None
""" """
try: try:
await self._ensure_connection()
src = src.strip().lower() src = src.strip().lower()
await self.redis_client.incr(f"returned:{src}") await self.redis_client.incr(f"returned:{src}")
except Exception as e: except Exception as e:
@@ -113,6 +124,7 @@ class RedisCache:
dict: In the form {'source': count, 'source2': count, ...} dict: In the form {'source': count, 'source2': count, ...}
""" """
try: try:
await self._ensure_connection()
sources: list = ["cache", "lrclib", "genius", "failed"] sources: list = ["cache", "lrclib", "genius", "failed"]
counts: dict[str, int] = {} counts: dict[str, int] = {}
for src in sources: for src in sources:
@@ -142,6 +154,7 @@ class RedisCache:
""" """
try: try:
await self._ensure_connection()
fuzzy_artist = None fuzzy_artist = None
fuzzy_song = None fuzzy_song = None
is_random_search = artist == "!" and song == "!" is_random_search = artist == "!" and song == "!"
@@ -167,7 +180,7 @@ class RedisCache:
result["id"].split(":", maxsplit=1)[1], result["id"].split(":", maxsplit=1)[1],
dict(json.loads(result["json"])), dict(json.loads(result["json"])),
) )
for result in search_res.docs for result in search_res.docs # type: ignore
] # type: ignore ] # type: ignore
if not search_res_out: if not search_res_out:
logging.debug( logging.debug(
@@ -189,13 +202,13 @@ class RedisCache:
result["id"].split(":", maxsplit=1)[1], result["id"].split(":", maxsplit=1)[1],
dict(json.loads(result["json"])), dict(json.loads(result["json"])),
) )
for result in search_res.docs for result in search_res.docs # type: ignore
] # type: ignore ] # type: ignore
else: else:
random_redis_key: str = await self.redis_client.randomkey() random_redis_key: str = await self.redis_client.randomkey()
out_id: str = str(random_redis_key).split(":", maxsplit=1)[1][:-1] out_id: str = str(random_redis_key).split(":", maxsplit=1)[1][:-1]
search_res = await self.redis_client.json().get(random_redis_key) search_res = await self.redis_client.json().get(random_redis_key) # type: ignore
search_res_out = [(out_id, search_res)] search_res_out = [(out_id, search_res)]
if not search_res_out and self.notify_warnings: if not search_res_out and self.notify_warnings:
@@ -218,6 +231,7 @@ class RedisCache:
None None
""" """
try: try:
await self._ensure_connection()
(search_artist, search_song) = self.sanitize_input( (search_artist, search_song) = self.sanitize_input(
lyr_result.artist, lyr_result.song lyr_result.artist, lyr_result.song
) )
@@ -237,7 +251,7 @@ class RedisCache:
"liked": 0, "liked": 0,
} }
newkey: str = f"lyrics:000{sqlite_id}" newkey: str = f"lyrics:000{sqlite_id}"
jsonset: bool = await self.redis_client.json().set( jsonset: bool = await self.redis_client.json().set( # type: ignore
newkey, Path.root_path(), redis_mapping newkey, Path.root_path(), redis_mapping
) )
if not jsonset: if not jsonset: