291 lines
9.5 KiB
Python
291 lines
9.5 KiB
Python
"""
|
|
Shared infrastructure: connection pools and sessions.
|
|
|
|
This module provides centralized, reusable connections for:
|
|
- Redis (async connection pool)
|
|
- aiohttp (shared ClientSession)
|
|
- SQLite (connection pool for frequently accessed databases)
|
|
|
|
Usage:
|
|
from shared import get_redis_client, get_aiohttp_session, get_sqlite_pool
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Optional, Dict
|
|
from contextlib import asynccontextmanager
|
|
|
|
import aiohttp
|
|
import redis.asyncio as redis_async
|
|
import redis as redis_sync
|
|
import aiosqlite
|
|
|
|
from lyric_search.sources import private
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# =============================================================================
|
|
# Redis Connection Pool
|
|
# =============================================================================
|
|
|
|
_redis_async_pool: Optional[redis_async.ConnectionPool] = None
|
|
_redis_async_client: Optional[redis_async.Redis] = None
|
|
_redis_sync_client: Optional[redis_sync.Redis] = None
|
|
_redis_sync_client_decoded: Optional[redis_sync.Redis] = None
|
|
|
|
|
|
def _create_redis_pool() -> redis_async.ConnectionPool:
|
|
"""Create a shared Redis connection pool."""
|
|
return redis_async.ConnectionPool(
|
|
host="127.0.0.1",
|
|
port=6379,
|
|
password=private.REDIS_PW,
|
|
max_connections=50,
|
|
decode_responses=False, # Default; callers can decode as needed
|
|
)
|
|
|
|
|
|
def get_redis_async_pool() -> redis_async.ConnectionPool:
|
|
"""Get or create the shared async Redis connection pool."""
|
|
global _redis_async_pool
|
|
if _redis_async_pool is None:
|
|
_redis_async_pool = _create_redis_pool()
|
|
return _redis_async_pool
|
|
|
|
|
|
def get_redis_async_client() -> redis_async.Redis:
|
|
"""Get or create a shared async Redis client using the connection pool."""
|
|
global _redis_async_client
|
|
if _redis_async_client is None:
|
|
_redis_async_client = redis_async.Redis(connection_pool=get_redis_async_pool())
|
|
return _redis_async_client
|
|
|
|
|
|
def get_redis_sync_client(decode_responses: bool = True) -> redis_sync.Redis:
|
|
"""
|
|
Get or create a shared sync Redis client.
|
|
|
|
We maintain two separate clients: one with decode_responses=True,
|
|
one with decode_responses=False, since this setting affects all operations.
|
|
"""
|
|
global _redis_sync_client, _redis_sync_client_decoded
|
|
|
|
if decode_responses:
|
|
if _redis_sync_client_decoded is None:
|
|
_redis_sync_client_decoded = redis_sync.Redis(
|
|
host="127.0.0.1",
|
|
port=6379,
|
|
password=private.REDIS_PW,
|
|
decode_responses=True,
|
|
)
|
|
return _redis_sync_client_decoded
|
|
else:
|
|
if _redis_sync_client is None:
|
|
_redis_sync_client = redis_sync.Redis(
|
|
host="127.0.0.1",
|
|
port=6379,
|
|
password=private.REDIS_PW,
|
|
decode_responses=False,
|
|
)
|
|
return _redis_sync_client
|
|
|
|
|
|
async def close_redis_pools() -> None:
|
|
"""Close Redis connections. Call on app shutdown."""
|
|
global _redis_async_pool, _redis_async_client, _redis_sync_client, _redis_sync_client_decoded
|
|
|
|
if _redis_async_client:
|
|
await _redis_async_client.close()
|
|
_redis_async_client = None
|
|
|
|
if _redis_async_pool:
|
|
await _redis_async_pool.disconnect()
|
|
_redis_async_pool = None
|
|
|
|
if _redis_sync_client:
|
|
_redis_sync_client.close()
|
|
_redis_sync_client = None
|
|
|
|
if _redis_sync_client_decoded:
|
|
_redis_sync_client_decoded.close()
|
|
_redis_sync_client_decoded = None
|
|
|
|
logger.info("Redis connections closed")
|
|
|
|
|
|
# =============================================================================
|
|
# aiohttp Shared Session
|
|
# =============================================================================
|
|
|
|
_aiohttp_session: Optional[aiohttp.ClientSession] = None
|
|
|
|
|
|
async def get_aiohttp_session() -> aiohttp.ClientSession:
|
|
"""
|
|
Get or create a shared aiohttp ClientSession.
|
|
|
|
The session uses connection pooling internally (default: 100 connections).
|
|
"""
|
|
global _aiohttp_session
|
|
if _aiohttp_session is None or _aiohttp_session.closed:
|
|
timeout = aiohttp.ClientTimeout(total=30, connect=10)
|
|
connector = aiohttp.TCPConnector(
|
|
limit=100, # Total connection pool size
|
|
limit_per_host=30, # Max connections per host
|
|
ttl_dns_cache=300, # DNS cache TTL
|
|
keepalive_timeout=60,
|
|
)
|
|
_aiohttp_session = aiohttp.ClientSession(
|
|
timeout=timeout,
|
|
connector=connector,
|
|
)
|
|
logger.info("Created shared aiohttp session")
|
|
return _aiohttp_session
|
|
|
|
|
|
async def close_aiohttp_session() -> None:
|
|
"""Close the shared aiohttp session. Call on app shutdown."""
|
|
global _aiohttp_session
|
|
if _aiohttp_session and not _aiohttp_session.closed:
|
|
await _aiohttp_session.close()
|
|
_aiohttp_session = None
|
|
logger.info("aiohttp session closed")
|
|
|
|
|
|
# =============================================================================
|
|
# SQLite Connection Pool
|
|
# =============================================================================
|
|
|
|
|
|
class SQLitePool:
|
|
"""
|
|
Simple SQLite connection pool for async access.
|
|
|
|
Maintains a pool of connections per database file to avoid
|
|
opening/closing connections on every request.
|
|
"""
|
|
|
|
def __init__(self, max_connections: int = 5):
|
|
self._pools: Dict[str, asyncio.Queue] = {}
|
|
self._max_connections = max_connections
|
|
self._locks: Dict[str, asyncio.Lock] = {}
|
|
self._connection_counts: Dict[str, int] = {}
|
|
|
|
async def _get_pool(self, db_path: str) -> asyncio.Queue:
|
|
"""Get or create a connection pool for the given database."""
|
|
if db_path not in self._pools:
|
|
self._pools[db_path] = asyncio.Queue(maxsize=self._max_connections)
|
|
self._locks[db_path] = asyncio.Lock()
|
|
self._connection_counts[db_path] = 0
|
|
return self._pools[db_path]
|
|
|
|
@asynccontextmanager
|
|
async def connection(self, db_path: str, timeout: float = 5.0):
|
|
"""
|
|
Get a connection from the pool.
|
|
|
|
Usage:
|
|
async with sqlite_pool.connection("/path/to/db.db") as conn:
|
|
async with conn.execute("SELECT ...") as cursor:
|
|
...
|
|
"""
|
|
pool = await self._get_pool(db_path)
|
|
lock = self._locks[db_path]
|
|
conn: Optional[aiosqlite.Connection] = None
|
|
|
|
# Try to get an existing connection from the pool
|
|
try:
|
|
conn = pool.get_nowait()
|
|
except asyncio.QueueEmpty:
|
|
# No available connection, create one if under limit
|
|
async with lock:
|
|
if self._connection_counts[db_path] < self._max_connections:
|
|
conn = await aiosqlite.connect(db_path, timeout=timeout)
|
|
self._connection_counts[db_path] += 1
|
|
|
|
# If still no connection (at limit), wait for one
|
|
if conn is None:
|
|
conn = await asyncio.wait_for(pool.get(), timeout=timeout)
|
|
|
|
try:
|
|
# Verify connection is still valid
|
|
if conn is not None:
|
|
try:
|
|
await conn.execute("SELECT 1")
|
|
except Exception:
|
|
# Connection is broken, create a new one
|
|
try:
|
|
await conn.close()
|
|
except Exception:
|
|
pass
|
|
conn = await aiosqlite.connect(db_path, timeout=timeout)
|
|
|
|
yield conn
|
|
finally:
|
|
# Return connection to pool
|
|
if conn is not None:
|
|
try:
|
|
pool.put_nowait(conn)
|
|
except asyncio.QueueFull:
|
|
# Pool is full, close this connection
|
|
await conn.close()
|
|
async with lock:
|
|
self._connection_counts[db_path] -= 1
|
|
|
|
async def close_all(self) -> None:
|
|
"""Close all connections in all pools."""
|
|
for db_path, pool in self._pools.items():
|
|
while not pool.empty():
|
|
try:
|
|
conn = pool.get_nowait()
|
|
await conn.close()
|
|
except asyncio.QueueEmpty:
|
|
break
|
|
self._connection_counts[db_path] = 0
|
|
|
|
self._pools.clear()
|
|
self._locks.clear()
|
|
self._connection_counts.clear()
|
|
logger.info("SQLite pools closed")
|
|
|
|
|
|
# Global SQLite pool instance
|
|
_sqlite_pool: Optional[SQLitePool] = None
|
|
|
|
|
|
def get_sqlite_pool() -> SQLitePool:
|
|
"""Get the shared SQLite connection pool."""
|
|
global _sqlite_pool
|
|
if _sqlite_pool is None:
|
|
_sqlite_pool = SQLitePool(max_connections=5)
|
|
return _sqlite_pool
|
|
|
|
|
|
async def close_sqlite_pools() -> None:
|
|
"""Close all SQLite pools. Call on app shutdown."""
|
|
global _sqlite_pool
|
|
if _sqlite_pool:
|
|
await _sqlite_pool.close_all()
|
|
_sqlite_pool = None
|
|
|
|
|
|
# =============================================================================
|
|
# Lifecycle Management
|
|
# =============================================================================
|
|
|
|
|
|
async def startup() -> None:
|
|
"""Initialize all shared resources. Call on app startup."""
|
|
# Pre-warm Redis connection
|
|
client = get_redis_async_client()
|
|
await client.ping()
|
|
logger.info("Shared infrastructure initialized")
|
|
|
|
|
|
async def shutdown() -> None:
|
|
"""Clean up all shared resources. Call on app shutdown."""
|
|
await close_aiohttp_session()
|
|
await close_redis_pools()
|
|
await close_sqlite_pools()
|
|
logger.info("Shared infrastructure shutdown complete")
|