diff --git a/base.py b/base.py index 14e7dbd..5f15fb8 100644 --- a/base.py +++ b/base.py @@ -4,6 +4,7 @@ import sys sys.path.insert(0, ".") import logging import asyncio +from contextlib import asynccontextmanager from typing import Any from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware @@ -13,20 +14,61 @@ from lyric_search.sources import redis_cache logging.basicConfig(level=logging.INFO) logging.getLogger("aiosqlite").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING) +logging.getLogger("python_multipart.multipart").setLevel(logging.WARNING) +logging.getLogger("streamrip").setLevel(logging.WARNING) +logging.getLogger("utils.sr_wrapper").setLevel(logging.WARNING) logger = logging.getLogger() loop = asyncio.get_event_loop() + +# Pre-import endpoint modules so we can wire up lifespan +constants = importlib.import_module("constants").Constants() + +# Will be set after app creation +_routes: dict = {} + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Lifespan context manager for startup/shutdown events.""" + # Startup + uvicorn_access_logger = logging.getLogger("uvicorn.access") + uvicorn_access_logger.disabled = True + + # Start Radio playlists + if "radio" in _routes and hasattr(_routes["radio"], "on_start"): + await _routes["radio"].on_start() + + # Start endpoint background tasks + if "trip" in _routes and hasattr(_routes["trip"], "startup"): + await _routes["trip"].startup() + if "lighting" in _routes and hasattr(_routes["lighting"], "startup"): + await _routes["lighting"].startup() + + logger.info("Application startup complete") + + yield + + # Shutdown + if "lighting" in _routes and hasattr(_routes["lighting"], "shutdown"): + await _routes["lighting"].shutdown() + if "trip" in _routes and hasattr(_routes["trip"], "shutdown"): + await _routes["trip"].shutdown() + + logger.info("Application shutdown complete") + + app = FastAPI( title="codey.lol API", version="1.0", contact={"name": "codey"}, redirect_slashes=False, loop=loop, - docs_url="/docs", # Swagger UI (default) - redoc_url="/redoc", # ReDoc UI (default, but explicitly set) + docs_url=None, # Disabled - using Scalar at /docs instead + redoc_url="/redoc", + lifespan=lifespan, ) -constants = importlib.import_module("constants").Constants() util = importlib.import_module("util").Utilities(app, constants) origins = [ @@ -48,8 +90,8 @@ app.add_middleware( ) # type: ignore -# Add Scalar API documentation endpoint (before blacklist routes) -@app.get("/scalar", include_in_schema=False) +# Scalar API documentation at /docs (replaces default Swagger UI) +@app.get("/docs", include_in_schema=False) def scalar_docs(): return get_scalar_api_reference(openapi_url="/openapi.json", title="codey.lol API") @@ -72,7 +114,7 @@ def base_head(): @app.get("/{path}", include_in_schema=False) def disallow_get_any(request: Request, var: Any = None): path = request.path_params["path"] - allowed_paths = ["widget", "misc/no", "docs", "redoc", "scalar", "openapi.json"] + allowed_paths = ["widget", "misc/no", "docs", "redoc", "openapi.json"] logging.info( f"Checking path: {path}, allowed: {path in allowed_paths or path.split('/', maxsplit=1)[0] in allowed_paths}" ) @@ -99,7 +141,7 @@ End Blacklisted Routes Actionable Routes """ -routes: dict = { +_routes.update({ "randmsg": importlib.import_module("endpoints.rand_msg").RandMsg( app, util, constants ), @@ -116,12 +158,12 @@ routes: dict = { "lighting": importlib.import_module("endpoints.lighting").Lighting( app, util, constants ), -} +}) # Misc endpoint depends on radio endpoint instance -radio_endpoint = routes.get("radio") +radio_endpoint = _routes.get("radio") if radio_endpoint: - routes["misc"] = importlib.import_module("endpoints.misc").Misc( + _routes["misc"] = importlib.import_module("endpoints.misc").Misc( app, util, constants, radio_endpoint ) @@ -133,12 +175,5 @@ End Actionable Routes Startup """ - -async def on_start(): - uvicorn_access_logger = logging.getLogger("uvicorn.access") - uvicorn_access_logger.disabled = True - - -app.add_event_handler("startup", on_start) redis = redis_cache.RedisCache() loop.create_task(redis.create_index()) diff --git a/endpoints/lighting.py b/endpoints/lighting.py index 02ceba4..a1314b1 100644 --- a/endpoints/lighting.py +++ b/endpoints/lighting.py @@ -1,741 +1,628 @@ +""" +Cync Lighting Control API + +This module provides a FastAPI endpoint for controlling Cync smart lights. +It maintains a persistent connection to the Cync cloud service and handles +authentication, token caching, and connection lifecycle management. + +Key behaviors: +- pycync uses a TCP/TLS connection that requires login acknowledgment before commands work +- Commands are sent through a WiFi-connected "hub" device to the Bluetooth mesh +- The TCP manager auto-reconnects on disconnect with a 10-second delay +- We wait for the connection to be fully ready before sending commands +""" + import logging import json import os import time -import aiohttp import asyncio -import traceback -from datetime import datetime +from typing import Optional, Any +from dataclasses import dataclass + +import aiohttp from fastapi import FastAPI, Depends, HTTPException, Request from fastapi_throttle import RateLimiter from fastapi.responses import JSONResponse import redis + from lyric_search.sources import private from auth.deps import get_current_user from dotenv import load_dotenv + from pycync.user import User # type: ignore -from pycync.cync import Cync as Cync # type: ignore +from pycync.cync import Cync # type: ignore from pycync import Auth # type: ignore from pycync.exceptions import TwoFactorRequiredError, AuthFailedError # type: ignore -import inspect -import getpass -from typing import Optional -# Configure logging to write to a file for specific events -logging.basicConfig( - filename="cync_auth_events.log", - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", -) +# Configure logging +logger = logging.getLogger(__name__) -def _mask_token(token: Optional[str]) -> str: - """Mask sensitive token data for logging, showing only first/last 4 chars.""" - if not token or len(token) < 8: - return "" - return f"{token[:4]}...{token[-4:]}" +@dataclass +class CyncConnectionState: + """Track the state of our Cync connection.""" + session: Optional[aiohttp.ClientSession] = None + auth: Optional[Auth] = None + cync_api: Optional[Cync] = None + user: Optional[User] = None + connected_at: Optional[float] = None + last_command_at: Optional[float] = None -def _log_token_state(user, context: str): - """Log masked token state for debugging.""" - if not user: - logging.info(f"{context} - No user object available") - return - - try: - logging.info( - f"{context} - Token state: access=%s refresh=%s expires_at=%s", - _mask_token(getattr(user, "access_token", None)), - _mask_token(getattr(user, "refresh_token", None)), - getattr(user, "expires_at", None), - ) - except Exception as e: - logging.error(f"Error logging token state: {e}") - - -class Lighting(FastAPI): - async def _close_session_safely(self): - """Safely close the current session if it exists.""" - if self.session and not getattr(self.session, "closed", True): - try: - await self.session.close() - # Wait a bit for the underlying connections to close - await asyncio.sleep(0.1) - except Exception as e: - logging.warning(f"Error closing session: {e}") - self.session = None - self.auth = None - self.cync_api = None - - async def _test_connection_health(self) -> bool: - """Test if the current connection is healthy by making a simple API call.""" - if ( - not self.cync_api - or not self.session - or getattr(self.session, "closed", True) - ): - return False - - try: - # Make a simple API call to test connectivity - devices = self.cync_api.get_devices() - # Just check if we get a response without errors - return devices is not None - except Exception as e: - logging.warning(f"Connection health check failed: {e}") - return False - - async def ensure_cync_connection(self, force_reconnect: bool = False): - """Ensure aiohttp session and Cync API are alive, re-create if needed.""" - # Check required environment variables - missing_vars = [] - if not self.cync_email: - missing_vars.append("CYNC_EMAIL") - if not self.cync_password: - missing_vars.append("CYNC_PASSWORD") - if not self.cync_device_name: - missing_vars.append("CYNC_DEVICE_NAME") - if missing_vars: - raise Exception( - f"Missing required environment variables: {', '.join(missing_vars)}" - ) - - # Cast to str after check to silence linter - cync_email: str = self.cync_email # type: ignore - cync_password: str = self.cync_password # type: ignore - - # If force_reconnect is True or connection is unhealthy, rebuild everything - if force_reconnect or not await self._test_connection_health(): - logging.info( - "Connection unhealthy or force reconnect requested. Rebuilding connection..." - ) - - # Clean up existing connection - await self._close_session_safely() - - # Create new session with timeout configuration - timeout = aiohttp.ClientTimeout(total=30, connect=10) - connector = aiohttp.TCPConnector( - limit=100, - limit_per_host=30, - ttl_dns_cache=300, - use_dns_cache=True, - keepalive_timeout=60, - enable_cleanup_closed=True, - ) - self.session = aiohttp.ClientSession(timeout=timeout, connector=connector) - - # Load cached token and check validity - self.cync_user = None - cached_user = self._load_cached_user() - token_status = None - - if cached_user and hasattr(cached_user, "expires_at"): - # Add buffer time - consider token expired if less than 5 minutes remaining - buffer_time = 300 # 5 minutes - if cached_user.expires_at > (time.time() + buffer_time): - token_status = "valid" - else: - token_status = "expired" - else: - token_status = "no cached user or missing expires_at" - - logging.info(f"Cync token status: {token_status}") - - if token_status == "valid" and cached_user is not None: - # Use cached token - self.auth = Auth( - session=self.session, - user=cached_user, - username=cync_email, - password=cync_password, - ) - self.cync_user = cached_user - logging.info("Reusing valid cached token, no 2FA required.") - else: - # Need fresh login - clear any cached user that's expired - if token_status == "expired": - try: - os.remove(self.token_cache_path) - logging.info("Removed expired token cache") - except (OSError, FileNotFoundError): - pass - - logging.info("Initializing new Auth instance...") - self.auth = Auth( - session=self.session, - username=cync_email, - password=cync_password, - ) - try: - logging.info("Attempting fresh login...") - self.cync_user = await self.auth.login() - _log_token_state(self.cync_user, "After fresh login") - self._save_cached_user(self.cync_user) - logging.info("Fresh login successful") - except TwoFactorRequiredError: - twofa_code = os.getenv("CYNC_2FA_CODE") - if not twofa_code: - print("Cync 2FA required. Please enter your code:") - twofa_code = getpass.getpass("2FA Code: ") - if twofa_code: - logging.info("Retrying Cync login with 2FA code.") - try: - self.cync_user = await self.auth.login( - two_factor_code=twofa_code - ) - self._save_cached_user(self.cync_user) - logging.info("Logged in with 2FA successfully.") - except Exception as e: - logging.error("Cync 2FA login failed: %s", e) - logging.info( - "2FA failure details: Code=%s, User=%s", - twofa_code, - self.cync_user, - ) - raise Exception("Cync 2FA code invalid or not accepted.") - else: - logging.error("Cync 2FA required but no code provided.") - raise Exception("Cync 2FA required.") - except AuthFailedError as e: - logging.error("Failed to authenticate with Cync API: %s", e) - raise Exception("Cync authentication failed.") - - # Create new Cync API instance - try: - logging.info("Creating Cync API instance...") - _log_token_state(self.auth.user, "Before Cync.create") - self.cync_api = await Cync.create(self.auth) - logging.info("Cync API connection established successfully") - except Exception as e: - logging.error("Failed to create Cync API instance") - logging.error("Exception details: %s", str(e)) - logging.error("Traceback:\n%s", traceback.format_exc()) - - # Save diagnostic info - diagnostic_data = { - "timestamp": datetime.now().isoformat(), - "error_type": type(e).__name__, - "error_message": str(e), - "auth_state": { - "has_auth": bool(self.auth), - "has_user": bool(getattr(self.auth, "user", None)), - "user_state": { - "access_token": _mask_token( - getattr(self.auth.user, "access_token", None) - ) - if self.auth and self.auth.user - else None, - "refresh_token": _mask_token( - getattr(self.auth.user, "refresh_token", None) - ) - if self.auth and self.auth.user - else None, - "expires_at": getattr(self.auth.user, "expires_at", None) - if self.auth and self.auth.user - else None, - } - if self.auth and self.auth.user - else None, - }, - } - diagnostic_file = f"cync_api_failure-{int(time.time())}.json" - try: - with open(diagnostic_file, "w") as f: - json.dump(diagnostic_data, f, indent=2) - logging.info( - f"Saved API creation diagnostic data to {diagnostic_file}" - ) - except Exception as save_error: - logging.error(f"Failed to save diagnostic data: {save_error}") - raise - - # Final validation - if ( - not self.cync_api - or not self.session - or getattr(self.session, "closed", True) - ): - logging.error("Connection validation failed after setup") - _log_token_state( - getattr(self.auth, "user", None), "Failed connection validation" - ) - raise Exception("Failed to establish proper Cync connection") - +class Lighting: """ - Lighting Endpoints + Cync Lighting Controller + + Manages authentication and device control for Cync smart lights. + Uses pycync library which maintains a TCP connection for device commands. """ - - def __init__(self, app: FastAPI, util, constants) -> None: - """Initialize Lighting endpoints and persistent Cync connection.""" + + # Configuration + TOKEN_EXPIRY_BUFFER = 300 # Consider token expired 5 min before actual expiry + CONNECTION_READY_TIMEOUT = 15 # Max seconds to wait for TCP connection to be ready + COMMAND_DELAY = 0.3 # Delay between sequential commands + MAX_RETRIES = 3 + + def __init__(self, app: FastAPI, util: Any, constants: Any) -> None: load_dotenv() - self.app: FastAPI = app + + self.app = app self.util = util self.constants = constants + + # Redis for state persistence self.redis_client = redis.Redis( - password=private.REDIS_PW, decode_responses=True + password=private.REDIS_PW, + decode_responses=True ) self.lighting_key = "lighting:state" - - # Cync config + + # Cync configuration from environment self.cync_email = os.getenv("CYNC_EMAIL") self.cync_password = os.getenv("CYNC_PASSWORD") self.cync_device_name = os.getenv("CYNC_DEVICE_NAME") self.token_cache_path = "cync_token.json" - self.session = None - self.auth = None - self.cync_user = None - self.cync_api = None - self.health_check_task: Optional[asyncio.Task] = None - - # Set up Cync connection at startup using FastAPI event - @app.on_event("startup") - async def startup_event(): - # Check required environment variables - missing_vars = [] - if not self.cync_email: - missing_vars.append("CYNC_EMAIL") - if not self.cync_password: - missing_vars.append("CYNC_PASSWORD") - if not self.cync_device_name: - missing_vars.append("CYNC_DEVICE_NAME") - if missing_vars: - raise Exception( - f"Missing required environment variables: {', '.join(missing_vars)}" - ) - - # Use ensure_cync_connection which has proper token caching - try: - await self.ensure_cync_connection() - logging.info("Cync lighting system initialized successfully") - except Exception as e: - logging.error(f"Failed to initialize Cync connection at startup: {e}") - # Don't raise - allow server to start, connection will be retried on first request - - # Schedule periodic token validation and connection health checks - self.health_check_task = asyncio.create_task(self._schedule_health_checks()) - - @app.on_event("shutdown") - async def shutdown_event(): - # Cancel health check task - if self.health_check_task and not self.health_check_task.done(): - self.health_check_task.cancel() - try: - await self.health_check_task - except asyncio.CancelledError: - logging.info("Health check task cancelled successfully") - pass - - # Clean up connections - await self._close_session_safely() - logging.info("Cync lighting system shut down cleanly") - - # Register endpoints - self.endpoints: dict = { - "lighting/state": self.get_lighting_state, - } - - for endpoint, handler in self.endpoints.items(): - self.app.add_api_route( - f"/{endpoint}", - handler, - methods=["GET"], - include_in_schema=True, - dependencies=[ - Depends(RateLimiter(times=25, seconds=2)), - Depends(get_current_user), - ], - ) - + + # Connection state + self._state = CyncConnectionState() + self._connection_lock = asyncio.Lock() + self._health_task: Optional[asyncio.Task] = None + + # Register routes + self._register_routes() + + def _register_routes(self) -> None: + """Register FastAPI routes.""" + common_deps = [ + Depends(RateLimiter(times=25, seconds=2)), + Depends(get_current_user), + ] + + self.app.add_api_route( + "/lighting/state", + self.get_lighting_state, + methods=["GET"], + dependencies=common_deps, + include_in_schema=False, + ) + self.app.add_api_route( "/lighting/state", self.set_lighting_state, methods=["POST"], - include_in_schema=True, - dependencies=[ - Depends(RateLimiter(times=25, seconds=2)), - Depends(get_current_user), - ], + dependencies=common_deps, + include_in_schema=False, ) - - async def _refresh_or_login(self): - if not self.auth: - logging.error("Auth object is not initialized.") - raise Exception("Cync authentication not initialized.") + + # ========================================================================= + # Lifecycle Management + # ========================================================================= + + async def startup(self) -> None: + """Initialize on app startup. Call from lifespan context manager.""" + self._validate_config() + try: - user = getattr(self.auth, "user", None) - _log_token_state(user, "Before refresh attempt") - - if user and hasattr(user, "expires_at") and user.expires_at > time.time(): - refresh = getattr(self.auth, "async_refresh_user_token", None) - if callable(refresh): - try: - logging.info("Attempting token refresh...") - result = refresh() - if inspect.isawaitable(result): - await result - logging.info( - "Token refresh completed successfully (awaited)" - ) - else: - logging.info("Token refresh completed (non-awaitable)") - except AuthFailedError as e: - logging.error("Token refresh failed with AuthFailedError") - logging.error("Exception details: %s", str(e)) - logging.error("Traceback:\n%s", traceback.format_exc()) - - # Save diagnostic info to file - diagnostic_data = { - "timestamp": datetime.now().isoformat(), - "error_type": "AuthFailedError", - "error_message": str(e), - "user_state": { - "access_token": _mask_token( - getattr(user, "access_token", None) - ), - "refresh_token": _mask_token( - getattr(user, "refresh_token", None) - ), - "expires_at": getattr(user, "expires_at", None), - }, - } - try: - diagnostic_file = ( - f"cync_auth_failure-{int(time.time())}.json" - ) - with open(diagnostic_file, "w") as f: - json.dump(diagnostic_data, f, indent=2) - logging.info(f"Saved diagnostic data to {diagnostic_file}") - except Exception as save_error: - logging.error( - f"Failed to save diagnostic data: {save_error}" - ) - raise - login = getattr(self.auth, "login", None) - if callable(login): - try: - result = login() - if inspect.isawaitable(result): - self.cync_user = await result - else: - self.cync_user = result - self._save_cached_user(self.cync_user) - logging.info("Logged in successfully.") - except TwoFactorRequiredError: - twofa_code = os.getenv("CYNC_2FA_CODE") - if not twofa_code: - # Prompt interactively if not set - print("Cync 2FA required. Please enter your code:") - twofa_code = getpass.getpass("2FA Code: ") - if twofa_code: - logging.info("Retrying Cync login with 2FA code.") - try: - result = login(two_factor_code=twofa_code) - if inspect.isawaitable(result): - self.cync_user = await result - else: - self.cync_user = result - self._save_cached_user(self.cync_user) - logging.info("Logged in with 2FA successfully.") - except Exception as e: - logging.error("Cync 2FA login failed: %s", e) - logging.info( - "2FA failure details: Code=%s, User=%s", - twofa_code, - self.cync_user, - ) - raise Exception("Cync 2FA code invalid or not accepted.") - else: - logging.error("Cync 2FA required but no code provided.") - raise Exception("Cync 2FA required.") - else: - raise Exception("Auth object missing login method.") - except AuthFailedError as e: - logging.error("Failed to authenticate with Cync API: %s", e) - raise Exception("Cync authentication failed.") + await self._connect() + logger.info("Cync lighting initialized successfully") except Exception as e: - logging.error("Unexpected error during authentication: %s", e) + logger.error(f"Failed to initialize Cync at startup: {e}") + # Don't raise - allow app to start, will retry on first request + + # Start background health monitoring + self._health_task = asyncio.create_task(self._health_monitor()) + + async def shutdown(self) -> None: + """Cleanup on app shutdown. Call from lifespan context manager.""" + if self._health_task: + self._health_task.cancel() + try: + await self._health_task + except asyncio.CancelledError: + pass + + await self._disconnect() + logger.info("Cync lighting shut down") + + def _validate_config(self) -> None: + """Validate required environment variables.""" + missing = [] + if not self.cync_email: + missing.append("CYNC_EMAIL") + if not self.cync_password: + missing.append("CYNC_PASSWORD") + if not self.cync_device_name: + missing.append("CYNC_DEVICE_NAME") + + if missing: + raise RuntimeError(f"Missing required env vars: {', '.join(missing)}") + + # ========================================================================= + # Connection Management + # ========================================================================= + + async def _connect(self, force: bool = False) -> None: + """ + Establish connection to Cync cloud. + + This creates the aiohttp session, authenticates, and initializes + the pycync API which starts its TCP connection. + """ + async with self._connection_lock: + # Check if we need to connect + if not force and self._is_connection_valid(): + return + + logger.info("Establishing Cync connection...") + + # Clean up existing connection + await self._disconnect_unlocked() + + # Create HTTP session + timeout = aiohttp.ClientTimeout(total=30, connect=10) + self._state.session = aiohttp.ClientSession(timeout=timeout) + + # Authenticate + await self._authenticate() + + # Create Cync API (starts TCP connection) + logger.info("Creating Cync API instance...") + assert self._state.auth is not None # Set by _authenticate + self._state.cync_api = await Cync.create(self._state.auth) + + # Wait for TCP connection to be ready + await self._wait_for_connection_ready() + + self._state.connected_at = time.time() + logger.info("Cync connection established") + + async def _disconnect(self) -> None: + """Disconnect and cleanup resources.""" + async with self._connection_lock: + await self._disconnect_unlocked() + + async def _disconnect_unlocked(self) -> None: + """Disconnect without acquiring lock (internal use).""" + # Shutdown pycync TCP connection + if self._state.cync_api: + try: + # pycync's command client has a shut_down method + client = getattr(self._state.cync_api, '_command_client', None) + if client: + await client.shut_down() + except Exception as e: + logger.warning(f"Error shutting down Cync client: {e}") + + # Close HTTP session + if self._state.session and not self._state.session.closed: + await self._state.session.close() + await asyncio.sleep(0.1) # Allow cleanup + + # Reset state + self._state = CyncConnectionState() + + def _is_connection_valid(self) -> bool: + """Check if current connection is usable.""" + if not self._state.cync_api or not self._state.session: + return False + + if self._state.session.closed: + return False + + # Check token expiry + if self._is_token_expired(): + logger.info("Token expired or expiring soon") + return False + + return True + + def _is_token_expired(self) -> bool: + """Check if token is expired or will expire soon.""" + if not self._state.user: + return True + + expires_at = getattr(self._state.user, 'expires_at', 0) + return expires_at < (time.time() + self.TOKEN_EXPIRY_BUFFER) + + async def _wait_for_connection_ready(self) -> None: + """ + Wait for pycync TCP connection to be fully ready. + + pycync's TCP manager waits for login acknowledgment before sending + any commands. We need to wait for this to complete. + """ + if not self._state.cync_api: + raise RuntimeError("Cync API not initialized") + + client = getattr(self._state.cync_api, '_command_client', None) + if not client: + logger.warning("Could not access command client") + return + + tcp_manager = getattr(client, '_tcp_manager', None) + if not tcp_manager: + logger.warning("Could not access TCP manager") + return + + # Wait for login to be acknowledged + start = time.time() + while not getattr(tcp_manager, '_login_acknowledged', False): + if time.time() - start > self.CONNECTION_READY_TIMEOUT: + raise TimeoutError("Timed out waiting for Cync login acknowledgment") + await asyncio.sleep(0.2) + logger.debug("Waiting for Cync TCP login acknowledgment...") + + # Give a tiny bit more time for device probing to start + await asyncio.sleep(0.5) + logger.info(f"Cync TCP connection ready (took {time.time() - start:.1f}s)") + + # ========================================================================= + # Authentication + # ========================================================================= + + async def _authenticate(self) -> None: + """Authenticate with Cync, using cached token if valid.""" + # Try cached token first + cached_user = self._load_cached_token() + + # These are validated by _validate_config at startup + assert self._state.session is not None + assert self.cync_email is not None + assert self.cync_password is not None + + if cached_user and not self._is_user_token_expired(cached_user): + logger.info("Using cached Cync token") + self._state.auth = Auth( + session=self._state.session, + user=cached_user, + username=self.cync_email, + password=self.cync_password, + ) + self._state.user = cached_user + return + + # Need fresh login + logger.info("Performing fresh Cync login...") + self._state.auth = Auth( + session=self._state.session, + username=self.cync_email, + password=self.cync_password, + ) + + try: + self._state.user = await self._state.auth.login() + self._save_cached_token(self._state.user) + logger.info("Cync login successful") + except TwoFactorRequiredError: + await self._handle_2fa() + except AuthFailedError as e: + logger.error(f"Cync authentication failed: {e}") raise - - async def _schedule_health_checks(self): - """Periodic health checks and token validation.""" + + async def _handle_2fa(self) -> None: + """Handle 2FA authentication.""" + import sys + + # Try environment variable first + twofa_code = os.getenv("CYNC_2FA_CODE") + + # If not set, prompt interactively + if not twofa_code: + print("\n" + "=" * 50) + print("CYNC 2FA REQUIRED") + print("=" * 50) + print("Check your email for the Cync verification code.") + print("Enter the code below (you have 60 seconds):") + print("=" * 50) + sys.stdout.flush() + + # Use asyncio to read with timeout + try: + loop = asyncio.get_event_loop() + twofa_code = await asyncio.wait_for( + loop.run_in_executor(None, input, "2FA Code: "), + timeout=60.0 + ) + twofa_code = twofa_code.strip() + except asyncio.TimeoutError: + logger.error("2FA code entry timed out") + raise RuntimeError("2FA code entry timed out") + + if not twofa_code: + logger.error("No 2FA code provided") + raise RuntimeError("Cync 2FA required but no code provided") + + logger.info("Retrying Cync login with 2FA code") + try: + assert self._state.auth is not None + self._state.user = await self._state.auth.login(two_factor_code=twofa_code) + self._save_cached_token(self._state.user) + logger.info("Cync 2FA login successful") + except Exception as e: + logger.error(f"Cync 2FA login failed: {e}") + raise + + def _is_user_token_expired(self, user: User) -> bool: + """Check if a user's token is expired.""" + expires_at = getattr(user, 'expires_at', 0) + return expires_at < (time.time() + self.TOKEN_EXPIRY_BUFFER) + + def _load_cached_token(self) -> Optional[User]: + """Load cached authentication token from disk.""" + try: + if not os.path.exists(self.token_cache_path): + return None + + with open(self.token_cache_path, 'r') as f: + data = json.load(f) + + return User( + access_token=data['access_token'], + refresh_token=data['refresh_token'], + authorize=data['authorize'], + user_id=data['user_id'], + expires_at=data['expires_at'], + ) + except Exception as e: + logger.warning(f"Failed to load cached token: {e}") + return None + + def _save_cached_token(self, user: User) -> None: + """Save authentication token to disk.""" + try: + data = { + 'access_token': user.access_token, + 'refresh_token': user.refresh_token, + 'authorize': user.authorize, + 'user_id': user.user_id, + 'expires_at': user.expires_at, + } + with open(self.token_cache_path, 'w') as f: + json.dump(data, f) + logger.debug("Saved Cync token to disk") + except Exception as e: + logger.warning(f"Failed to save token: {e}") + + def _clear_cached_token(self) -> None: + """Remove cached token file.""" + try: + if os.path.exists(self.token_cache_path): + os.remove(self.token_cache_path) + logger.info("Cleared cached token") + except OSError: + pass + + # ========================================================================= + # Health Monitoring + # ========================================================================= + + async def _health_monitor(self) -> None: + """Background task to monitor connection health and refresh tokens.""" while True: try: await asyncio.sleep(300) # Check every 5 minutes - - # Check token expiration (refresh if less than 10 minutes left) - if self.cync_user and hasattr(self.cync_user, "expires_at"): - expires_at = getattr(self.cync_user, "expires_at", 0) - time_until_expiry = expires_at - time.time() - if time_until_expiry < 600: # Less than 10 minutes - logging.info( - f"Token expires in {int(time_until_expiry / 60)} minutes. Refreshing..." - ) - try: - await self._refresh_or_login() - except Exception as e: - logging.error( - f"Token refresh failed during health check: {e}" - ) - - # Test connection health - if not await self._test_connection_health(): - logging.warning( - "Connection health check failed. Will reconnect on next API call." - ) - + + # Proactively refresh if token is expiring + if self._is_token_expired(): + logger.info("Token expiring, proactively reconnecting...") + try: + await self._connect(force=True) + except Exception as e: + logger.error(f"Proactive reconnection failed: {e}") + except asyncio.CancelledError: - logging.info("Health check task cancelled") break except Exception as e: - logging.error(f"Error during periodic health check: {e}") - # Continue the loop even on errors - - def _load_cached_user(self): - try: - if os.path.exists(self.token_cache_path): - with open(self.token_cache_path, "r") as f: - data = json.load(f) - return User( - access_token=data["access_token"], - refresh_token=data["refresh_token"], - authorize=data["authorize"], - user_id=data["user_id"], - expires_at=data["expires_at"], - ) - except Exception as e: - logging.warning("Failed to load cached Cync user: %s", e) - return None - - def _save_cached_user(self, user): - try: - data = { - "access_token": user.access_token, - "refresh_token": user.refresh_token, - "authorize": user.authorize, - "user_id": user.user_id, - "expires_at": user.expires_at, - } - with open(self.token_cache_path, "w") as f: - json.dump(data, f) - logging.info("Saved Cync user tokens to disk.") - except Exception as e: - logging.warning("Failed to save Cync user tokens: %s", e) - - async def get_lighting_state(self) -> JSONResponse: + logger.error(f"Health monitor error: {e}") + + # ========================================================================= + # Device Control + # ========================================================================= + + async def _get_device(self): + """Get the target light device.""" + if not self._state.cync_api: + raise RuntimeError("Cync not connected") + + devices = self._state.cync_api.get_devices() + if not devices: + raise RuntimeError("No devices found") + + device = next( + (d for d in devices if getattr(d, 'name', None) == self.cync_device_name), + None + ) + + if not device: + available = [getattr(d, 'name', 'unnamed') for d in devices] + raise RuntimeError( + f"Device '{self.cync_device_name}' not found. Available: {available}" + ) + + return device + + async def _send_commands( + self, + power: str, + brightness: Optional[int] = None, + rgb: Optional[tuple[int, int, int]] = None, + ) -> None: """ - Get the current lighting state. - - Returns: - - **JSONResponse**: Contains the current lighting state. + Send commands to the light device. + + Commands are sent sequentially with small delays to ensure + the TCP connection processes each one. """ + device = await self._get_device() + logger.info(f"Sending commands to device: {device.name}") + + # Power + if power == "on": + await device.turn_on() + logger.debug("Sent turn_on") + else: + await device.turn_off() + logger.debug("Sent turn_off") + await asyncio.sleep(self.COMMAND_DELAY) + + # Brightness + if brightness is not None: + await device.set_brightness(brightness) + logger.debug(f"Sent brightness: {brightness}") + await asyncio.sleep(self.COMMAND_DELAY) + + # Color + if rgb: + await device.set_rgb(rgb) + logger.debug(f"Sent RGB: {rgb}") + await asyncio.sleep(self.COMMAND_DELAY) + + self._state.last_command_at = time.time() + + # ========================================================================= + # API Endpoints + # ========================================================================= + + async def get_lighting_state(self, user=Depends(get_current_user)) -> JSONResponse: + """Get the current lighting state from Redis.""" + if "lighting" not in user.get("roles", []) and "admin" not in user.get("roles", []): + raise HTTPException(status_code=403, detail="Insufficient permissions") try: state = self.redis_client.get(self.lighting_key) if state: return JSONResponse(content=json.loads(str(state))) - else: - # Default state - default_state = { - "power": "off", - "brightness": 50, - "color": {"r": 255, "g": 255, "b": 255}, - } - return JSONResponse(content=default_state) + + # Default state + return JSONResponse(content={ + "power": "off", + "brightness": 50, + "color": {"r": 255, "g": 255, "b": 255}, + }) except Exception as e: - logging.error("Error getting lighting state: %s", e) + logger.error(f"Error getting lighting state: {e}") raise HTTPException(status_code=500, detail="Internal server error") - - async def set_lighting_state(self, request: Request) -> JSONResponse: - """ - Set the lighting state and apply it to the Cync device. - """ - logging.info("=== LIGHTING STATE REQUEST RECEIVED ===") + + async def set_lighting_state(self, request: Request, + user=Depends(get_current_user)) -> JSONResponse: + """Set the lighting state and apply to Cync device.""" try: + if "lighting" not in user.get("roles", []) and "admin" not in user.get("roles", []): + raise HTTPException(status_code=403, detail="Insufficient permissions") state = await request.json() - logging.info(f"Requested state: {state}") - # Validate state (basic validation) + logger.info(f"Lighting request: {state}") + + # Validate if not isinstance(state, dict): - raise HTTPException( - status_code=400, detail="State must be a JSON object" - ) - - # Store in Redis + raise HTTPException(status_code=400, detail="State must be a JSON object") + + power, brightness, rgb = self._parse_state(state) + + # Save to Redis (even if device command fails) self.redis_client.set(self.lighting_key, json.dumps(state)) - - await self.ensure_cync_connection() - - # Validate and extract state values - power = state.get("power", "off") - if power not in ["on", "off"]: - raise HTTPException( - status_code=400, detail=f"Invalid power state: {power}" - ) - - brightness = state.get("brightness", 50) - if not isinstance(brightness, (int, float)) or not (0 <= brightness <= 100): - raise HTTPException( - status_code=400, detail=f"Invalid brightness: {brightness}" - ) - - color = state.get("color") - if ( - color - and isinstance(color, dict) - and all(k in color for k in ["r", "g", "b"]) - ): - rgb = (color["r"], color["g"], color["b"]) - elif all(k in state for k in ["red", "green", "blue"]): - rgb = (state["red"], state["green"], state["blue"]) - for val, name in zip(rgb, ["red", "green", "blue"]): - if not isinstance(val, int) or not (0 <= val <= 255): - raise HTTPException( - status_code=400, detail=f"Invalid {name} color value: {val}" - ) - else: - rgb = None - - # Apply to Cync device with robust retry and error handling - max_retries = 3 - last_exception: Exception = Exception("No attempts made") - - for attempt in range(max_retries): - try: - # Ensure connection before each attempt - force_reconnect = attempt > 0 # Force reconnect on retries - await self.ensure_cync_connection(force_reconnect=force_reconnect) - - if not self.cync_api: - raise Exception("Cync API not available after connection setup") - - logging.info( - f"Attempt {attempt + 1}/{max_retries}: Getting devices from Cync API..." - ) - devices = self.cync_api.get_devices() - - if not devices: - raise Exception("No devices returned from Cync API") - - logging.info( - f"Devices returned: {[getattr(d, 'name', 'unnamed') for d in devices]}" - ) - - light = next( - ( - d - for d in devices - if hasattr(d, "name") and d.name == self.cync_device_name - ), - None, - ) - - if not light: - available_devices = [ - getattr(d, "name", "unnamed") for d in devices - ] - raise Exception( - f"Device '{self.cync_device_name}' not found. Available devices: {available_devices}" - ) - - logging.info( - f"Selected device: {getattr(light, 'name', 'unnamed')}" - ) - - # Execute device operations - operations_completed = [] - - # Set power - if power == "on": - result = await light.turn_on() - operations_completed.append(f"turn_on: {result}") - else: - result = await light.turn_off() - operations_completed.append(f"turn_off: {result}") - - # Set brightness - if "brightness" in state: - result = await light.set_brightness(brightness) - operations_completed.append( - f"set_brightness({brightness}): {result}" - ) - - # Set color - if rgb: - result = await light.set_rgb(rgb) - operations_completed.append(f"set_rgb({rgb}): {result}") - - logging.info( - f"All operations completed successfully: {operations_completed}" - ) - break # Success, exit retry loop - - except ( - aiohttp.ClientConnectionError, - aiohttp.ClientOSError, - aiohttp.ServerDisconnectedError, - aiohttp.ClientConnectorError, - ConnectionResetError, - ConnectionError, - OSError, - asyncio.TimeoutError, - ) as e: - last_exception = e - logging.warning( - f"Connection/network error (attempt {attempt + 1}/{max_retries}): {type(e).__name__}: {e}" - ) - if attempt < max_retries - 1: - # Wait a bit before retry to allow network/server recovery - await asyncio.sleep( - 2**attempt - ) # Exponential backoff: 1s, 2s, 4s - continue - - except (AuthFailedError, TwoFactorRequiredError) as e: - last_exception = e - logging.error( - f"Authentication error (attempt {attempt + 1}/{max_retries}): {e}" - ) - if attempt < max_retries - 1: - # Clear cached tokens on auth errors - try: - os.remove(self.token_cache_path) - logging.info("Cleared token cache due to auth error") - except (OSError, FileNotFoundError): - pass - await asyncio.sleep(1) - continue - - except Exception as e: - last_exception = e - error_msg = f"Unexpected error (attempt {attempt + 1}/{max_retries}): {type(e).__name__}: {e}" - logging.error(error_msg) - - # On unexpected errors, try reconnecting for next attempt - if attempt < max_retries - 1: - logging.warning( - "Forcing full reconnection due to unexpected error..." - ) - await asyncio.sleep(1) - continue - - # If we get here, all retries failed - logging.error( - f"All {max_retries} attempts failed. Last error: {type(last_exception).__name__}: {last_exception}" - ) - raise last_exception - - logging.info( - "Successfully applied state to device '%s': %s", - self.cync_device_name, - state, - ) - return JSONResponse( - content={ - "message": "Lighting state updated and applied", - "state": state, - } - ) + + # Apply to device with retries + await self._apply_state_with_retry(power, brightness, rgb) + + logger.info(f"Successfully applied state: power={power}, brightness={brightness}, rgb={rgb}") + return JSONResponse(content={ + "message": "Lighting state updated", + "state": state, + }) + except HTTPException: raise except Exception as e: - logging.error("Error setting lighting state: %s", e) - raise HTTPException(status_code=500, detail="Internal server error") + logger.error(f"Error setting lighting state: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + def _parse_state(self, state: dict) -> tuple[str, Optional[int], Optional[tuple]]: + """Parse and validate lighting state from request.""" + # Power + power = state.get("power", "off") + if power not in ("on", "off"): + raise HTTPException(status_code=400, detail=f"Invalid power: {power}") + + # Brightness + brightness = None + if "brightness" in state: + brightness = state["brightness"] + if not isinstance(brightness, (int, float)) or not (0 <= brightness <= 100): + raise HTTPException(status_code=400, detail=f"Invalid brightness: {brightness}") + brightness = int(brightness) + + # Color + rgb = None + color = state.get("color") + if color and isinstance(color, dict) and all(k in color for k in ("r", "g", "b")): + rgb = (color["r"], color["g"], color["b"]) + elif all(k in state for k in ("red", "green", "blue")): + rgb = (state["red"], state["green"], state["blue"]) + + if rgb: + for i, name in enumerate(("red", "green", "blue")): + if not isinstance(rgb[i], int) or not (0 <= rgb[i] <= 255): + raise HTTPException(status_code=400, detail=f"Invalid {name}: {rgb[i]}") + + return power, brightness, rgb + + async def _apply_state_with_retry( + self, + power: str, + brightness: Optional[int], + rgb: Optional[tuple], + ) -> None: + """Apply state to device with connection retry logic.""" + last_error: Optional[Exception] = None + + for attempt in range(self.MAX_RETRIES): + try: + # Ensure connection (force reconnect on retries) + await self._connect(force=(attempt > 0)) + + # Send commands + await self._send_commands(power, brightness, rgb) + return # Success + + except (AuthFailedError, TwoFactorRequiredError) as e: + last_error = e + logger.warning(f"Auth error on attempt {attempt + 1}: {e}") + self._clear_cached_token() + + except TimeoutError as e: + last_error = e + logger.warning(f"Timeout on attempt {attempt + 1}: {e}") + + except Exception as e: + last_error = e + logger.warning(f"Error on attempt {attempt + 1}: {type(e).__name__}: {e}") + + # Wait before retry (exponential backoff) + if attempt < self.MAX_RETRIES - 1: + wait_time = 2 ** attempt + logger.info(f"Retrying in {wait_time}s...") + await asyncio.sleep(wait_time) + + # All retries failed + logger.error(f"All {self.MAX_RETRIES} attempts failed") + raise last_error or RuntimeError("Failed to apply lighting state") diff --git a/endpoints/misc.py b/endpoints/misc.py index d759a85..6693a24 100644 --- a/endpoints/misc.py +++ b/endpoints/misc.py @@ -62,6 +62,7 @@ class Misc(FastAPI): self.upload_activity_image, methods=["POST"], dependencies=[Depends(RateLimiter(times=10, seconds=2))], + include_in_schema=False, ) logging.debug("Loading NaaS reasons") diff --git a/endpoints/radio.py b/endpoints/radio.py index be1d743..773826c 100644 --- a/endpoints/radio.py +++ b/endpoints/radio.py @@ -47,12 +47,12 @@ class Radio(FastAPI): self.sr_util = SRUtil() self.lrclib = LRCLib() self.lrc_cache: Dict[str, Optional[str]] = {} - self.lrc_cache_locks = {} + self.lrc_cache_locks: Dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self.playlists_loaded: bool = False # WebSocket connection management self.active_connections: Dict[str, Set[WebSocket]] = {} # Initialize broadcast locks to prevent duplicate events - self.broadcast_locks = defaultdict(asyncio.Lock) + self.broadcast_locks: Dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self.endpoints: dict = { "radio/np": self.radio_now_playing, "radio/request": self.radio_request, @@ -71,9 +71,9 @@ class Radio(FastAPI): if endpoint == "radio/album_art": methods = ["GET"] app.add_api_route( - f"/{endpoint}", handler, methods=methods, include_in_schema=True, + f"/{endpoint}", handler, methods=methods, include_in_schema=False, dependencies=[Depends( - RateLimiter(times=25, seconds=2))] if not endpoint == "radio/np" else None, + RateLimiter(times=25, seconds=2))], ) # Add WebSocket route @@ -83,12 +83,8 @@ class Radio(FastAPI): app.add_websocket_route("/radio/ws/{station}", websocket_route_handler) - app.add_event_handler("startup", self.on_start) - async def on_start(self) -> None: - # Initialize locks in the event loop - self.lrc_cache_locks = defaultdict(asyncio.Lock) - self.broadcast_locks = defaultdict(asyncio.Lock) + # Load playlists for all stations stations = ", ".join(self.radio_util.db_queries.keys()) logging.info("radio: Initializing stations:\n%s", stations) await self.radio_util.load_playlists() diff --git a/endpoints/rip.py b/endpoints/rip.py index dc7d113..1b317b7 100644 --- a/endpoints/rip.py +++ b/endpoints/rip.py @@ -1,5 +1,5 @@ import logging -from fastapi import FastAPI, Request, Response, Depends +from fastapi import FastAPI, Request, Response, Depends, HTTPException from fastapi_throttle import RateLimiter from fastapi.responses import JSONResponse from utils.sr_wrapper import SRUtil @@ -63,22 +63,42 @@ class RIP(FastAPI): "trip/bulk_fetch": self.bulk_fetch_handler, "trip/job/{job_id:path}": self.job_status_handler, "trip/jobs/list": self.job_list_handler, + "trip/auth/start": self.tidal_auth_start_handler, + "trip/auth/check": self.tidal_auth_check_handler, } + # Store pending device codes for auth flow + self._pending_device_codes: dict[str, str] = {} + for endpoint, handler in self.endpoints.items(): dependencies = [Depends(RateLimiter(times=8, seconds=2))] app.add_api_route( f"/{endpoint}", handler, - methods=["GET"] if endpoint != "trip/bulk_fetch" else ["POST"], + methods=["GET"] if endpoint not in ("trip/bulk_fetch", "trip/auth/check") else ["POST"], include_in_schema=False, dependencies=dependencies, ) + async def startup(self) -> None: + """Initialize Tidal keepalive. Call this from your app's lifespan context manager.""" + try: + await self.trip_util.start_keepalive() + logger.info("Tidal keepalive task started successfully") + except Exception as e: + logger.error(f"Failed to start Tidal keepalive task: {e}") + + async def shutdown(self) -> None: + """Stop Tidal keepalive. Call this from your app's lifespan context manager.""" + try: + await self.trip_util.stop_keepalive() + logger.info("Tidal keepalive task stopped successfully") + except Exception as e: + logger.error(f"Error stopping Tidal keepalive task: {e}") + def _format_job(self, job: Job): """ Helper to normalize job data into JSON. - Parameters: - job (Job): The job object to format. @@ -132,6 +152,8 @@ class RIP(FastAPI): Returns: - **Response**: JSON response with artists or 404. """ + if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []): + raise HTTPException(status_code=403, detail="Insufficient permissions") # support optional grouping to return one primary per display name # with `alternatives` for disambiguation (use ?group=true) group = bool(request.query_params.get("group", False)) @@ -154,6 +176,8 @@ class RIP(FastAPI): Returns: - **Response**: JSON response with albums or 404. """ + if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []): + raise HTTPException(status_code=403, detail="Insufficient permissions") albums = await self.trip_util.get_albums_by_artist_id(artist_id) if not albums: return Response(status_code=404, content="Not found") @@ -178,6 +202,8 @@ class RIP(FastAPI): Returns: - **Response**: JSON response with tracks or 404. """ + if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []): + raise HTTPException(status_code=403, detail="Insufficient permissions") tracks = await self.trip_util.get_tracks_by_album_id(album_id, quality) if not tracks: return Response(status_code=404, content="Not Found") @@ -198,6 +224,8 @@ class RIP(FastAPI): Returns: - **Response**: JSON response with tracks or 404. """ + if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []): + raise HTTPException(status_code=403, detail="Insufficient permissions") logging.critical("Searching for tracks by artist: %s, song: %s", artist, song) tracks = await self.trip_util.get_tracks_by_artist_song(artist, song) if not tracks: @@ -223,6 +251,8 @@ class RIP(FastAPI): Returns: - **Response**: JSON response with stream URL or 404. """ + if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []): + raise HTTPException(status_code=403, detail="Insufficient permissions") track = await self.trip_util.get_stream_url_by_track_id(track_id, quality) if not track: return Response(status_code=404, content="Not found") @@ -245,6 +275,8 @@ class RIP(FastAPI): Returns: - **Response**: JSON response with job info or error. """ + if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []): + raise HTTPException(status_code=403, detail="Insufficient permissions") if not data or not data.track_ids or not data.target: return JSONResponse( content={ @@ -296,7 +328,8 @@ class RIP(FastAPI): Returns: - **JSONResponse**: Job status and result or error. """ - + if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []): + raise HTTPException(status_code=403, detail="Insufficient permissions") job = None try: # Try direct fetch first @@ -334,6 +367,8 @@ class RIP(FastAPI): Returns: - **JSONResponse**: List of jobs. """ + if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []): + raise HTTPException(status_code=403, detail="Insufficient permissions") jobs_info = [] seen = set() @@ -385,3 +420,79 @@ class RIP(FastAPI): jobs_info.sort(key=job_sort_key, reverse=True) return {"jobs": jobs_info} + + async def tidal_auth_start_handler( + self, request: Request, user=Depends(get_current_user) + ): + """ + Start Tidal device authorization flow. + + Returns a URL that the user must visit to authorize the application. + After visiting the URL and authorizing, call /trip/auth/check to complete. + + Returns: + - **JSONResponse**: Contains device_code and verification_url. + """ + try: + if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []): + raise HTTPException(status_code=403, detail="Insufficient permissions") + device_code, verification_url = await self.trip_util.start_device_auth() + # Store device code for this session + self._pending_device_codes[user.get("sub", "default")] = device_code + return JSONResponse( + content={ + "device_code": device_code, + "verification_url": verification_url, + "message": "Visit the URL to authorize, then call /trip/auth/check", + } + ) + except Exception as e: + logger.error("Tidal auth start failed: %s", e) + return JSONResponse( + content={"error": str(e)}, + status_code=500, + ) + + async def tidal_auth_check_handler( + self, request: Request, user=Depends(get_current_user) + ): + """ + Check if Tidal device authorization is complete. + + Call this after the user has visited the verification URL. + + Returns: + - **JSONResponse**: Contains success status and message. + """ + if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []): + raise HTTPException(status_code=403, detail="Insufficient permissions") + device_code = self._pending_device_codes.get(user.get("sub", "default")) + if not device_code: + return JSONResponse( + content={"error": "No pending authorization. Call /trip/auth/start first."}, + status_code=400, + ) + + try: + success, error = await self.trip_util.check_device_auth(device_code) + if success: + # Clear the pending code + self._pending_device_codes.pop(user.get("sub", "default"), None) + return JSONResponse( + content={"success": True, "message": "Tidal authorization complete!"} + ) + elif error == "pending": + return JSONResponse( + content={"success": False, "pending": True, "message": "Waiting for user to authorize..."} + ) + else: + return JSONResponse( + content={"success": False, "error": error}, + status_code=400, + ) + except Exception as e: + logger.error("Tidal auth check failed: %s", e) + return JSONResponse( + content={"error": str(e)}, + status_code=500, + ) diff --git a/endpoints/yt.py b/endpoints/yt.py index d98489b..4976bbd 100644 --- a/endpoints/yt.py +++ b/endpoints/yt.py @@ -3,9 +3,9 @@ from fastapi import FastAPI, Depends from fastapi.responses import JSONResponse from fastapi_throttle import RateLimiter from typing import Optional, Union +from utils.yt_utils import sign_video_id from .constructors import ValidYTSearchRequest - class YT(FastAPI): """ YT Endpoints @@ -57,6 +57,7 @@ class YT(FastAPI): return JSONResponse( content={ "video_id": yt_video_id, + "video_token": sign_video_id(yt_video_id) if yt_video_id else None, "extras": yts_res[0], } ) diff --git a/utils/sr_wrapper.py b/utils/sr_wrapper.py index 5d328f5..79d46de 100644 --- a/utils/sr_wrapper.py +++ b/utils/sr_wrapper.py @@ -1,24 +1,46 @@ +# isort: skip_file from typing import Optional, Any, Callable from uuid import uuid4 from urllib.parse import urlparse +from pathlib import Path import hashlib import traceback import logging import random import asyncio +import json import os import aiohttp import time -from streamrip.client import TidalClient # type: ignore -from streamrip.config import Config as StreamripConfig # type: ignore -from dotenv import load_dotenv -from rapidfuzz import fuzz + +# Monkey-patch streamrip's Tidal client credentials BEFORE importing TidalClient +import streamrip.client.tidal as _tidal_module # type: ignore # noqa: E402 +_tidal_module.CLIENT_ID = "fX2JxdmntZWK0ixT" +_tidal_module.CLIENT_SECRET = "1Nn9AfDAjxrgJFJbKNWLeAyKGVGmINuXPPLHVXAvxAg=" +_tidal_module.AUTH = aiohttp.BasicAuth( + login=_tidal_module.CLIENT_ID, + password=_tidal_module.CLIENT_SECRET +) + +from streamrip.client import TidalClient # type: ignore # noqa: E402 +from streamrip.config import Config as StreamripConfig # type: ignore # noqa: E402 +from dotenv import load_dotenv # noqa: E402 +from rapidfuzz import fuzz # noqa: E402 + +# Path to persist Tidal tokens across restarts +TIDAL_TOKEN_CACHE_PATH = Path(__file__).parent.parent / "tidal_token.json" class MetadataFetchError(Exception): """Raised when metadata fetch permanently fails after retries.""" +# How long before token expiry to proactively refresh (seconds) +TIDAL_TOKEN_REFRESH_BUFFER = 600 # 10 minutes +# Maximum age of a session before forcing a fresh login (seconds) +TIDAL_SESSION_MAX_AGE = 1800 # 30 minutes + + # Suppress noisy logging from this module and from the `streamrip` library # We set propagate=False so messages don't bubble up to the root logger and # attach a NullHandler where appropriate to avoid "No handler found" warnings. @@ -47,27 +69,11 @@ class SRUtil: def __init__(self) -> None: """Initialize StreamRip utility.""" self.streamrip_config = StreamripConfig.defaults() - self.streamrip_config.session.tidal.user_id = os.getenv("tidal_user_id", "") - self.streamrip_config.session.tidal.access_token = os.getenv( - "tidal_access_token", "" - ) - self.streamrip_config.session.tidal.refresh_token = os.getenv( - "tidal_refresh_token", "" - ) - self.streamrip_config.session.tidal.token_expiry = os.getenv( - "tidal_token_expiry", "" - ) - self.streamrip_config.session.tidal.country_code = os.getenv( - "tidal_country_code", "" - ) - self.streamrip_config.session.tidal.quality = int( - os.getenv("tidal_default_quality", 2) - ) + self._load_tidal_config() self.streamrip_config.session.conversion.enabled = False self.streamrip_config.session.downloads.folder = os.getenv( "tidal_download_folder", "" ) - self.streamrip_config self.streamrip_client = TidalClient(self.streamrip_config) self.MAX_CONCURRENT_METADATA_REQUESTS = 2 self.METADATA_RATE_LIMIT = 1.25 @@ -82,19 +88,328 @@ class SRUtil: self.on_rate_limit: Optional[Callable[[Exception], Any]] = None # Internal flag to avoid repeated notifications for the same runtime self._rate_limit_notified = False + # Track when we last successfully logged in + self._last_login_time: Optional[float] = None + # Track last successful API call + self._last_successful_request: Optional[float] = None + # Keepalive task handle + self._keepalive_task: Optional[asyncio.Task] = None + # Keepalive interval in seconds + self.KEEPALIVE_INTERVAL = 180 # 3 minutes + + async def start_keepalive(self) -> None: + """Start the background keepalive task. + + This should be called once at startup to ensure the Tidal session + stays alive even during idle periods. + """ + if self._keepalive_task and not self._keepalive_task.done(): + logging.info("Tidal keepalive task already running") + return + + # Ensure initial login + try: + await self._login_and_persist() + logging.info("Initial Tidal login successful") + except Exception as e: + logging.warning("Initial Tidal login failed: %s", e) + + self._keepalive_task = asyncio.create_task(self._keepalive_runner()) + logging.info("Tidal keepalive task started") + + async def stop_keepalive(self) -> None: + """Stop the background keepalive task.""" + if self._keepalive_task and not self._keepalive_task.done(): + self._keepalive_task.cancel() + try: + await self._keepalive_task + except asyncio.CancelledError: + pass + logging.info("Tidal keepalive task stopped") + + async def _keepalive_runner(self) -> None: + """Background task to keep the Tidal session alive.""" + while True: + try: + await asyncio.sleep(self.KEEPALIVE_INTERVAL) + + # Check if we've had recent activity + if self._last_successful_request: + time_since_last = time.time() - self._last_successful_request + if time_since_last < self.KEEPALIVE_INTERVAL: + # Recent activity, no need to ping + continue + + # Check if token is expiring soon and proactively refresh + if self._is_token_expiring_soon(): + logging.info("Tidal keepalive: Token expiring soon, refreshing...") + try: + await self._login_and_persist(force=True) + logging.info("Tidal keepalive: Token refresh successful") + except Exception as e: + logging.warning("Tidal keepalive: Token refresh failed: %s", e) + continue + + # Check if session is stale + if self._is_session_stale(): + logging.info("Tidal keepalive: Session stale, refreshing...") + try: + await self._login_and_persist(force=True) + logging.info("Tidal keepalive: Session refresh successful") + except Exception as e: + logging.warning("Tidal keepalive: Session refresh failed: %s", e) + continue + + # Make a lightweight API call to keep the session alive + if self.streamrip_client.logged_in: + try: + # Simple search to keep the connection alive + await self._safe_api_call( + self.streamrip_client.search, + media_type="artist", + query="test", + retries=1, + ) + logging.debug("Tidal keepalive ping successful") + except Exception as e: + logging.warning("Tidal keepalive ping failed: %s", e) + # Try to refresh the session + try: + await self._login_and_persist(force=True) + except Exception: + pass + + except asyncio.CancelledError: + logging.info("Tidal keepalive task cancelled") + break + except Exception as e: + logging.error("Error in Tidal keepalive task: %s", e) + + def _load_tidal_config(self) -> None: + """Load Tidal config from cache file if available, otherwise from env.""" + tidal = self.streamrip_config.session.tidal + cached = self._load_cached_tokens() + + if cached: + tidal.user_id = cached.get("user_id", "") + tidal.access_token = cached.get("access_token", "") + tidal.refresh_token = cached.get("refresh_token", "") + tidal.token_expiry = cached.get("token_expiry", "") + tidal.country_code = cached.get("country_code", os.getenv("tidal_country_code", "")) + else: + tidal.user_id = os.getenv("tidal_user_id", "") + tidal.access_token = os.getenv("tidal_access_token", "") + tidal.refresh_token = os.getenv("tidal_refresh_token", "") + tidal.token_expiry = os.getenv("tidal_token_expiry", "") + tidal.country_code = os.getenv("tidal_country_code", "") + + tidal.quality = int(os.getenv("tidal_default_quality", 2)) + + def _load_cached_tokens(self) -> Optional[dict]: + """Load cached tokens from disk if valid.""" + try: + if TIDAL_TOKEN_CACHE_PATH.exists(): + with open(TIDAL_TOKEN_CACHE_PATH, "r") as f: + data = json.load(f) + # Validate required fields exist + if all(k in data for k in ("access_token", "refresh_token", "token_expiry")): + logging.info("Loaded Tidal tokens from cache") + return data + except Exception as e: + logging.warning("Failed to load cached Tidal tokens: %s", e) + return None + + def _save_cached_tokens(self) -> None: + """Persist current tokens to disk for use across restarts.""" + try: + tidal = self.streamrip_config.session.tidal + data = { + "user_id": tidal.user_id, + "access_token": tidal.access_token, + "refresh_token": tidal.refresh_token, + "token_expiry": tidal.token_expiry, + "country_code": tidal.country_code, + } + with open(TIDAL_TOKEN_CACHE_PATH, "w") as f: + json.dump(data, f) + logging.info("Saved Tidal tokens to cache") + except Exception as e: + logging.warning("Failed to save Tidal tokens: %s", e) + + def _apply_new_tokens(self, auth_info: dict) -> None: + """Apply new tokens from device auth to config.""" + tidal = self.streamrip_config.session.tidal + tidal.user_id = str(auth_info.get("user_id", "")) + tidal.access_token = auth_info.get("access_token", "") + tidal.refresh_token = auth_info.get("refresh_token", "") + tidal.token_expiry = auth_info.get("token_expiry", "") + tidal.country_code = auth_info.get("country_code", tidal.country_code) + self._save_cached_tokens() + + async def start_device_auth(self) -> tuple[str, str]: + """Start device authorization flow. + + Returns: + tuple: (device_code, verification_url) - User should visit the URL to authorize. + """ + if not hasattr(self.streamrip_client, 'session') or not self.streamrip_client.session: + self.streamrip_client.session = await self.streamrip_client.get_session() + + device_code, verification_url = await self.streamrip_client._get_device_code() + return device_code, verification_url + + async def check_device_auth(self, device_code: str) -> tuple[bool, Optional[str]]: + """Check if user has completed device authorization. + + Args: + device_code: The device code from start_device_auth() + + Returns: + tuple: (success, error_message) + - (True, None) if auth completed successfully + - (False, "pending") if user hasn't authorized yet + - (False, error_message) if auth failed + """ + status, auth_info = await self.streamrip_client._get_auth_status(device_code) + + if status == 0: + # Success - apply new tokens + self._apply_new_tokens(auth_info) + # Re-login with new tokens + self.streamrip_client.logged_in = False + try: + await self.streamrip_client.login() + self._save_cached_tokens() + return True, None + except Exception as e: + return False, f"Login after auth failed: {e}" + elif status == 2: + # Pending - user hasn't authorized yet + return False, "pending" + else: + # Failed + return False, "Authorization failed" + + def _is_token_expiring_soon(self) -> bool: + """Check if the token is about to expire within the buffer window.""" + tidal = self.streamrip_config.session.tidal + token_expiry = getattr(tidal, "token_expiry", None) + if not token_expiry: + return True # No expiry info means we should refresh + try: + # token_expiry is typically an ISO timestamp string + if isinstance(token_expiry, str): + from datetime import datetime + expiry_dt = datetime.fromisoformat(token_expiry.replace('Z', '+00:00')) + expiry_ts = expiry_dt.timestamp() + else: + expiry_ts = float(token_expiry) + return expiry_ts < (time.time() + TIDAL_TOKEN_REFRESH_BUFFER) + except Exception as e: + logging.warning("Failed to parse token expiry '%s': %s", token_expiry, e) + return True # Err on the side of refreshing + + def _is_session_stale(self) -> bool: + """Check if the login session is too old and should be refreshed.""" + if not self._last_login_time: + return True + session_age = time.time() - self._last_login_time + return session_age > TIDAL_SESSION_MAX_AGE + + async def _force_fresh_login(self) -> bool: + """Force a complete fresh login, ignoring logged_in state. + + Returns True if login succeeded, False otherwise. + """ + # Reset the logged_in flag to force a fresh login + self.streamrip_client.logged_in = False + + # Close existing session if present + if hasattr(self.streamrip_client, 'session') and self.streamrip_client.session: + try: + if not self.streamrip_client.session.closed: + await self.streamrip_client.session.close() + except Exception as e: + logging.warning("Error closing old session: %s", e) + # Use object.__setattr__ to bypass type checking for session reset + try: + object.__setattr__(self.streamrip_client, 'session', None) + except Exception: + pass # Session will be recreated on next login + + try: + logging.info("Forcing fresh Tidal login...") + await self.streamrip_client.login() + self._last_login_time = time.time() + self._save_cached_tokens() + logging.info("Fresh Tidal login successful") + return True + except Exception as e: + logging.warning("Forced Tidal login failed: %s - device re-auth may be required", e) + return False + + async def _login_and_persist(self, force: bool = False) -> None: + """Login to Tidal and persist any refreshed tokens. + + Args: + force: If True, force a fresh login even if already logged in. + + This method now checks for: + 1. Token expiry - refreshes if token is about to expire + 2. Session age - refreshes if session is too old + 3. logged_in state - logs in if not logged in + + If refresh fails, logs a warning but does not raise. + """ + needs_login = force or not self.streamrip_client.logged_in + + # Check if token is expiring soon + if not needs_login and self._is_token_expiring_soon(): + logging.info("Tidal token expiring soon, will refresh") + needs_login = True + + # Check if session is too old + if not needs_login and self._is_session_stale(): + logging.info("Tidal session is stale, will refresh") + needs_login = True + + if not needs_login: + return + + try: + # Reset logged_in to ensure fresh login attempt + if force or self._is_token_expiring_soon(): + self.streamrip_client.logged_in = False + + await self.streamrip_client.login() + self._last_login_time = time.time() + # After login, tokens may have been refreshed - persist them + self._save_cached_tokens() + logging.info("Tidal login/refresh successful") + except Exception as e: + logging.warning("Tidal login/refresh failed: %s - device re-auth may be required", e) + # Don't mark as logged in on failure - let subsequent calls retry async def rate_limited_request(self, func, *args, **kwargs): + """Rate-limited wrapper that also ensures login before making requests.""" async with self.METADATA_SEMAPHORE: now = time.time() elapsed = now - self.LAST_METADATA_REQUEST if elapsed < self.METADATA_RATE_LIMIT: await asyncio.sleep(self.METADATA_RATE_LIMIT - elapsed) + + # Ensure we're logged in before making the request + try: + await self._login_and_persist() + except Exception as e: + logging.warning("Pre-request login failed in rate_limited_request: %s", e) + result = await func(*args, **kwargs) self.LAST_METADATA_REQUEST = time.time() return result async def _safe_api_call( - self, func, *args, retries: int = 2, backoff: float = 0.5, **kwargs + self, func, *args, retries: int = 3, backoff: float = 0.5, **kwargs ): """Call an async API function with resilient retry behavior. @@ -103,18 +418,32 @@ class SRUtil: attempt a `login()` and retry up to `retries` times. - On 400/429 responses (message contains '400' or '429'): retry with backoff without triggering login (to avoid excessive logins). + - On 401 (Unauthorized): force a fresh login and retry. Returns the result or raises the last exception. """ last_exc: Optional[Exception] = None for attempt in range(retries): try: - return await func(*args, **kwargs) + # Before each attempt, ensure we have a valid session + if attempt == 0: + # On first attempt, try to ensure logged in (checks token expiry) + # Wrapped in try/except so login failures don't block the API call + try: + await self._login_and_persist() + except Exception as login_err: + logging.warning("Pre-request login failed: %s (continuing anyway)", login_err) + + result = await func(*args, **kwargs) + # Track successful request + self._last_successful_request = time.time() + return result except AttributeError as e: # Probably missing/closed client internals: try re-login once last_exc = e + logging.warning("AttributeError in API call (attempt %d/%d): %s", attempt + 1, retries, e) try: - await self.streamrip_client.login() + await self._force_fresh_login() except Exception: pass continue @@ -144,6 +473,31 @@ class SRUtil: await asyncio.sleep(backoff * (2**attempt)) continue + # Treat 401 (Unauthorized) as an auth failure: force a fresh re-login then retry + is_401_error = ( + (isinstance(e, aiohttp.ClientResponseError) and getattr(e, "status", None) == 401) + or "401" in msg + or "unauthorized" in msg.lower() + ) + if is_401_error: + logging.warning( + "Received 401/Unauthorized from Tidal (attempt %d/%d). Forcing fresh re-login...", + attempt + 1, + retries, + ) + try: + # Use force=True to ensure we actually re-authenticate + login_success = await self._force_fresh_login() + if login_success: + logging.info("Forced re-login after 401 successful") + else: + logging.warning("Forced re-login after 401 failed - may need device re-auth") + except Exception as login_exc: + logging.warning("Forced login after 401 failed: %s", login_exc) + if attempt < retries - 1: + await asyncio.sleep(backoff * (2**attempt)) + continue + # Connection related errors — try to re-login then retry if ( isinstance( @@ -159,7 +513,7 @@ class SRUtil: or "closed" in msg.lower() ): try: - await self.streamrip_client.login() + await self._login_and_persist(force=True) except Exception: pass if attempt < retries - 1: @@ -434,8 +788,6 @@ class SRUtil: async def get_albums_by_artist_id(self, artist_id: int) -> Optional[list | dict]: """Get albums by artist ID. Retry login only on authentication failure. Rate limit and retry on 400/429.""" - import asyncio - artist_id_str: str = str(artist_id) albums_out: list[dict] = [] max_retries = 4 @@ -585,26 +937,26 @@ class SRUtil: TODO: Reimplement using StreamRip """ try: + # _safe_api_call already handles login, no need to call it here search_res = await self._safe_api_call( self.streamrip_client.search, media_type="track", query=f"{artist} - {song}", retries=3, ) - logging.critical("Result: %s", search_res) + logging.debug("Search result: %s", search_res) return ( search_res[0].get("items") if search_res and isinstance(search_res, list) else [] ) except Exception as e: - traceback.print_exc() - logging.critical("Search Exception: %s", str(e)) - if n < 3: + logging.warning("Search Exception: %s", str(e)) + if n < 2: # Reduce max retries from 3 to 2 n += 1 + await asyncio.sleep(0.5 * n) # Add backoff return await self.get_tracks_by_artist_song(artist, song, n) return [] - # return [] async def get_stream_url_by_track_id( self, track_id: int, quality: str = "FLAC" @@ -655,7 +1007,6 @@ class SRUtil: """ for attempt in range(1, self.MAX_METADATA_RETRIES + 1): try: - await self._safe_api_call(self.streamrip_client.login, retries=1) # Track metadata metadata = await self.rate_limited_request( self.streamrip_client.get_metadata, str(track_id), "track" @@ -734,7 +1085,6 @@ class SRUtil: bool """ try: - await self._safe_api_call(self.streamrip_client.login, retries=1) track_url = await self.get_stream_url_by_track_id(track_id) if not track_url: return False diff --git a/utils/yt_utils.py b/utils/yt_utils.py new file mode 100644 index 0000000..a39cef4 --- /dev/null +++ b/utils/yt_utils.py @@ -0,0 +1,25 @@ +from typing import Optional +import hmac +import hashlib +import time +import base64 +import os + +VIDEO_PROXY_SECRET = os.environ.get("VIDEO_PROXY_SECRET", "").encode() + +def sign_video_id(video_id: Optional[str|bool]) -> str: + """Generate a signed token for a video ID.""" + if not VIDEO_PROXY_SECRET or not video_id: + return "" # Return empty if no secret configured + + timestamp = int(time.time() * 1000) # milliseconds to match JS Date.now() + payload = f"{video_id}:{timestamp}" + signature = hmac.new( + VIDEO_PROXY_SECRET, + payload.encode(), + hashlib.sha256 + ).hexdigest() + + token_data = f"{payload}:{signature}" + # base64url encode (no padding, to match JS base64url) + return base64.urlsafe_b64encode(token_data.encode()).decode().rstrip("=") \ No newline at end of file