misc
This commit is contained in:
69
base.py
69
base.py
@@ -4,6 +4,7 @@ import sys
|
|||||||
sys.path.insert(0, ".")
|
sys.path.insert(0, ".")
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@@ -13,20 +14,61 @@ from lyric_search.sources import redis_cache
|
|||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logging.getLogger("aiosqlite").setLevel(logging.WARNING)
|
logging.getLogger("aiosqlite").setLevel(logging.WARNING)
|
||||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("python_multipart.multipart").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("streamrip").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("utils.sr_wrapper").setLevel(logging.WARNING)
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
# Pre-import endpoint modules so we can wire up lifespan
|
||||||
|
constants = importlib.import_module("constants").Constants()
|
||||||
|
|
||||||
|
# Will be set after app creation
|
||||||
|
_routes: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Lifespan context manager for startup/shutdown events."""
|
||||||
|
# Startup
|
||||||
|
uvicorn_access_logger = logging.getLogger("uvicorn.access")
|
||||||
|
uvicorn_access_logger.disabled = True
|
||||||
|
|
||||||
|
# Start Radio playlists
|
||||||
|
if "radio" in _routes and hasattr(_routes["radio"], "on_start"):
|
||||||
|
await _routes["radio"].on_start()
|
||||||
|
|
||||||
|
# Start endpoint background tasks
|
||||||
|
if "trip" in _routes and hasattr(_routes["trip"], "startup"):
|
||||||
|
await _routes["trip"].startup()
|
||||||
|
if "lighting" in _routes and hasattr(_routes["lighting"], "startup"):
|
||||||
|
await _routes["lighting"].startup()
|
||||||
|
|
||||||
|
logger.info("Application startup complete")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown
|
||||||
|
if "lighting" in _routes and hasattr(_routes["lighting"], "shutdown"):
|
||||||
|
await _routes["lighting"].shutdown()
|
||||||
|
if "trip" in _routes and hasattr(_routes["trip"], "shutdown"):
|
||||||
|
await _routes["trip"].shutdown()
|
||||||
|
|
||||||
|
logger.info("Application shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="codey.lol API",
|
title="codey.lol API",
|
||||||
version="1.0",
|
version="1.0",
|
||||||
contact={"name": "codey"},
|
contact={"name": "codey"},
|
||||||
redirect_slashes=False,
|
redirect_slashes=False,
|
||||||
loop=loop,
|
loop=loop,
|
||||||
docs_url="/docs", # Swagger UI (default)
|
docs_url=None, # Disabled - using Scalar at /docs instead
|
||||||
redoc_url="/redoc", # ReDoc UI (default, but explicitly set)
|
redoc_url="/redoc",
|
||||||
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
constants = importlib.import_module("constants").Constants()
|
|
||||||
util = importlib.import_module("util").Utilities(app, constants)
|
util = importlib.import_module("util").Utilities(app, constants)
|
||||||
|
|
||||||
origins = [
|
origins = [
|
||||||
@@ -48,8 +90,8 @@ app.add_middleware(
|
|||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# Add Scalar API documentation endpoint (before blacklist routes)
|
# Scalar API documentation at /docs (replaces default Swagger UI)
|
||||||
@app.get("/scalar", include_in_schema=False)
|
@app.get("/docs", include_in_schema=False)
|
||||||
def scalar_docs():
|
def scalar_docs():
|
||||||
return get_scalar_api_reference(openapi_url="/openapi.json", title="codey.lol API")
|
return get_scalar_api_reference(openapi_url="/openapi.json", title="codey.lol API")
|
||||||
|
|
||||||
@@ -72,7 +114,7 @@ def base_head():
|
|||||||
@app.get("/{path}", include_in_schema=False)
|
@app.get("/{path}", include_in_schema=False)
|
||||||
def disallow_get_any(request: Request, var: Any = None):
|
def disallow_get_any(request: Request, var: Any = None):
|
||||||
path = request.path_params["path"]
|
path = request.path_params["path"]
|
||||||
allowed_paths = ["widget", "misc/no", "docs", "redoc", "scalar", "openapi.json"]
|
allowed_paths = ["widget", "misc/no", "docs", "redoc", "openapi.json"]
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Checking path: {path}, allowed: {path in allowed_paths or path.split('/', maxsplit=1)[0] in allowed_paths}"
|
f"Checking path: {path}, allowed: {path in allowed_paths or path.split('/', maxsplit=1)[0] in allowed_paths}"
|
||||||
)
|
)
|
||||||
@@ -99,7 +141,7 @@ End Blacklisted Routes
|
|||||||
Actionable Routes
|
Actionable Routes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
routes: dict = {
|
_routes.update({
|
||||||
"randmsg": importlib.import_module("endpoints.rand_msg").RandMsg(
|
"randmsg": importlib.import_module("endpoints.rand_msg").RandMsg(
|
||||||
app, util, constants
|
app, util, constants
|
||||||
),
|
),
|
||||||
@@ -116,12 +158,12 @@ routes: dict = {
|
|||||||
"lighting": importlib.import_module("endpoints.lighting").Lighting(
|
"lighting": importlib.import_module("endpoints.lighting").Lighting(
|
||||||
app, util, constants
|
app, util, constants
|
||||||
),
|
),
|
||||||
}
|
})
|
||||||
|
|
||||||
# Misc endpoint depends on radio endpoint instance
|
# Misc endpoint depends on radio endpoint instance
|
||||||
radio_endpoint = routes.get("radio")
|
radio_endpoint = _routes.get("radio")
|
||||||
if radio_endpoint:
|
if radio_endpoint:
|
||||||
routes["misc"] = importlib.import_module("endpoints.misc").Misc(
|
_routes["misc"] = importlib.import_module("endpoints.misc").Misc(
|
||||||
app, util, constants, radio_endpoint
|
app, util, constants, radio_endpoint
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -133,12 +175,5 @@ End Actionable Routes
|
|||||||
Startup
|
Startup
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
async def on_start():
|
|
||||||
uvicorn_access_logger = logging.getLogger("uvicorn.access")
|
|
||||||
uvicorn_access_logger.disabled = True
|
|
||||||
|
|
||||||
|
|
||||||
app.add_event_handler("startup", on_start)
|
|
||||||
redis = redis_cache.RedisCache()
|
redis = redis_cache.RedisCache()
|
||||||
loop.create_task(redis.create_index())
|
loop.create_task(redis.create_index())
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -62,6 +62,7 @@ class Misc(FastAPI):
|
|||||||
self.upload_activity_image,
|
self.upload_activity_image,
|
||||||
methods=["POST"],
|
methods=["POST"],
|
||||||
dependencies=[Depends(RateLimiter(times=10, seconds=2))],
|
dependencies=[Depends(RateLimiter(times=10, seconds=2))],
|
||||||
|
include_in_schema=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.debug("Loading NaaS reasons")
|
logging.debug("Loading NaaS reasons")
|
||||||
|
|||||||
@@ -47,12 +47,12 @@ class Radio(FastAPI):
|
|||||||
self.sr_util = SRUtil()
|
self.sr_util = SRUtil()
|
||||||
self.lrclib = LRCLib()
|
self.lrclib = LRCLib()
|
||||||
self.lrc_cache: Dict[str, Optional[str]] = {}
|
self.lrc_cache: Dict[str, Optional[str]] = {}
|
||||||
self.lrc_cache_locks = {}
|
self.lrc_cache_locks: Dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||||
self.playlists_loaded: bool = False
|
self.playlists_loaded: bool = False
|
||||||
# WebSocket connection management
|
# WebSocket connection management
|
||||||
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
||||||
# Initialize broadcast locks to prevent duplicate events
|
# Initialize broadcast locks to prevent duplicate events
|
||||||
self.broadcast_locks = defaultdict(asyncio.Lock)
|
self.broadcast_locks: Dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||||
self.endpoints: dict = {
|
self.endpoints: dict = {
|
||||||
"radio/np": self.radio_now_playing,
|
"radio/np": self.radio_now_playing,
|
||||||
"radio/request": self.radio_request,
|
"radio/request": self.radio_request,
|
||||||
@@ -71,9 +71,9 @@ class Radio(FastAPI):
|
|||||||
if endpoint == "radio/album_art":
|
if endpoint == "radio/album_art":
|
||||||
methods = ["GET"]
|
methods = ["GET"]
|
||||||
app.add_api_route(
|
app.add_api_route(
|
||||||
f"/{endpoint}", handler, methods=methods, include_in_schema=True,
|
f"/{endpoint}", handler, methods=methods, include_in_schema=False,
|
||||||
dependencies=[Depends(
|
dependencies=[Depends(
|
||||||
RateLimiter(times=25, seconds=2))] if not endpoint == "radio/np" else None,
|
RateLimiter(times=25, seconds=2))],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add WebSocket route
|
# Add WebSocket route
|
||||||
@@ -83,12 +83,8 @@ class Radio(FastAPI):
|
|||||||
|
|
||||||
app.add_websocket_route("/radio/ws/{station}", websocket_route_handler)
|
app.add_websocket_route("/radio/ws/{station}", websocket_route_handler)
|
||||||
|
|
||||||
app.add_event_handler("startup", self.on_start)
|
|
||||||
|
|
||||||
async def on_start(self) -> None:
|
async def on_start(self) -> None:
|
||||||
# Initialize locks in the event loop
|
# Load playlists for all stations
|
||||||
self.lrc_cache_locks = defaultdict(asyncio.Lock)
|
|
||||||
self.broadcast_locks = defaultdict(asyncio.Lock)
|
|
||||||
stations = ", ".join(self.radio_util.db_queries.keys())
|
stations = ", ".join(self.radio_util.db_queries.keys())
|
||||||
logging.info("radio: Initializing stations:\n%s", stations)
|
logging.info("radio: Initializing stations:\n%s", stations)
|
||||||
await self.radio_util.load_playlists()
|
await self.radio_util.load_playlists()
|
||||||
|
|||||||
119
endpoints/rip.py
119
endpoints/rip.py
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from fastapi import FastAPI, Request, Response, Depends
|
from fastapi import FastAPI, Request, Response, Depends, HTTPException
|
||||||
from fastapi_throttle import RateLimiter
|
from fastapi_throttle import RateLimiter
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from utils.sr_wrapper import SRUtil
|
from utils.sr_wrapper import SRUtil
|
||||||
@@ -63,22 +63,42 @@ class RIP(FastAPI):
|
|||||||
"trip/bulk_fetch": self.bulk_fetch_handler,
|
"trip/bulk_fetch": self.bulk_fetch_handler,
|
||||||
"trip/job/{job_id:path}": self.job_status_handler,
|
"trip/job/{job_id:path}": self.job_status_handler,
|
||||||
"trip/jobs/list": self.job_list_handler,
|
"trip/jobs/list": self.job_list_handler,
|
||||||
|
"trip/auth/start": self.tidal_auth_start_handler,
|
||||||
|
"trip/auth/check": self.tidal_auth_check_handler,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Store pending device codes for auth flow
|
||||||
|
self._pending_device_codes: dict[str, str] = {}
|
||||||
|
|
||||||
for endpoint, handler in self.endpoints.items():
|
for endpoint, handler in self.endpoints.items():
|
||||||
dependencies = [Depends(RateLimiter(times=8, seconds=2))]
|
dependencies = [Depends(RateLimiter(times=8, seconds=2))]
|
||||||
app.add_api_route(
|
app.add_api_route(
|
||||||
f"/{endpoint}",
|
f"/{endpoint}",
|
||||||
handler,
|
handler,
|
||||||
methods=["GET"] if endpoint != "trip/bulk_fetch" else ["POST"],
|
methods=["GET"] if endpoint not in ("trip/bulk_fetch", "trip/auth/check") else ["POST"],
|
||||||
include_in_schema=False,
|
include_in_schema=False,
|
||||||
dependencies=dependencies,
|
dependencies=dependencies,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def startup(self) -> None:
|
||||||
|
"""Initialize Tidal keepalive. Call this from your app's lifespan context manager."""
|
||||||
|
try:
|
||||||
|
await self.trip_util.start_keepalive()
|
||||||
|
logger.info("Tidal keepalive task started successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start Tidal keepalive task: {e}")
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""Stop Tidal keepalive. Call this from your app's lifespan context manager."""
|
||||||
|
try:
|
||||||
|
await self.trip_util.stop_keepalive()
|
||||||
|
logger.info("Tidal keepalive task stopped successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error stopping Tidal keepalive task: {e}")
|
||||||
|
|
||||||
def _format_job(self, job: Job):
|
def _format_job(self, job: Job):
|
||||||
"""
|
"""
|
||||||
Helper to normalize job data into JSON.
|
Helper to normalize job data into JSON.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- job (Job): The job object to format.
|
- job (Job): The job object to format.
|
||||||
|
|
||||||
@@ -132,6 +152,8 @@ class RIP(FastAPI):
|
|||||||
Returns:
|
Returns:
|
||||||
- **Response**: JSON response with artists or 404.
|
- **Response**: JSON response with artists or 404.
|
||||||
"""
|
"""
|
||||||
|
if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []):
|
||||||
|
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||||
# support optional grouping to return one primary per display name
|
# support optional grouping to return one primary per display name
|
||||||
# with `alternatives` for disambiguation (use ?group=true)
|
# with `alternatives` for disambiguation (use ?group=true)
|
||||||
group = bool(request.query_params.get("group", False))
|
group = bool(request.query_params.get("group", False))
|
||||||
@@ -154,6 +176,8 @@ class RIP(FastAPI):
|
|||||||
Returns:
|
Returns:
|
||||||
- **Response**: JSON response with albums or 404.
|
- **Response**: JSON response with albums or 404.
|
||||||
"""
|
"""
|
||||||
|
if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []):
|
||||||
|
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||||
albums = await self.trip_util.get_albums_by_artist_id(artist_id)
|
albums = await self.trip_util.get_albums_by_artist_id(artist_id)
|
||||||
if not albums:
|
if not albums:
|
||||||
return Response(status_code=404, content="Not found")
|
return Response(status_code=404, content="Not found")
|
||||||
@@ -178,6 +202,8 @@ class RIP(FastAPI):
|
|||||||
Returns:
|
Returns:
|
||||||
- **Response**: JSON response with tracks or 404.
|
- **Response**: JSON response with tracks or 404.
|
||||||
"""
|
"""
|
||||||
|
if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []):
|
||||||
|
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||||
tracks = await self.trip_util.get_tracks_by_album_id(album_id, quality)
|
tracks = await self.trip_util.get_tracks_by_album_id(album_id, quality)
|
||||||
if not tracks:
|
if not tracks:
|
||||||
return Response(status_code=404, content="Not Found")
|
return Response(status_code=404, content="Not Found")
|
||||||
@@ -198,6 +224,8 @@ class RIP(FastAPI):
|
|||||||
Returns:
|
Returns:
|
||||||
- **Response**: JSON response with tracks or 404.
|
- **Response**: JSON response with tracks or 404.
|
||||||
"""
|
"""
|
||||||
|
if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []):
|
||||||
|
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||||
logging.critical("Searching for tracks by artist: %s, song: %s", artist, song)
|
logging.critical("Searching for tracks by artist: %s, song: %s", artist, song)
|
||||||
tracks = await self.trip_util.get_tracks_by_artist_song(artist, song)
|
tracks = await self.trip_util.get_tracks_by_artist_song(artist, song)
|
||||||
if not tracks:
|
if not tracks:
|
||||||
@@ -223,6 +251,8 @@ class RIP(FastAPI):
|
|||||||
Returns:
|
Returns:
|
||||||
- **Response**: JSON response with stream URL or 404.
|
- **Response**: JSON response with stream URL or 404.
|
||||||
"""
|
"""
|
||||||
|
if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []):
|
||||||
|
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||||
track = await self.trip_util.get_stream_url_by_track_id(track_id, quality)
|
track = await self.trip_util.get_stream_url_by_track_id(track_id, quality)
|
||||||
if not track:
|
if not track:
|
||||||
return Response(status_code=404, content="Not found")
|
return Response(status_code=404, content="Not found")
|
||||||
@@ -245,6 +275,8 @@ class RIP(FastAPI):
|
|||||||
Returns:
|
Returns:
|
||||||
- **Response**: JSON response with job info or error.
|
- **Response**: JSON response with job info or error.
|
||||||
"""
|
"""
|
||||||
|
if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []):
|
||||||
|
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||||
if not data or not data.track_ids or not data.target:
|
if not data or not data.track_ids or not data.target:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={
|
content={
|
||||||
@@ -296,7 +328,8 @@ class RIP(FastAPI):
|
|||||||
Returns:
|
Returns:
|
||||||
- **JSONResponse**: Job status and result or error.
|
- **JSONResponse**: Job status and result or error.
|
||||||
"""
|
"""
|
||||||
|
if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []):
|
||||||
|
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||||
job = None
|
job = None
|
||||||
try:
|
try:
|
||||||
# Try direct fetch first
|
# Try direct fetch first
|
||||||
@@ -334,6 +367,8 @@ class RIP(FastAPI):
|
|||||||
Returns:
|
Returns:
|
||||||
- **JSONResponse**: List of jobs.
|
- **JSONResponse**: List of jobs.
|
||||||
"""
|
"""
|
||||||
|
if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []):
|
||||||
|
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||||
jobs_info = []
|
jobs_info = []
|
||||||
seen = set()
|
seen = set()
|
||||||
|
|
||||||
@@ -385,3 +420,79 @@ class RIP(FastAPI):
|
|||||||
jobs_info.sort(key=job_sort_key, reverse=True)
|
jobs_info.sort(key=job_sort_key, reverse=True)
|
||||||
|
|
||||||
return {"jobs": jobs_info}
|
return {"jobs": jobs_info}
|
||||||
|
|
||||||
|
async def tidal_auth_start_handler(
|
||||||
|
self, request: Request, user=Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Start Tidal device authorization flow.
|
||||||
|
|
||||||
|
Returns a URL that the user must visit to authorize the application.
|
||||||
|
After visiting the URL and authorizing, call /trip/auth/check to complete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- **JSONResponse**: Contains device_code and verification_url.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []):
|
||||||
|
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||||
|
device_code, verification_url = await self.trip_util.start_device_auth()
|
||||||
|
# Store device code for this session
|
||||||
|
self._pending_device_codes[user.get("sub", "default")] = device_code
|
||||||
|
return JSONResponse(
|
||||||
|
content={
|
||||||
|
"device_code": device_code,
|
||||||
|
"verification_url": verification_url,
|
||||||
|
"message": "Visit the URL to authorize, then call /trip/auth/check",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Tidal auth start failed: %s", e)
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": str(e)},
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def tidal_auth_check_handler(
|
||||||
|
self, request: Request, user=Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Check if Tidal device authorization is complete.
|
||||||
|
|
||||||
|
Call this after the user has visited the verification URL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- **JSONResponse**: Contains success status and message.
|
||||||
|
"""
|
||||||
|
if "trip" not in user.get("roles", []) and "admin" not in user.get("roles", []):
|
||||||
|
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||||
|
device_code = self._pending_device_codes.get(user.get("sub", "default"))
|
||||||
|
if not device_code:
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": "No pending authorization. Call /trip/auth/start first."},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
success, error = await self.trip_util.check_device_auth(device_code)
|
||||||
|
if success:
|
||||||
|
# Clear the pending code
|
||||||
|
self._pending_device_codes.pop(user.get("sub", "default"), None)
|
||||||
|
return JSONResponse(
|
||||||
|
content={"success": True, "message": "Tidal authorization complete!"}
|
||||||
|
)
|
||||||
|
elif error == "pending":
|
||||||
|
return JSONResponse(
|
||||||
|
content={"success": False, "pending": True, "message": "Waiting for user to authorize..."}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return JSONResponse(
|
||||||
|
content={"success": False, "error": error},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Tidal auth check failed: %s", e)
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": str(e)},
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ from fastapi import FastAPI, Depends
|
|||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi_throttle import RateLimiter
|
from fastapi_throttle import RateLimiter
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
from utils.yt_utils import sign_video_id
|
||||||
from .constructors import ValidYTSearchRequest
|
from .constructors import ValidYTSearchRequest
|
||||||
|
|
||||||
|
|
||||||
class YT(FastAPI):
|
class YT(FastAPI):
|
||||||
"""
|
"""
|
||||||
YT Endpoints
|
YT Endpoints
|
||||||
@@ -57,6 +57,7 @@ class YT(FastAPI):
|
|||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={
|
content={
|
||||||
"video_id": yt_video_id,
|
"video_id": yt_video_id,
|
||||||
|
"video_token": sign_video_id(yt_video_id) if yt_video_id else None,
|
||||||
"extras": yts_res[0],
|
"extras": yts_res[0],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,24 +1,46 @@
|
|||||||
|
# isort: skip_file
|
||||||
from typing import Optional, Any, Callable
|
from typing import Optional, Any, Callable
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
from pathlib import Path
|
||||||
import hashlib
|
import hashlib
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import time
|
import time
|
||||||
from streamrip.client import TidalClient # type: ignore
|
|
||||||
from streamrip.config import Config as StreamripConfig # type: ignore
|
# Monkey-patch streamrip's Tidal client credentials BEFORE importing TidalClient
|
||||||
from dotenv import load_dotenv
|
import streamrip.client.tidal as _tidal_module # type: ignore # noqa: E402
|
||||||
from rapidfuzz import fuzz
|
_tidal_module.CLIENT_ID = "fX2JxdmntZWK0ixT"
|
||||||
|
_tidal_module.CLIENT_SECRET = "1Nn9AfDAjxrgJFJbKNWLeAyKGVGmINuXPPLHVXAvxAg="
|
||||||
|
_tidal_module.AUTH = aiohttp.BasicAuth(
|
||||||
|
login=_tidal_module.CLIENT_ID,
|
||||||
|
password=_tidal_module.CLIENT_SECRET
|
||||||
|
)
|
||||||
|
|
||||||
|
from streamrip.client import TidalClient # type: ignore # noqa: E402
|
||||||
|
from streamrip.config import Config as StreamripConfig # type: ignore # noqa: E402
|
||||||
|
from dotenv import load_dotenv # noqa: E402
|
||||||
|
from rapidfuzz import fuzz # noqa: E402
|
||||||
|
|
||||||
|
# Path to persist Tidal tokens across restarts
|
||||||
|
TIDAL_TOKEN_CACHE_PATH = Path(__file__).parent.parent / "tidal_token.json"
|
||||||
|
|
||||||
|
|
||||||
class MetadataFetchError(Exception):
|
class MetadataFetchError(Exception):
|
||||||
"""Raised when metadata fetch permanently fails after retries."""
|
"""Raised when metadata fetch permanently fails after retries."""
|
||||||
|
|
||||||
|
|
||||||
|
# How long before token expiry to proactively refresh (seconds)
|
||||||
|
TIDAL_TOKEN_REFRESH_BUFFER = 600 # 10 minutes
|
||||||
|
# Maximum age of a session before forcing a fresh login (seconds)
|
||||||
|
TIDAL_SESSION_MAX_AGE = 1800 # 30 minutes
|
||||||
|
|
||||||
|
|
||||||
# Suppress noisy logging from this module and from the `streamrip` library
|
# Suppress noisy logging from this module and from the `streamrip` library
|
||||||
# We set propagate=False so messages don't bubble up to the root logger and
|
# We set propagate=False so messages don't bubble up to the root logger and
|
||||||
# attach a NullHandler where appropriate to avoid "No handler found" warnings.
|
# attach a NullHandler where appropriate to avoid "No handler found" warnings.
|
||||||
@@ -47,27 +69,11 @@ class SRUtil:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize StreamRip utility."""
|
"""Initialize StreamRip utility."""
|
||||||
self.streamrip_config = StreamripConfig.defaults()
|
self.streamrip_config = StreamripConfig.defaults()
|
||||||
self.streamrip_config.session.tidal.user_id = os.getenv("tidal_user_id", "")
|
self._load_tidal_config()
|
||||||
self.streamrip_config.session.tidal.access_token = os.getenv(
|
|
||||||
"tidal_access_token", ""
|
|
||||||
)
|
|
||||||
self.streamrip_config.session.tidal.refresh_token = os.getenv(
|
|
||||||
"tidal_refresh_token", ""
|
|
||||||
)
|
|
||||||
self.streamrip_config.session.tidal.token_expiry = os.getenv(
|
|
||||||
"tidal_token_expiry", ""
|
|
||||||
)
|
|
||||||
self.streamrip_config.session.tidal.country_code = os.getenv(
|
|
||||||
"tidal_country_code", ""
|
|
||||||
)
|
|
||||||
self.streamrip_config.session.tidal.quality = int(
|
|
||||||
os.getenv("tidal_default_quality", 2)
|
|
||||||
)
|
|
||||||
self.streamrip_config.session.conversion.enabled = False
|
self.streamrip_config.session.conversion.enabled = False
|
||||||
self.streamrip_config.session.downloads.folder = os.getenv(
|
self.streamrip_config.session.downloads.folder = os.getenv(
|
||||||
"tidal_download_folder", ""
|
"tidal_download_folder", ""
|
||||||
)
|
)
|
||||||
self.streamrip_config
|
|
||||||
self.streamrip_client = TidalClient(self.streamrip_config)
|
self.streamrip_client = TidalClient(self.streamrip_config)
|
||||||
self.MAX_CONCURRENT_METADATA_REQUESTS = 2
|
self.MAX_CONCURRENT_METADATA_REQUESTS = 2
|
||||||
self.METADATA_RATE_LIMIT = 1.25
|
self.METADATA_RATE_LIMIT = 1.25
|
||||||
@@ -82,19 +88,328 @@ class SRUtil:
|
|||||||
self.on_rate_limit: Optional[Callable[[Exception], Any]] = None
|
self.on_rate_limit: Optional[Callable[[Exception], Any]] = None
|
||||||
# Internal flag to avoid repeated notifications for the same runtime
|
# Internal flag to avoid repeated notifications for the same runtime
|
||||||
self._rate_limit_notified = False
|
self._rate_limit_notified = False
|
||||||
|
# Track when we last successfully logged in
|
||||||
|
self._last_login_time: Optional[float] = None
|
||||||
|
# Track last successful API call
|
||||||
|
self._last_successful_request: Optional[float] = None
|
||||||
|
# Keepalive task handle
|
||||||
|
self._keepalive_task: Optional[asyncio.Task] = None
|
||||||
|
# Keepalive interval in seconds
|
||||||
|
self.KEEPALIVE_INTERVAL = 180 # 3 minutes
|
||||||
|
|
||||||
|
async def start_keepalive(self) -> None:
|
||||||
|
"""Start the background keepalive task.
|
||||||
|
|
||||||
|
This should be called once at startup to ensure the Tidal session
|
||||||
|
stays alive even during idle periods.
|
||||||
|
"""
|
||||||
|
if self._keepalive_task and not self._keepalive_task.done():
|
||||||
|
logging.info("Tidal keepalive task already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ensure initial login
|
||||||
|
try:
|
||||||
|
await self._login_and_persist()
|
||||||
|
logging.info("Initial Tidal login successful")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("Initial Tidal login failed: %s", e)
|
||||||
|
|
||||||
|
self._keepalive_task = asyncio.create_task(self._keepalive_runner())
|
||||||
|
logging.info("Tidal keepalive task started")
|
||||||
|
|
||||||
|
async def stop_keepalive(self) -> None:
|
||||||
|
"""Stop the background keepalive task."""
|
||||||
|
if self._keepalive_task and not self._keepalive_task.done():
|
||||||
|
self._keepalive_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._keepalive_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
logging.info("Tidal keepalive task stopped")
|
||||||
|
|
||||||
|
async def _keepalive_runner(self) -> None:
|
||||||
|
"""Background task to keep the Tidal session alive."""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self.KEEPALIVE_INTERVAL)
|
||||||
|
|
||||||
|
# Check if we've had recent activity
|
||||||
|
if self._last_successful_request:
|
||||||
|
time_since_last = time.time() - self._last_successful_request
|
||||||
|
if time_since_last < self.KEEPALIVE_INTERVAL:
|
||||||
|
# Recent activity, no need to ping
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if token is expiring soon and proactively refresh
|
||||||
|
if self._is_token_expiring_soon():
|
||||||
|
logging.info("Tidal keepalive: Token expiring soon, refreshing...")
|
||||||
|
try:
|
||||||
|
await self._login_and_persist(force=True)
|
||||||
|
logging.info("Tidal keepalive: Token refresh successful")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("Tidal keepalive: Token refresh failed: %s", e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if session is stale
|
||||||
|
if self._is_session_stale():
|
||||||
|
logging.info("Tidal keepalive: Session stale, refreshing...")
|
||||||
|
try:
|
||||||
|
await self._login_and_persist(force=True)
|
||||||
|
logging.info("Tidal keepalive: Session refresh successful")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("Tidal keepalive: Session refresh failed: %s", e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Make a lightweight API call to keep the session alive
|
||||||
|
if self.streamrip_client.logged_in:
|
||||||
|
try:
|
||||||
|
# Simple search to keep the connection alive
|
||||||
|
await self._safe_api_call(
|
||||||
|
self.streamrip_client.search,
|
||||||
|
media_type="artist",
|
||||||
|
query="test",
|
||||||
|
retries=1,
|
||||||
|
)
|
||||||
|
logging.debug("Tidal keepalive ping successful")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("Tidal keepalive ping failed: %s", e)
|
||||||
|
# Try to refresh the session
|
||||||
|
try:
|
||||||
|
await self._login_and_persist(force=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logging.info("Tidal keepalive task cancelled")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Error in Tidal keepalive task: %s", e)
|
||||||
|
|
||||||
|
def _load_tidal_config(self) -> None:
|
||||||
|
"""Load Tidal config from cache file if available, otherwise from env."""
|
||||||
|
tidal = self.streamrip_config.session.tidal
|
||||||
|
cached = self._load_cached_tokens()
|
||||||
|
|
||||||
|
if cached:
|
||||||
|
tidal.user_id = cached.get("user_id", "")
|
||||||
|
tidal.access_token = cached.get("access_token", "")
|
||||||
|
tidal.refresh_token = cached.get("refresh_token", "")
|
||||||
|
tidal.token_expiry = cached.get("token_expiry", "")
|
||||||
|
tidal.country_code = cached.get("country_code", os.getenv("tidal_country_code", ""))
|
||||||
|
else:
|
||||||
|
tidal.user_id = os.getenv("tidal_user_id", "")
|
||||||
|
tidal.access_token = os.getenv("tidal_access_token", "")
|
||||||
|
tidal.refresh_token = os.getenv("tidal_refresh_token", "")
|
||||||
|
tidal.token_expiry = os.getenv("tidal_token_expiry", "")
|
||||||
|
tidal.country_code = os.getenv("tidal_country_code", "")
|
||||||
|
|
||||||
|
tidal.quality = int(os.getenv("tidal_default_quality", 2))
|
||||||
|
|
||||||
|
def _load_cached_tokens(self) -> Optional[dict]:
|
||||||
|
"""Load cached tokens from disk if valid."""
|
||||||
|
try:
|
||||||
|
if TIDAL_TOKEN_CACHE_PATH.exists():
|
||||||
|
with open(TIDAL_TOKEN_CACHE_PATH, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
# Validate required fields exist
|
||||||
|
if all(k in data for k in ("access_token", "refresh_token", "token_expiry")):
|
||||||
|
logging.info("Loaded Tidal tokens from cache")
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("Failed to load cached Tidal tokens: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _save_cached_tokens(self) -> None:
|
||||||
|
"""Persist current tokens to disk for use across restarts."""
|
||||||
|
try:
|
||||||
|
tidal = self.streamrip_config.session.tidal
|
||||||
|
data = {
|
||||||
|
"user_id": tidal.user_id,
|
||||||
|
"access_token": tidal.access_token,
|
||||||
|
"refresh_token": tidal.refresh_token,
|
||||||
|
"token_expiry": tidal.token_expiry,
|
||||||
|
"country_code": tidal.country_code,
|
||||||
|
}
|
||||||
|
with open(TIDAL_TOKEN_CACHE_PATH, "w") as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
logging.info("Saved Tidal tokens to cache")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("Failed to save Tidal tokens: %s", e)
|
||||||
|
|
||||||
|
def _apply_new_tokens(self, auth_info: dict) -> None:
|
||||||
|
"""Apply new tokens from device auth to config."""
|
||||||
|
tidal = self.streamrip_config.session.tidal
|
||||||
|
tidal.user_id = str(auth_info.get("user_id", ""))
|
||||||
|
tidal.access_token = auth_info.get("access_token", "")
|
||||||
|
tidal.refresh_token = auth_info.get("refresh_token", "")
|
||||||
|
tidal.token_expiry = auth_info.get("token_expiry", "")
|
||||||
|
tidal.country_code = auth_info.get("country_code", tidal.country_code)
|
||||||
|
self._save_cached_tokens()
|
||||||
|
|
||||||
|
async def start_device_auth(self) -> tuple[str, str]:
|
||||||
|
"""Start device authorization flow.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (device_code, verification_url) - User should visit the URL to authorize.
|
||||||
|
"""
|
||||||
|
if not hasattr(self.streamrip_client, 'session') or not self.streamrip_client.session:
|
||||||
|
self.streamrip_client.session = await self.streamrip_client.get_session()
|
||||||
|
|
||||||
|
device_code, verification_url = await self.streamrip_client._get_device_code()
|
||||||
|
return device_code, verification_url
|
||||||
|
|
||||||
|
async def check_device_auth(self, device_code: str) -> tuple[bool, Optional[str]]:
|
||||||
|
"""Check if user has completed device authorization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_code: The device code from start_device_auth()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (success, error_message)
|
||||||
|
- (True, None) if auth completed successfully
|
||||||
|
- (False, "pending") if user hasn't authorized yet
|
||||||
|
- (False, error_message) if auth failed
|
||||||
|
"""
|
||||||
|
status, auth_info = await self.streamrip_client._get_auth_status(device_code)
|
||||||
|
|
||||||
|
if status == 0:
|
||||||
|
# Success - apply new tokens
|
||||||
|
self._apply_new_tokens(auth_info)
|
||||||
|
# Re-login with new tokens
|
||||||
|
self.streamrip_client.logged_in = False
|
||||||
|
try:
|
||||||
|
await self.streamrip_client.login()
|
||||||
|
self._save_cached_tokens()
|
||||||
|
return True, None
|
||||||
|
except Exception as e:
|
||||||
|
return False, f"Login after auth failed: {e}"
|
||||||
|
elif status == 2:
|
||||||
|
# Pending - user hasn't authorized yet
|
||||||
|
return False, "pending"
|
||||||
|
else:
|
||||||
|
# Failed
|
||||||
|
return False, "Authorization failed"
|
||||||
|
|
||||||
|
def _is_token_expiring_soon(self) -> bool:
|
||||||
|
"""Check if the token is about to expire within the buffer window."""
|
||||||
|
tidal = self.streamrip_config.session.tidal
|
||||||
|
token_expiry = getattr(tidal, "token_expiry", None)
|
||||||
|
if not token_expiry:
|
||||||
|
return True # No expiry info means we should refresh
|
||||||
|
try:
|
||||||
|
# token_expiry is typically an ISO timestamp string
|
||||||
|
if isinstance(token_expiry, str):
|
||||||
|
from datetime import datetime
|
||||||
|
expiry_dt = datetime.fromisoformat(token_expiry.replace('Z', '+00:00'))
|
||||||
|
expiry_ts = expiry_dt.timestamp()
|
||||||
|
else:
|
||||||
|
expiry_ts = float(token_expiry)
|
||||||
|
return expiry_ts < (time.time() + TIDAL_TOKEN_REFRESH_BUFFER)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("Failed to parse token expiry '%s': %s", token_expiry, e)
|
||||||
|
return True # Err on the side of refreshing
|
||||||
|
|
||||||
|
def _is_session_stale(self) -> bool:
|
||||||
|
"""Check if the login session is too old and should be refreshed."""
|
||||||
|
if not self._last_login_time:
|
||||||
|
return True
|
||||||
|
session_age = time.time() - self._last_login_time
|
||||||
|
return session_age > TIDAL_SESSION_MAX_AGE
|
||||||
|
|
||||||
|
async def _force_fresh_login(self) -> bool:
|
||||||
|
"""Force a complete fresh login, ignoring logged_in state.
|
||||||
|
|
||||||
|
Returns True if login succeeded, False otherwise.
|
||||||
|
"""
|
||||||
|
# Reset the logged_in flag to force a fresh login
|
||||||
|
self.streamrip_client.logged_in = False
|
||||||
|
|
||||||
|
# Close existing session if present
|
||||||
|
if hasattr(self.streamrip_client, 'session') and self.streamrip_client.session:
|
||||||
|
try:
|
||||||
|
if not self.streamrip_client.session.closed:
|
||||||
|
await self.streamrip_client.session.close()
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("Error closing old session: %s", e)
|
||||||
|
# Use object.__setattr__ to bypass type checking for session reset
|
||||||
|
try:
|
||||||
|
object.__setattr__(self.streamrip_client, 'session', None)
|
||||||
|
except Exception:
|
||||||
|
pass # Session will be recreated on next login
|
||||||
|
|
||||||
|
try:
|
||||||
|
logging.info("Forcing fresh Tidal login...")
|
||||||
|
await self.streamrip_client.login()
|
||||||
|
self._last_login_time = time.time()
|
||||||
|
self._save_cached_tokens()
|
||||||
|
logging.info("Fresh Tidal login successful")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("Forced Tidal login failed: %s - device re-auth may be required", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _login_and_persist(self, force: bool = False) -> None:
|
||||||
|
"""Login to Tidal and persist any refreshed tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force: If True, force a fresh login even if already logged in.
|
||||||
|
|
||||||
|
This method now checks for:
|
||||||
|
1. Token expiry - refreshes if token is about to expire
|
||||||
|
2. Session age - refreshes if session is too old
|
||||||
|
3. logged_in state - logs in if not logged in
|
||||||
|
|
||||||
|
If refresh fails, logs a warning but does not raise.
|
||||||
|
"""
|
||||||
|
needs_login = force or not self.streamrip_client.logged_in
|
||||||
|
|
||||||
|
# Check if token is expiring soon
|
||||||
|
if not needs_login and self._is_token_expiring_soon():
|
||||||
|
logging.info("Tidal token expiring soon, will refresh")
|
||||||
|
needs_login = True
|
||||||
|
|
||||||
|
# Check if session is too old
|
||||||
|
if not needs_login and self._is_session_stale():
|
||||||
|
logging.info("Tidal session is stale, will refresh")
|
||||||
|
needs_login = True
|
||||||
|
|
||||||
|
if not needs_login:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Reset logged_in to ensure fresh login attempt
|
||||||
|
if force or self._is_token_expiring_soon():
|
||||||
|
self.streamrip_client.logged_in = False
|
||||||
|
|
||||||
|
await self.streamrip_client.login()
|
||||||
|
self._last_login_time = time.time()
|
||||||
|
# After login, tokens may have been refreshed - persist them
|
||||||
|
self._save_cached_tokens()
|
||||||
|
logging.info("Tidal login/refresh successful")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("Tidal login/refresh failed: %s - device re-auth may be required", e)
|
||||||
|
# Don't mark as logged in on failure - let subsequent calls retry
|
||||||
|
|
||||||
async def rate_limited_request(self, func, *args, **kwargs):
|
async def rate_limited_request(self, func, *args, **kwargs):
|
||||||
|
"""Rate-limited wrapper that also ensures login before making requests."""
|
||||||
async with self.METADATA_SEMAPHORE:
|
async with self.METADATA_SEMAPHORE:
|
||||||
now = time.time()
|
now = time.time()
|
||||||
elapsed = now - self.LAST_METADATA_REQUEST
|
elapsed = now - self.LAST_METADATA_REQUEST
|
||||||
if elapsed < self.METADATA_RATE_LIMIT:
|
if elapsed < self.METADATA_RATE_LIMIT:
|
||||||
await asyncio.sleep(self.METADATA_RATE_LIMIT - elapsed)
|
await asyncio.sleep(self.METADATA_RATE_LIMIT - elapsed)
|
||||||
|
|
||||||
|
# Ensure we're logged in before making the request
|
||||||
|
try:
|
||||||
|
await self._login_and_persist()
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("Pre-request login failed in rate_limited_request: %s", e)
|
||||||
|
|
||||||
result = await func(*args, **kwargs)
|
result = await func(*args, **kwargs)
|
||||||
self.LAST_METADATA_REQUEST = time.time()
|
self.LAST_METADATA_REQUEST = time.time()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _safe_api_call(
|
async def _safe_api_call(
|
||||||
self, func, *args, retries: int = 2, backoff: float = 0.5, **kwargs
|
self, func, *args, retries: int = 3, backoff: float = 0.5, **kwargs
|
||||||
):
|
):
|
||||||
"""Call an async API function with resilient retry behavior.
|
"""Call an async API function with resilient retry behavior.
|
||||||
|
|
||||||
@@ -103,18 +418,32 @@ class SRUtil:
|
|||||||
attempt a `login()` and retry up to `retries` times.
|
attempt a `login()` and retry up to `retries` times.
|
||||||
- On 400/429 responses (message contains '400' or '429'): retry with backoff
|
- On 400/429 responses (message contains '400' or '429'): retry with backoff
|
||||||
without triggering login (to avoid excessive logins).
|
without triggering login (to avoid excessive logins).
|
||||||
|
- On 401 (Unauthorized): force a fresh login and retry.
|
||||||
|
|
||||||
Returns the result or raises the last exception.
|
Returns the result or raises the last exception.
|
||||||
"""
|
"""
|
||||||
last_exc: Optional[Exception] = None
|
last_exc: Optional[Exception] = None
|
||||||
for attempt in range(retries):
|
for attempt in range(retries):
|
||||||
try:
|
try:
|
||||||
return await func(*args, **kwargs)
|
# Before each attempt, ensure we have a valid session
|
||||||
|
if attempt == 0:
|
||||||
|
# On first attempt, try to ensure logged in (checks token expiry)
|
||||||
|
# Wrapped in try/except so login failures don't block the API call
|
||||||
|
try:
|
||||||
|
await self._login_and_persist()
|
||||||
|
except Exception as login_err:
|
||||||
|
logging.warning("Pre-request login failed: %s (continuing anyway)", login_err)
|
||||||
|
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
# Track successful request
|
||||||
|
self._last_successful_request = time.time()
|
||||||
|
return result
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
# Probably missing/closed client internals: try re-login once
|
# Probably missing/closed client internals: try re-login once
|
||||||
last_exc = e
|
last_exc = e
|
||||||
|
logging.warning("AttributeError in API call (attempt %d/%d): %s", attempt + 1, retries, e)
|
||||||
try:
|
try:
|
||||||
await self.streamrip_client.login()
|
await self._force_fresh_login()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
continue
|
continue
|
||||||
@@ -144,6 +473,31 @@ class SRUtil:
|
|||||||
await asyncio.sleep(backoff * (2**attempt))
|
await asyncio.sleep(backoff * (2**attempt))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Treat 401 (Unauthorized) as an auth failure: force a fresh re-login then retry
|
||||||
|
is_401_error = (
|
||||||
|
(isinstance(e, aiohttp.ClientResponseError) and getattr(e, "status", None) == 401)
|
||||||
|
or "401" in msg
|
||||||
|
or "unauthorized" in msg.lower()
|
||||||
|
)
|
||||||
|
if is_401_error:
|
||||||
|
logging.warning(
|
||||||
|
"Received 401/Unauthorized from Tidal (attempt %d/%d). Forcing fresh re-login...",
|
||||||
|
attempt + 1,
|
||||||
|
retries,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# Use force=True to ensure we actually re-authenticate
|
||||||
|
login_success = await self._force_fresh_login()
|
||||||
|
if login_success:
|
||||||
|
logging.info("Forced re-login after 401 successful")
|
||||||
|
else:
|
||||||
|
logging.warning("Forced re-login after 401 failed - may need device re-auth")
|
||||||
|
except Exception as login_exc:
|
||||||
|
logging.warning("Forced login after 401 failed: %s", login_exc)
|
||||||
|
if attempt < retries - 1:
|
||||||
|
await asyncio.sleep(backoff * (2**attempt))
|
||||||
|
continue
|
||||||
|
|
||||||
# Connection related errors — try to re-login then retry
|
# Connection related errors — try to re-login then retry
|
||||||
if (
|
if (
|
||||||
isinstance(
|
isinstance(
|
||||||
@@ -159,7 +513,7 @@ class SRUtil:
|
|||||||
or "closed" in msg.lower()
|
or "closed" in msg.lower()
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
await self.streamrip_client.login()
|
await self._login_and_persist(force=True)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if attempt < retries - 1:
|
if attempt < retries - 1:
|
||||||
@@ -434,8 +788,6 @@ class SRUtil:
|
|||||||
|
|
||||||
async def get_albums_by_artist_id(self, artist_id: int) -> Optional[list | dict]:
|
async def get_albums_by_artist_id(self, artist_id: int) -> Optional[list | dict]:
|
||||||
"""Get albums by artist ID. Retry login only on authentication failure. Rate limit and retry on 400/429."""
|
"""Get albums by artist ID. Retry login only on authentication failure. Rate limit and retry on 400/429."""
|
||||||
import asyncio
|
|
||||||
|
|
||||||
artist_id_str: str = str(artist_id)
|
artist_id_str: str = str(artist_id)
|
||||||
albums_out: list[dict] = []
|
albums_out: list[dict] = []
|
||||||
max_retries = 4
|
max_retries = 4
|
||||||
@@ -585,26 +937,26 @@ class SRUtil:
|
|||||||
TODO: Reimplement using StreamRip
|
TODO: Reimplement using StreamRip
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# _safe_api_call already handles login, no need to call it here
|
||||||
search_res = await self._safe_api_call(
|
search_res = await self._safe_api_call(
|
||||||
self.streamrip_client.search,
|
self.streamrip_client.search,
|
||||||
media_type="track",
|
media_type="track",
|
||||||
query=f"{artist} - {song}",
|
query=f"{artist} - {song}",
|
||||||
retries=3,
|
retries=3,
|
||||||
)
|
)
|
||||||
logging.critical("Result: %s", search_res)
|
logging.debug("Search result: %s", search_res)
|
||||||
return (
|
return (
|
||||||
search_res[0].get("items")
|
search_res[0].get("items")
|
||||||
if search_res and isinstance(search_res, list)
|
if search_res and isinstance(search_res, list)
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
logging.warning("Search Exception: %s", str(e))
|
||||||
logging.critical("Search Exception: %s", str(e))
|
if n < 2: # Reduce max retries from 3 to 2
|
||||||
if n < 3:
|
|
||||||
n += 1
|
n += 1
|
||||||
|
await asyncio.sleep(0.5 * n) # Add backoff
|
||||||
return await self.get_tracks_by_artist_song(artist, song, n)
|
return await self.get_tracks_by_artist_song(artist, song, n)
|
||||||
return []
|
return []
|
||||||
# return []
|
|
||||||
|
|
||||||
async def get_stream_url_by_track_id(
|
async def get_stream_url_by_track_id(
|
||||||
self, track_id: int, quality: str = "FLAC"
|
self, track_id: int, quality: str = "FLAC"
|
||||||
@@ -655,7 +1007,6 @@ class SRUtil:
|
|||||||
"""
|
"""
|
||||||
for attempt in range(1, self.MAX_METADATA_RETRIES + 1):
|
for attempt in range(1, self.MAX_METADATA_RETRIES + 1):
|
||||||
try:
|
try:
|
||||||
await self._safe_api_call(self.streamrip_client.login, retries=1)
|
|
||||||
# Track metadata
|
# Track metadata
|
||||||
metadata = await self.rate_limited_request(
|
metadata = await self.rate_limited_request(
|
||||||
self.streamrip_client.get_metadata, str(track_id), "track"
|
self.streamrip_client.get_metadata, str(track_id), "track"
|
||||||
@@ -734,7 +1085,6 @@ class SRUtil:
|
|||||||
bool
|
bool
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
await self._safe_api_call(self.streamrip_client.login, retries=1)
|
|
||||||
track_url = await self.get_stream_url_by_track_id(track_id)
|
track_url = await self.get_stream_url_by_track_id(track_id)
|
||||||
if not track_url:
|
if not track_url:
|
||||||
return False
|
return False
|
||||||
|
|||||||
25
utils/yt_utils.py
Normal file
25
utils/yt_utils.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import hmac
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
|
||||||
|
VIDEO_PROXY_SECRET = os.environ.get("VIDEO_PROXY_SECRET", "").encode()
|
||||||
|
|
||||||
|
def sign_video_id(video_id: Optional[str|bool]) -> str:
|
||||||
|
"""Generate a signed token for a video ID."""
|
||||||
|
if not VIDEO_PROXY_SECRET or not video_id:
|
||||||
|
return "" # Return empty if no secret configured
|
||||||
|
|
||||||
|
timestamp = int(time.time() * 1000) # milliseconds to match JS Date.now()
|
||||||
|
payload = f"{video_id}:{timestamp}"
|
||||||
|
signature = hmac.new(
|
||||||
|
VIDEO_PROXY_SECRET,
|
||||||
|
payload.encode(),
|
||||||
|
hashlib.sha256
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
token_data = f"{payload}:{signature}"
|
||||||
|
# base64url encode (no padding, to match JS base64url)
|
||||||
|
return base64.urlsafe_b64encode(token_data.encode()).decode().rstrip("=")
|
||||||
Reference in New Issue
Block a user