diff --git a/endpoints/radio.py b/endpoints/radio.py index 06e589a..171d501 100644 --- a/endpoints/radio.py +++ b/endpoints/radio.py @@ -14,6 +14,7 @@ from .constructors import ( from utils import radio_util from typing import Optional from fastapi import FastAPI, BackgroundTasks, Request, Response, HTTPException +from starlette.concurrency import run_in_threadpool from fastapi.responses import RedirectResponse, JSONResponse @@ -57,7 +58,7 @@ class Radio(FastAPI): async def on_start(self) -> None: logging.info("radio: Initializing") - self.loop.run_in_executor(None, self.radio_util.load_playlist) + await run_in_threadpool(self.radio_util.load_playlist) async def radio_skip( self, data: ValidRadioNextRequest, request: Request @@ -316,7 +317,7 @@ class Radio(FastAPI): if len(self.radio_util.active_playlist) > 1: self.radio_util.active_playlist.append(next) # Push to end of playlist else: - self.loop.run_in_executor(None, self.radio_util.load_playlist) + await run_in_threadpool(self.radio_util.load_playlist) self.radio_util.now_playing = next next["start"] = time_started diff --git a/utils/radio_util.py b/utils/radio_util.py index dfb64ab..3b17d52 100644 --- a/utils/radio_util.py +++ b/utils/radio_util.py @@ -56,13 +56,13 @@ class RadioUtil: "deathcore", "edm", "electronic", - "hard rock", - "rock", - "ska", - "post punk", - "post-punk", - "pop punk", - "pop-punk", + # "hard rock", + # "rock", + # "ska", + # "post punk", + # "post-punk", + # "pop punk", + # "pop-punk", ] self.active_playlist: list[dict] = [] self.playlist_loaded: bool = False @@ -296,6 +296,36 @@ class RadioUtil: logging.info("Failed to store artist/genre pairs: %s", str(e)) traceback.print_exc() return False + + def get_genres(self, input_artists: list[str]) -> dict: + """ + Retrieve genres for given list of artists + Batch equivalent of get_genre + Args: + input_artists (list): The artists to query + + Returns: + dict[str, str] + """ + time_start: float = time.time() + artist_genre: dict[str, str] = {} + query: str = ( + "SELECT genre FROM artist_genre WHERE artist LIKE ? COLLATE NOCASE" + ) + with sqlite3.connect(self.artist_genre_db_path) as _db: + _db.row_factory = sqlite3.Row + for artist in input_artists: + params: tuple[str] = (f"%%{artist}%%",) + _cursor = _db.execute(query, params) + res = _cursor.fetchone() + if not res: + artist_genre[artist] = "N/A" + continue + artist_genre[artist] = res["genre"] + time_end: float = time.time() + logging.info(f"Time taken: {time_end - time_start}") + return artist_genre + def get_genre(self, artist: str) -> str: """ @@ -347,9 +377,7 @@ class RadioUtil: "artist": double_space.sub(" ", r["artist"]).strip(), "song": double_space.sub(" ", r["song"]).strip(), "album": double_space.sub(" ", r["album"]).strip(), - "genre": self.get_genre( - double_space.sub(" ", r["artist"]).strip() - ), + "genre": "N/A", "artistsong": double_space.sub( " ", r["artistdashsong"] ).strip(), @@ -364,6 +392,20 @@ class RadioUtil: len(self.active_playlist), ) + logging.info( + "Adding genre data..." + ) + + artist_genre = self.get_genres([ + str(r.get('artist')) for r in self.active_playlist]) + + + for item in self.active_playlist: + artist = double_space.sub(" ", item["artist"]).strip() + item['genre'] = artist_genre[artist] + + logging.info("Genre data added.") + random.shuffle(self.active_playlist) """Dedupe"""