Add connection check to Redis methods and refactor index creation logic

This commit is contained in:
2025-09-27 09:22:56 -04:00
parent 061aed296f
commit 566f1f6692

View File

@@ -6,7 +6,7 @@ import sys
import regex
from regex import Pattern
import asyncio
from typing import Union, Optional
from typing import Union, Optional, cast, List
sys.path.insert(1, "..")
from lyric_search import notifier
@@ -14,7 +14,7 @@ from lyric_search.constructors import LyricsResult
import redis.asyncio as redis
from redis.commands.search.query import Query # type: ignore
from redis.commands.search.indexDefinition import IndexDefinition, IndexType # type: ignore
from redis.commands.search.field import TextField # type: ignore
from redis.commands.search.field import TextField, Field # type: ignore
from redis.commands.json.path import Path # type: ignore
from . import private
@@ -46,17 +46,27 @@ class RedisCache:
except Exception as e:
logging.debug("Failed to create redis create_index task: %s", str(e))
async def _ensure_connection(self) -> None:
"""Ensure Redis connection is active, reconnect if necessary."""
try:
await self.redis_client.ping()
except Exception:
logging.debug("Redis connection lost, attempting to reconnect.")
self.redis_client = redis.Redis(password=private.REDIS_PW)
await self.redis_client.ping() # Test the new connection
async def create_index(self) -> None:
"""Create Index"""
try:
schema = (
await self._ensure_connection()
schema = [
TextField("$.search_artist", as_name="artist"),
TextField("$.search_song", as_name="song"),
TextField("$.src", as_name="src"),
TextField("$.lyrics", as_name="lyrics"),
)
result = await self.redis_client.ft().create_index(
schema,
]
result = await self.redis_client.ft().create_index( # type: ignore
cast(List[Field], schema),
definition=IndexDefinition(
prefix=["lyrics:"], index_type=IndexType.JSON
),
@@ -98,6 +108,7 @@ class RedisCache:
None
"""
try:
await self._ensure_connection()
src = src.strip().lower()
await self.redis_client.incr(f"returned:{src}")
except Exception as e:
@@ -113,6 +124,7 @@ class RedisCache:
dict: In the form {'source': count, 'source2': count, ...}
"""
try:
await self._ensure_connection()
sources: list = ["cache", "lrclib", "genius", "failed"]
counts: dict[str, int] = {}
for src in sources:
@@ -142,6 +154,7 @@ class RedisCache:
"""
try:
await self._ensure_connection()
fuzzy_artist = None
fuzzy_song = None
is_random_search = artist == "!" and song == "!"
@@ -167,7 +180,7 @@ class RedisCache:
result["id"].split(":", maxsplit=1)[1],
dict(json.loads(result["json"])),
)
for result in search_res.docs
for result in search_res.docs # type: ignore
] # type: ignore
if not search_res_out:
logging.debug(
@@ -189,13 +202,13 @@ class RedisCache:
result["id"].split(":", maxsplit=1)[1],
dict(json.loads(result["json"])),
)
for result in search_res.docs
for result in search_res.docs # type: ignore
] # type: ignore
else:
random_redis_key: str = await self.redis_client.randomkey()
out_id: str = str(random_redis_key).split(":", maxsplit=1)[1][:-1]
search_res = await self.redis_client.json().get(random_redis_key)
search_res = await self.redis_client.json().get(random_redis_key) # type: ignore
search_res_out = [(out_id, search_res)]
if not search_res_out and self.notify_warnings:
@@ -218,6 +231,7 @@ class RedisCache:
None
"""
try:
await self._ensure_connection()
(search_artist, search_song) = self.sanitize_input(
lyr_result.artist, lyr_result.song
)
@@ -237,7 +251,7 @@ class RedisCache:
"liked": 0,
}
newkey: str = f"lyrics:000{sqlite_id}"
jsonset: bool = await self.redis_client.json().set(
jsonset: bool = await self.redis_client.json().set( # type: ignore
newkey, Path.root_path(), redis_mapping
)
if not jsonset: