From 10ccf8c8ebe175f15155d7ef6a2fbc4b7c2b7bd6 Mon Sep 17 00:00:00 2001 From: codey Date: Thu, 18 Dec 2025 07:51:47 -0500 Subject: [PATCH] performance: db/aiohttp connection pooling --- base.py | 17 ++ endpoints/lighting.py | 10 +- endpoints/radio.py | 48 ++--- endpoints/rand_msg.py | 12 +- lyric_search/sources/redis_cache.py | 11 +- shared.py | 290 ++++++++++++++++++++++++++++ utils/sr_wrapper.py | 6 +- 7 files changed, 350 insertions(+), 44 deletions(-) create mode 100644 shared.py diff --git a/base.py b/base.py index 8b546bf..37e13ba 100644 --- a/base.py +++ b/base.py @@ -4,12 +4,23 @@ import sys sys.path.insert(0, ".") import logging import asyncio + +# Install uvloop for better async performance (2-4x speedup on I/O) +try: + import uvloop + + uvloop.install() + logging.info("uvloop installed successfully") +except ImportError: + logging.warning("uvloop not available, using default asyncio event loop") + from contextlib import asynccontextmanager from typing import Any from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from scalar_fastapi import get_scalar_api_reference from lyric_search.sources import redis_cache +import shared # Shared connection pools logging.basicConfig(level=logging.INFO) logging.getLogger("aiosqlite").setLevel(logging.WARNING) @@ -35,6 +46,9 @@ async def lifespan(app: FastAPI): uvicorn_access_logger = logging.getLogger("uvicorn.access") uvicorn_access_logger.disabled = True + # Initialize shared infrastructure (Redis pool, aiohttp session, SQLite pool) + await shared.startup() + # Start Radio playlists if "radio" in _routes and hasattr(_routes["radio"], "on_start"): await _routes["radio"].on_start() @@ -55,6 +69,9 @@ async def lifespan(app: FastAPI): if "trip" in _routes and hasattr(_routes["trip"], "shutdown"): await _routes["trip"].shutdown() + # Clean up shared infrastructure + await shared.shutdown() + logger.info("Application shutdown complete") diff --git a/endpoints/lighting.py b/endpoints/lighting.py index 853aadb..5d45fec 100644 --- a/endpoints/lighting.py +++ b/endpoints/lighting.py @@ -24,9 +24,7 @@ import aiohttp from fastapi import FastAPI, Depends, HTTPException, Request from fastapi_throttle import RateLimiter from fastapi.responses import JSONResponse -import redis -from lyric_search.sources import private from auth.deps import get_current_user from dotenv import load_dotenv @@ -72,10 +70,10 @@ class Lighting: self.util = util self.constants = constants - # Redis for state persistence - self.redis_client = redis.Redis( - password=private.REDIS_PW, decode_responses=True - ) + # Redis for state persistence - use shared sync client + import shared + + self.redis_client = shared.get_redis_sync_client(decode_responses=True) self.lighting_key = "lighting:state" # Cync configuration from environment diff --git a/endpoints/radio.py b/endpoints/radio.py index 773826c..8ab41fe 100644 --- a/endpoints/radio.py +++ b/endpoints/radio.py @@ -686,23 +686,23 @@ class Radio(FastAPI): async def _send_lrc_to_client(self, websocket: WebSocket, station: str, track_data: dict): """Send cached LRC data to a specific client asynchronously. Only sends if LRC exists in cache.""" - logging.info(f"[LRC Send] Checking cached LRC for station {station}") - logging.info(f"[LRC Send] Current track: {track_data.get('artist', 'Unknown')} - {track_data.get('song', 'Unknown')}") + logging.debug(f"[LRC Send] Checking cached LRC for station {station}") + logging.debug(f"[LRC Send] Current track: {track_data.get('artist', 'Unknown')} - {track_data.get('song', 'Unknown')}") try: # Only send if LRC is in cache cached_lrc = self.lrc_cache.get(station) - logging.info(f"[LRC Send] Cache status for station {station}: {'Found' if cached_lrc else 'Not found'}") + logging.debug(f"[LRC Send] Cache status for station {station}: {'Found' if cached_lrc else 'Not found'}") if cached_lrc: - logging.info("[LRC Send] Sending cached LRC to client") + logging.debug("[LRC Send] Sending cached LRC to client") lrc_data: dict = { "type": "lrc", "data": cached_lrc, "source": "Cache" } await websocket.send_text(json.dumps(lrc_data)) - logging.info("[LRC Send] Successfully sent cached LRC to client") + logging.debug("[LRC Send] Successfully sent cached LRC to client") else: - logging.info(f"[LRC Send] No cached LRC available for station {station}") + logging.debug(f"[LRC Send] No cached LRC available for station {station}") except Exception as e: logging.error(f"[LRC Send] Failed to send cached LRC to client: {e}") logging.error(f"[LRC Send] Error details: {traceback.format_exc()}") @@ -711,34 +711,34 @@ class Radio(FastAPI): """Send cached LRC data to a specific client asynchronously. Only sends if valid LRC exists in cache.""" try: track_info = f"{track_data.get('artist', 'Unknown')} - {track_data.get('song', 'Unknown')}" - logging.info(f"[LRC Send {id(websocket)}] Starting LRC send for {track_info}") - logging.info(f"[LRC Send {id(websocket)}] Cache keys before lock: {list(self.lrc_cache.keys())}") + logging.debug(f"[LRC Send {id(websocket)}] Starting LRC send for {track_info}") + logging.debug(f"[LRC Send {id(websocket)}] Cache keys before lock: {list(self.lrc_cache.keys())}") # Get cached LRC with lock to ensure consistency async with self.lrc_cache_locks[station]: - logging.info(f"[LRC Send {id(websocket)}] Got cache lock") + logging.debug(f"[LRC Send {id(websocket)}] Got cache lock") cached_lrc = self.lrc_cache.get(station) - logging.info(f"[LRC Send {id(websocket)}] Cache keys during lock: {list(self.lrc_cache.keys())}") - logging.info(f"[LRC Send {id(websocket)}] Cache entry length: {len(cached_lrc) if cached_lrc else 0}") + logging.debug(f"[LRC Send {id(websocket)}] Cache keys during lock: {list(self.lrc_cache.keys())}") + logging.debug(f"[LRC Send {id(websocket)}] Cache entry length: {len(cached_lrc) if cached_lrc else 0}") # Only send if we have actual lyrics if cached_lrc: - logging.info(f"[LRC Send {id(websocket)}] Preparing to send {len(cached_lrc)} bytes of LRC") + logging.debug(f"[LRC Send {id(websocket)}] Preparing to send {len(cached_lrc)} bytes of LRC") lrc_data: dict = { "type": "lrc", "data": cached_lrc, "source": "Cache" } await websocket.send_text(json.dumps(lrc_data)) - logging.info(f"[LRC Send {id(websocket)}] Successfully sent LRC") + logging.debug(f"[LRC Send {id(websocket)}] Successfully sent LRC") else: - logging.info(f"[LRC Send {id(websocket)}] No LRC in cache") + logging.debug(f"[LRC Send {id(websocket)}] No LRC in cache") # If we have no cache entry, let's check if a fetch is needed async with self.lrc_cache_locks[station]: - logging.info(f"[LRC Send {id(websocket)}] Checking if fetch needed") + logging.debug(f"[LRC Send {id(websocket)}] Checking if fetch needed") # Only attempt fetch if we're the first to notice missing lyrics if station not in self.lrc_cache: - logging.info(f"[LRC Send {id(websocket)}] Initiating LRC fetch") + logging.debug(f"[LRC Send {id(websocket)}] Initiating LRC fetch") lrc, source = await self._fetch_and_cache_lrc(station, track_data) if lrc: self.lrc_cache[station] = lrc @@ -748,7 +748,7 @@ class Radio(FastAPI): "source": source } await websocket.send_text(json.dumps(lrc_data)) - logging.info(f"[LRC Send {id(websocket)}] Sent newly fetched LRC") + logging.debug(f"[LRC Send {id(websocket)}] Sent newly fetched LRC") except Exception as e: logging.error(f"[LRC Send {id(websocket)}] Failed: {e}") logging.error(f"[LRC Send {id(websocket)}] Error details: {traceback.format_exc()}") @@ -761,25 +761,25 @@ class Radio(FastAPI): duration: Optional[int] = track_data.get("duration") if not (artist and title): - logging.info("[LRC] Missing artist or title, skipping fetch") + logging.debug("[LRC] Missing artist or title, skipping fetch") return None, "None" - logging.info(f"[LRC] Starting fetch for {station}: {artist} - {title}") + logging.debug(f"[LRC] Starting fetch for {station}: {artist} - {title}") # Try LRCLib first with timeout try: async with asyncio.timeout(10.0): # 10 second timeout - logging.info("[LRC] Trying LRCLib") + logging.debug("[LRC] Trying LRCLib") lrclib_result = await self.lrclib.search(artist, title, plain=False, raw=True) if lrclib_result and lrclib_result.lyrics and isinstance(lrclib_result.lyrics, str): - logging.info("[LRC] Found from LRCLib") + logging.debug("[LRC] Found from LRCLib") return lrclib_result.lyrics, "LRCLib" except asyncio.TimeoutError: logging.warning("[LRC] LRCLib fetch timed out") except Exception as e: logging.error(f"[LRC] LRCLib fetch error: {e}") - logging.info("[LRC] LRCLib fetch completed without results") + logging.debug("[LRC] LRCLib fetch completed without results") # Try SR as fallback with timeout try: @@ -788,14 +788,14 @@ class Radio(FastAPI): artist, title, duration=duration ) if lrc: - logging.info("[LRC] Found from SR") + logging.debug("[LRC] Found from SR") return lrc, "SR" except asyncio.TimeoutError: logging.warning("[LRC] SR fetch timed out") except Exception as e: logging.error(f"[LRC] SR fetch error: {e}") - logging.info("[LRC] No lyrics found from any source") + logging.debug("[LRC] No lyrics found from any source") return None, "None" except Exception as e: logging.error(f"[LRC] Error fetching lyrics: {e}") diff --git a/endpoints/rand_msg.py b/endpoints/rand_msg.py index b1d19c3..3b849db 100644 --- a/endpoints/rand_msg.py +++ b/endpoints/rand_msg.py @@ -1,11 +1,11 @@ import os import random from typing import LiteralString, Optional, Union -import aiosqlite as sqlite3 from fastapi import FastAPI, Depends from fastapi_throttle import RateLimiter from fastapi.responses import JSONResponse from .constructors import RandMsgRequest +import shared # Use shared SQLite pool class RandMsg(FastAPI): @@ -103,11 +103,11 @@ class RandMsg(FastAPI): } ) - async with sqlite3.connect(database=randmsg_db_path, timeout=1) as _db: - async with await _db.execute(db_query) as _cursor: - if not isinstance(_cursor, sqlite3.Cursor): - return JSONResponse(content={"err": True}) - result: Optional[sqlite3.Row] = await _cursor.fetchone() + # Use shared SQLite pool for connection reuse + sqlite_pool = shared.get_sqlite_pool() + async with sqlite_pool.connection(randmsg_db_path, timeout=1) as _db: + async with _db.execute(db_query) as _cursor: + result = await _cursor.fetchone() if not result: return JSONResponse(content={"err": True}) (result_id, result_msg) = result diff --git a/lyric_search/sources/redis_cache.py b/lyric_search/sources/redis_cache.py index 33812d6..4cb74e9 100644 --- a/lyric_search/sources/redis_cache.py +++ b/lyric_search/sources/redis_cache.py @@ -16,7 +16,7 @@ from redis.commands.search.query import Query # type: ignore from redis.commands.search.index_definition import IndexDefinition, IndexType # type: ignore from redis.commands.search.field import TextField, Field # type: ignore from redis.commands.json.path import Path # type: ignore -from . import private +import shared # Use shared Redis pool logger = logging.getLogger() log_level = logging.getLevelName(logger.level) @@ -34,7 +34,8 @@ class RedisCache: """ def __init__(self) -> None: - self.redis_client: redis.Redis = redis.Redis(password=private.REDIS_PW) + # Use shared Redis client from connection pool + self.redis_client: redis.Redis = shared.get_redis_async_client() self.notifier = notifier.DiscordNotifier() self.notify_warnings = False self.regexes: list[Pattern] = [ @@ -51,9 +52,9 @@ class RedisCache: try: await self.redis_client.ping() except Exception: - logging.debug("Redis connection lost, attempting to reconnect.") - self.redis_client = redis.Redis(password=private.REDIS_PW) - await self.redis_client.ping() # Test the new connection + logging.debug("Redis connection lost, refreshing client from pool.") + # Get fresh client from shared pool + self.redis_client = shared.get_redis_async_client() async def create_index(self) -> None: """Create Index""" diff --git a/shared.py b/shared.py new file mode 100644 index 0000000..62a5426 --- /dev/null +++ b/shared.py @@ -0,0 +1,290 @@ +""" +Shared infrastructure: connection pools and sessions. + +This module provides centralized, reusable connections for: +- Redis (async connection pool) +- aiohttp (shared ClientSession) +- SQLite (connection pool for frequently accessed databases) + +Usage: + from shared import get_redis_client, get_aiohttp_session, get_sqlite_pool +""" + +import asyncio +import logging +from typing import Optional, Dict +from contextlib import asynccontextmanager + +import aiohttp +import redis.asyncio as redis_async +import redis as redis_sync +import aiosqlite + +from lyric_search.sources import private + +logger = logging.getLogger(__name__) + +# ============================================================================= +# Redis Connection Pool +# ============================================================================= + +_redis_async_pool: Optional[redis_async.ConnectionPool] = None +_redis_async_client: Optional[redis_async.Redis] = None +_redis_sync_client: Optional[redis_sync.Redis] = None +_redis_sync_client_decoded: Optional[redis_sync.Redis] = None + + +def _create_redis_pool() -> redis_async.ConnectionPool: + """Create a shared Redis connection pool.""" + return redis_async.ConnectionPool( + host="127.0.0.1", + port=6379, + password=private.REDIS_PW, + max_connections=50, + decode_responses=False, # Default; callers can decode as needed + ) + + +def get_redis_async_pool() -> redis_async.ConnectionPool: + """Get or create the shared async Redis connection pool.""" + global _redis_async_pool + if _redis_async_pool is None: + _redis_async_pool = _create_redis_pool() + return _redis_async_pool + + +def get_redis_async_client() -> redis_async.Redis: + """Get or create a shared async Redis client using the connection pool.""" + global _redis_async_client + if _redis_async_client is None: + _redis_async_client = redis_async.Redis(connection_pool=get_redis_async_pool()) + return _redis_async_client + + +def get_redis_sync_client(decode_responses: bool = True) -> redis_sync.Redis: + """ + Get or create a shared sync Redis client. + + We maintain two separate clients: one with decode_responses=True, + one with decode_responses=False, since this setting affects all operations. + """ + global _redis_sync_client, _redis_sync_client_decoded + + if decode_responses: + if _redis_sync_client_decoded is None: + _redis_sync_client_decoded = redis_sync.Redis( + host="127.0.0.1", + port=6379, + password=private.REDIS_PW, + decode_responses=True, + ) + return _redis_sync_client_decoded + else: + if _redis_sync_client is None: + _redis_sync_client = redis_sync.Redis( + host="127.0.0.1", + port=6379, + password=private.REDIS_PW, + decode_responses=False, + ) + return _redis_sync_client + + +async def close_redis_pools() -> None: + """Close Redis connections. Call on app shutdown.""" + global _redis_async_pool, _redis_async_client, _redis_sync_client, _redis_sync_client_decoded + + if _redis_async_client: + await _redis_async_client.close() + _redis_async_client = None + + if _redis_async_pool: + await _redis_async_pool.disconnect() + _redis_async_pool = None + + if _redis_sync_client: + _redis_sync_client.close() + _redis_sync_client = None + + if _redis_sync_client_decoded: + _redis_sync_client_decoded.close() + _redis_sync_client_decoded = None + + logger.info("Redis connections closed") + + +# ============================================================================= +# aiohttp Shared Session +# ============================================================================= + +_aiohttp_session: Optional[aiohttp.ClientSession] = None + + +async def get_aiohttp_session() -> aiohttp.ClientSession: + """ + Get or create a shared aiohttp ClientSession. + + The session uses connection pooling internally (default: 100 connections). + """ + global _aiohttp_session + if _aiohttp_session is None or _aiohttp_session.closed: + timeout = aiohttp.ClientTimeout(total=30, connect=10) + connector = aiohttp.TCPConnector( + limit=100, # Total connection pool size + limit_per_host=30, # Max connections per host + ttl_dns_cache=300, # DNS cache TTL + keepalive_timeout=60, + ) + _aiohttp_session = aiohttp.ClientSession( + timeout=timeout, + connector=connector, + ) + logger.info("Created shared aiohttp session") + return _aiohttp_session + + +async def close_aiohttp_session() -> None: + """Close the shared aiohttp session. Call on app shutdown.""" + global _aiohttp_session + if _aiohttp_session and not _aiohttp_session.closed: + await _aiohttp_session.close() + _aiohttp_session = None + logger.info("aiohttp session closed") + + +# ============================================================================= +# SQLite Connection Pool +# ============================================================================= + + +class SQLitePool: + """ + Simple SQLite connection pool for async access. + + Maintains a pool of connections per database file to avoid + opening/closing connections on every request. + """ + + def __init__(self, max_connections: int = 5): + self._pools: Dict[str, asyncio.Queue] = {} + self._max_connections = max_connections + self._locks: Dict[str, asyncio.Lock] = {} + self._connection_counts: Dict[str, int] = {} + + async def _get_pool(self, db_path: str) -> asyncio.Queue: + """Get or create a connection pool for the given database.""" + if db_path not in self._pools: + self._pools[db_path] = asyncio.Queue(maxsize=self._max_connections) + self._locks[db_path] = asyncio.Lock() + self._connection_counts[db_path] = 0 + return self._pools[db_path] + + @asynccontextmanager + async def connection(self, db_path: str, timeout: float = 5.0): + """ + Get a connection from the pool. + + Usage: + async with sqlite_pool.connection("/path/to/db.db") as conn: + async with conn.execute("SELECT ...") as cursor: + ... + """ + pool = await self._get_pool(db_path) + lock = self._locks[db_path] + conn: Optional[aiosqlite.Connection] = None + + # Try to get an existing connection from the pool + try: + conn = pool.get_nowait() + except asyncio.QueueEmpty: + # No available connection, create one if under limit + async with lock: + if self._connection_counts[db_path] < self._max_connections: + conn = await aiosqlite.connect(db_path, timeout=timeout) + self._connection_counts[db_path] += 1 + + # If still no connection (at limit), wait for one + if conn is None: + conn = await asyncio.wait_for(pool.get(), timeout=timeout) + + try: + # Verify connection is still valid + if conn is not None: + try: + await conn.execute("SELECT 1") + except Exception: + # Connection is broken, create a new one + try: + await conn.close() + except Exception: + pass + conn = await aiosqlite.connect(db_path, timeout=timeout) + + yield conn + finally: + # Return connection to pool + if conn is not None: + try: + pool.put_nowait(conn) + except asyncio.QueueFull: + # Pool is full, close this connection + await conn.close() + async with lock: + self._connection_counts[db_path] -= 1 + + async def close_all(self) -> None: + """Close all connections in all pools.""" + for db_path, pool in self._pools.items(): + while not pool.empty(): + try: + conn = pool.get_nowait() + await conn.close() + except asyncio.QueueEmpty: + break + self._connection_counts[db_path] = 0 + + self._pools.clear() + self._locks.clear() + self._connection_counts.clear() + logger.info("SQLite pools closed") + + +# Global SQLite pool instance +_sqlite_pool: Optional[SQLitePool] = None + + +def get_sqlite_pool() -> SQLitePool: + """Get the shared SQLite connection pool.""" + global _sqlite_pool + if _sqlite_pool is None: + _sqlite_pool = SQLitePool(max_connections=5) + return _sqlite_pool + + +async def close_sqlite_pools() -> None: + """Close all SQLite pools. Call on app shutdown.""" + global _sqlite_pool + if _sqlite_pool: + await _sqlite_pool.close_all() + _sqlite_pool = None + + +# ============================================================================= +# Lifecycle Management +# ============================================================================= + + +async def startup() -> None: + """Initialize all shared resources. Call on app startup.""" + # Pre-warm Redis connection + client = get_redis_async_client() + await client.ping() + logger.info("Shared infrastructure initialized") + + +async def shutdown() -> None: + """Clean up all shared resources. Call on app shutdown.""" + await close_aiohttp_session() + await close_redis_pools() + await close_sqlite_pools() + logger.info("Shared infrastructure shutdown complete") diff --git a/utils/sr_wrapper.py b/utils/sr_wrapper.py index 0e50198..f27448f 100644 --- a/utils/sr_wrapper.py +++ b/utils/sr_wrapper.py @@ -1148,10 +1148,10 @@ class SRUtil: async def get_lrc_by_track_id(self, track_id: int) -> Optional[str]: """Get LRC lyrics by track ID.""" - logging.info(f"SR: Fetching metadata for track ID {track_id}") + logging.debug(f"SR: Fetching metadata for track ID {track_id}") metadata = await self.get_metadata_by_track_id(track_id) lrc = metadata.get("lyrics") if metadata else None - logging.info(f"SR: LRC {'found' if lrc else 'not found'}") + logging.debug(f"SR: LRC {'found' if lrc else 'not found'}") return lrc async def get_lrc_by_artist_song( @@ -1162,7 +1162,7 @@ class SRUtil: duration: Optional[int] = None, ) -> Optional[str]: """Get LRC lyrics by artist and song, optionally filtering by album and duration.""" - logging.info(f"SR: Searching tracks for {artist} - {song}") + logging.debug(f"SR: Searching tracks for {artist} - {song}") tracks = await self.get_tracks_by_artist_song(artist, song) logging.info(f"SR: Found {len(tracks) if tracks else 0} tracks") if not tracks: