small improvements re: #33

This commit is contained in:
2025-07-01 10:34:03 -04:00
parent 0fe081597e
commit 1991e5b31b
2 changed files with 55 additions and 12 deletions

View File

@ -14,6 +14,7 @@ from .constructors import (
from utils import radio_util
from typing import Optional
from fastapi import FastAPI, BackgroundTasks, Request, Response, HTTPException
from starlette.concurrency import run_in_threadpool
from fastapi.responses import RedirectResponse, JSONResponse
@ -57,7 +58,7 @@ class Radio(FastAPI):
async def on_start(self) -> None:
logging.info("radio: Initializing")
self.loop.run_in_executor(None, self.radio_util.load_playlist)
await run_in_threadpool(self.radio_util.load_playlist)
async def radio_skip(
self, data: ValidRadioNextRequest, request: Request
@ -316,7 +317,7 @@ class Radio(FastAPI):
if len(self.radio_util.active_playlist) > 1:
self.radio_util.active_playlist.append(next) # Push to end of playlist
else:
self.loop.run_in_executor(None, self.radio_util.load_playlist)
await run_in_threadpool(self.radio_util.load_playlist)
self.radio_util.now_playing = next
next["start"] = time_started

View File

@ -56,13 +56,13 @@ class RadioUtil:
"deathcore",
"edm",
"electronic",
"hard rock",
"rock",
"ska",
"post punk",
"post-punk",
"pop punk",
"pop-punk",
# "hard rock",
# "rock",
# "ska",
# "post punk",
# "post-punk",
# "pop punk",
# "pop-punk",
]
self.active_playlist: list[dict] = []
self.playlist_loaded: bool = False
@ -297,6 +297,36 @@ class RadioUtil:
traceback.print_exc()
return False
def get_genres(self, input_artists: list[str]) -> dict:
"""
Retrieve genres for given list of artists
Batch equivalent of get_genre
Args:
input_artists (list): The artists to query
Returns:
dict[str, str]
"""
time_start: float = time.time()
artist_genre: dict[str, str] = {}
query: str = (
"SELECT genre FROM artist_genre WHERE artist LIKE ? COLLATE NOCASE"
)
with sqlite3.connect(self.artist_genre_db_path) as _db:
_db.row_factory = sqlite3.Row
for artist in input_artists:
params: tuple[str] = (f"%%{artist}%%",)
_cursor = _db.execute(query, params)
res = _cursor.fetchone()
if not res:
artist_genre[artist] = "N/A"
continue
artist_genre[artist] = res["genre"]
time_end: float = time.time()
logging.info(f"Time taken: {time_end - time_start}")
return artist_genre
def get_genre(self, artist: str) -> str:
"""
Retrieve Genre for given Artist
@ -347,9 +377,7 @@ class RadioUtil:
"artist": double_space.sub(" ", r["artist"]).strip(),
"song": double_space.sub(" ", r["song"]).strip(),
"album": double_space.sub(" ", r["album"]).strip(),
"genre": self.get_genre(
double_space.sub(" ", r["artist"]).strip()
),
"genre": "N/A",
"artistsong": double_space.sub(
" ", r["artistdashsong"]
).strip(),
@ -364,6 +392,20 @@ class RadioUtil:
len(self.active_playlist),
)
logging.info(
"Adding genre data..."
)
artist_genre = self.get_genres([
str(r.get('artist')) for r in self.active_playlist])
for item in self.active_playlist:
artist = double_space.sub(" ", item["artist"]).strip()
item['genre'] = artist_genre[artist]
logging.info("Genre data added.")
random.shuffle(self.active_playlist)
"""Dedupe"""