""" Shared infrastructure: connection pools and sessions. This module provides centralized, reusable connections for: - Redis (async connection pool) - aiohttp (shared ClientSession) - SQLite (connection pool for frequently accessed databases) Usage: from shared import get_redis_client, get_aiohttp_session, get_sqlite_pool """ import asyncio import logging from typing import Optional, Dict from contextlib import asynccontextmanager import aiohttp import redis.asyncio as redis_async import redis as redis_sync import aiosqlite from lyric_search.sources import private logger = logging.getLogger(__name__) # ============================================================================= # Redis Connection Pool # ============================================================================= _redis_async_pool: Optional[redis_async.ConnectionPool] = None _redis_async_client: Optional[redis_async.Redis] = None _redis_sync_client: Optional[redis_sync.Redis] = None _redis_sync_client_decoded: Optional[redis_sync.Redis] = None def _create_redis_pool() -> redis_async.ConnectionPool: """Create a shared Redis connection pool.""" return redis_async.ConnectionPool( host="127.0.0.1", port=6379, password=private.REDIS_PW, max_connections=50, decode_responses=False, # Default; callers can decode as needed ) def get_redis_async_pool() -> redis_async.ConnectionPool: """Get or create the shared async Redis connection pool.""" global _redis_async_pool if _redis_async_pool is None: _redis_async_pool = _create_redis_pool() return _redis_async_pool def get_redis_async_client() -> redis_async.Redis: """Get or create a shared async Redis client using the connection pool.""" global _redis_async_client if _redis_async_client is None: _redis_async_client = redis_async.Redis(connection_pool=get_redis_async_pool()) return _redis_async_client def get_redis_sync_client(decode_responses: bool = True) -> redis_sync.Redis: """ Get or create a shared sync Redis client. We maintain two separate clients: one with decode_responses=True, one with decode_responses=False, since this setting affects all operations. """ global _redis_sync_client, _redis_sync_client_decoded if decode_responses: if _redis_sync_client_decoded is None: _redis_sync_client_decoded = redis_sync.Redis( host="127.0.0.1", port=6379, password=private.REDIS_PW, decode_responses=True, ) return _redis_sync_client_decoded else: if _redis_sync_client is None: _redis_sync_client = redis_sync.Redis( host="127.0.0.1", port=6379, password=private.REDIS_PW, decode_responses=False, ) return _redis_sync_client async def close_redis_pools() -> None: """Close Redis connections. Call on app shutdown.""" global _redis_async_pool, _redis_async_client, _redis_sync_client, _redis_sync_client_decoded if _redis_async_client: await _redis_async_client.close() _redis_async_client = None if _redis_async_pool: await _redis_async_pool.disconnect() _redis_async_pool = None if _redis_sync_client: _redis_sync_client.close() _redis_sync_client = None if _redis_sync_client_decoded: _redis_sync_client_decoded.close() _redis_sync_client_decoded = None logger.info("Redis connections closed") # ============================================================================= # aiohttp Shared Session # ============================================================================= _aiohttp_session: Optional[aiohttp.ClientSession] = None async def get_aiohttp_session() -> aiohttp.ClientSession: """ Get or create a shared aiohttp ClientSession. The session uses connection pooling internally (default: 100 connections). """ global _aiohttp_session if _aiohttp_session is None or _aiohttp_session.closed: timeout = aiohttp.ClientTimeout(total=30, connect=10) connector = aiohttp.TCPConnector( limit=100, # Total connection pool size limit_per_host=30, # Max connections per host ttl_dns_cache=300, # DNS cache TTL keepalive_timeout=60, ) _aiohttp_session = aiohttp.ClientSession( timeout=timeout, connector=connector, ) logger.info("Created shared aiohttp session") return _aiohttp_session async def close_aiohttp_session() -> None: """Close the shared aiohttp session. Call on app shutdown.""" global _aiohttp_session if _aiohttp_session and not _aiohttp_session.closed: await _aiohttp_session.close() _aiohttp_session = None logger.info("aiohttp session closed") # ============================================================================= # SQLite Connection Pool # ============================================================================= class SQLitePool: """ Simple SQLite connection pool for async access. Maintains a pool of connections per database file to avoid opening/closing connections on every request. """ def __init__(self, max_connections: int = 5): self._pools: Dict[str, asyncio.Queue] = {} self._max_connections = max_connections self._locks: Dict[str, asyncio.Lock] = {} self._connection_counts: Dict[str, int] = {} async def _get_pool(self, db_path: str) -> asyncio.Queue: """Get or create a connection pool for the given database.""" if db_path not in self._pools: self._pools[db_path] = asyncio.Queue(maxsize=self._max_connections) self._locks[db_path] = asyncio.Lock() self._connection_counts[db_path] = 0 return self._pools[db_path] @asynccontextmanager async def connection(self, db_path: str, timeout: float = 5.0): """ Get a connection from the pool. Usage: async with sqlite_pool.connection("/path/to/db.db") as conn: async with conn.execute("SELECT ...") as cursor: ... """ pool = await self._get_pool(db_path) lock = self._locks[db_path] conn: Optional[aiosqlite.Connection] = None # Try to get an existing connection from the pool try: conn = pool.get_nowait() except asyncio.QueueEmpty: # No available connection, create one if under limit async with lock: if self._connection_counts[db_path] < self._max_connections: conn = await aiosqlite.connect(db_path, timeout=timeout) self._connection_counts[db_path] += 1 # If still no connection (at limit), wait for one if conn is None: conn = await asyncio.wait_for(pool.get(), timeout=timeout) try: # Verify connection is still valid if conn is not None: try: await conn.execute("SELECT 1") except Exception: # Connection is broken, create a new one try: await conn.close() except Exception: pass conn = await aiosqlite.connect(db_path, timeout=timeout) yield conn finally: # Return connection to pool if conn is not None: try: pool.put_nowait(conn) except asyncio.QueueFull: # Pool is full, close this connection await conn.close() async with lock: self._connection_counts[db_path] -= 1 async def close_all(self) -> None: """Close all connections in all pools.""" for db_path, pool in self._pools.items(): while not pool.empty(): try: conn = pool.get_nowait() await conn.close() except asyncio.QueueEmpty: break self._connection_counts[db_path] = 0 self._pools.clear() self._locks.clear() self._connection_counts.clear() logger.info("SQLite pools closed") # Global SQLite pool instance _sqlite_pool: Optional[SQLitePool] = None def get_sqlite_pool() -> SQLitePool: """Get the shared SQLite connection pool.""" global _sqlite_pool if _sqlite_pool is None: _sqlite_pool = SQLitePool(max_connections=5) return _sqlite_pool async def close_sqlite_pools() -> None: """Close all SQLite pools. Call on app shutdown.""" global _sqlite_pool if _sqlite_pool: await _sqlite_pool.close_all() _sqlite_pool = None # ============================================================================= # Lifecycle Management # ============================================================================= async def startup() -> None: """Initialize all shared resources. Call on app startup.""" # Pre-warm Redis connection client = get_redis_async_client() await client.ping() logger.info("Shared infrastructure initialized") async def shutdown() -> None: """Clean up all shared resources. Call on app shutdown.""" await close_aiohttp_session() await close_redis_pools() await close_sqlite_pools() logger.info("Shared infrastructure shutdown complete")