diff --git a/base.py b/base.py index 158298a..1aaf3e1 100644 --- a/base.py +++ b/base.py @@ -13,7 +13,7 @@ logger.setLevel(logging.DEBUG) app = FastAPI() util = importlib.import_module("util").Utilities() constants = importlib.import_module("constants").Constants() - +glob_state = importlib.import_module("state").State(app, util, constants) origins = [ "https://codey.lol", @@ -25,6 +25,9 @@ allow_credentials=True, allow_methods=["POST"], allow_headers=["*"]) + + + """ Blacklisted routes """ @@ -50,16 +53,17 @@ End Blacklisted Routes """ Actionable Routes """ - -randmsg_endpoint = importlib.import_module("endpoints.rand_msg").RandMsg(app, util, constants) +counter_endpoints = importlib.import_module("endpoints.counters").Counters(app, util, constants, glob_state) +randmsg_endpoint = importlib.import_module("endpoints.rand_msg").RandMsg(app, util, constants, glob_state) +transcription_endpoints = importlib.import_module("endpoints.transcriptions").Transcriptions(app, util, constants, glob_state) # Below also provides: /lyric_cache_list/ (in addition to /lyric_search/) -lyric_search_endpoint = importlib.import_module("endpoints.lyric_search").LyricSearch(app, util, constants) +lyric_search_endpoint = importlib.import_module("endpoints.lyric_search").LyricSearch(app, util, constants, glob_state) # Below provides numerous LastFM-fed endpoints -lastfm_endpoints = importlib.import_module("endpoints.lastfm").LastFM(app, util, constants) +lastfm_endpoints = importlib.import_module("endpoints.lastfm").LastFM(app, util, constants, glob_state) # Below: YT endpoint(s) -yt_endpoints = importlib.import_module("endpoints.yt").YT(app, util, constants) -# Below: Transcription endpoints -transcription_endpoints = importlib.import_module("endpoints.transcriptions").Transcriptions(app, util, constants) +yt_endpoints = importlib.import_module("endpoints.yt").YT(app, util, constants, glob_state) + + """ diff --git a/endpoints/counters.py b/endpoints/counters.py new file mode 100644 index 0000000..0693215 --- /dev/null +++ b/endpoints/counters.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3.12 + +#!/usr/bin/env python3.12 + +import importlib +from fastapi import FastAPI +from pydantic import BaseModel + +class ValidCounterIncrementRequest(BaseModel): + """ + - **counter**: counter to update + """ + + counter: str + + +class ValidCounterRetrievalRequest(BaseModel): + """ + - **counter**: counter to retrieve (if none is provided, all counters are returned) + """ + + counter: str = "all" + + +class Counters(FastAPI): + """Counter Endpoints""" + def __init__(self, app: FastAPI, util, constants, glob_state): # pylint: disable=super-init-not-called + self.app = app + self.util = util + self.constants = constants + self.glob_state = glob_state + + self.endpoints = { + "counters/get": self.get_counter_handler, + "counters/increment": self.increment_counter_handler + #tbd + } + + for endpoint, handler in self.endpoints.items(): + app.add_api_route(f"/{endpoint}/", handler, methods=["POST"]) + + async def get_counter_handler(self, data: ValidCounterRetrievalRequest): + """ + /get/ + Get current counter value + """ + + counter = data.counter + if not counter == 'all': + count = await self.glob_state.get_counter(counter) + else: + count = await self.glob_state.get_all_counters() + return { + 'counter': counter, + 'count': count + + } + + async def increment_counter_handler(self, data: ValidCounterIncrementRequest): + """ + /increment/ + Increment counter value (requires PUT KEY) + """ + + return { + + } + + diff --git a/endpoints/lastfm.py b/endpoints/lastfm.py index fd69d90..b89135e 100644 --- a/endpoints/lastfm.py +++ b/endpoints/lastfm.py @@ -22,10 +22,11 @@ class ValidAlbumDetailRequest(BaseModel): class LastFM(FastAPI): """Last.FM Endpoints""" - def __init__(self, app: FastAPI, util, constants): # pylint: disable=super-init-not-called + def __init__(self, app: FastAPI, util, constants, glob_state): # pylint: disable=super-init-not-called self.app = app self.util = util self.constants = constants + self.glob_state = glob_state self.lastfm = importlib.import_module("lastfm_wrapper").LastFM() self.endpoints = { diff --git a/endpoints/lyric_search.py b/endpoints/lyric_search.py index c1c0616..29b42b8 100644 --- a/endpoints/lyric_search.py +++ b/endpoints/lyric_search.py @@ -37,10 +37,11 @@ class ValidLyricRequest(BaseModel): class LyricSearch(FastAPI): """Lyric Search Endpoint""" - def __init__(self, app: FastAPI, util, constants): # pylint: disable=super-init-not-called + def __init__(self, app: FastAPI, util, constants, glob_state): # pylint: disable=super-init-not-called self.app = app self.util = util self.constants = constants + self.glob_state = glob_state self.lyrics_engine = importlib.import_module("lyrics_engine").LyricsEngine() self.endpoint_name = "lyric_search" @@ -87,7 +88,9 @@ class LyricSearch(FastAPI): src = data.src.upper() if not src in self.acceptable_request_sources: - raise HTTPException(detail="Invalid request source", status_code=403) + raise HTTPException(detail="Invalid request source", status_code=403) + + await self.glob_state.increment_counter('lyric_requests') search_artist = data.a search_song = data.s @@ -128,6 +131,7 @@ class LyricSearch(FastAPI): recipient='anyone') if not search_worker or not 'l' in search_worker.keys(): + await self.glob_state.increment_counter('failedlyric_requests') return { 'err': True, 'errorText': 'Sources exhausted, lyrics not located.' @@ -141,4 +145,5 @@ class LyricSearch(FastAPI): 'lyrics': regex.sub(r"\s/\s", "
", " ".join(search_worker['l'])), 'from_cache': search_worker['method'].strip().lower().startswith("local cache"), 'src': search_worker['method'] if add_extras else None, + 'reqn': await self.glob_state.get_counter('lyric_requests') } diff --git a/endpoints/rand_msg.py b/endpoints/rand_msg.py index 9954e11..08653ee 100644 --- a/endpoints/rand_msg.py +++ b/endpoints/rand_msg.py @@ -9,11 +9,12 @@ from fastapi import FastAPI class RandMsg(FastAPI): """Random Message Endpoint""" - def __init__(self, app: FastAPI, util, constants): # pylint: disable=super-init-not-called + def __init__(self, app: FastAPI, util, constants, glob_state): # pylint: disable=super-init-not-called self.app = app self.util = util self.constants = constants - + self.glob_state = glob_state + self.endpoint_name = "randmsg" app.add_api_route(f"/{self.endpoint_name}/", self.randmsg_handler, methods=["POST"]) diff --git a/endpoints/transcriptions.py b/endpoints/transcriptions.py index 2105564..ae57fb4 100644 --- a/endpoints/transcriptions.py +++ b/endpoints/transcriptions.py @@ -23,10 +23,11 @@ class ValidShowEpisodeLineRequest(BaseModel): class Transcriptions(FastAPI): """Transcription Endpoints""" - def __init__(self, app: FastAPI, util, constants): # pylint: disable=super-init-not-called + def __init__(self, app: FastAPI, util, constants, glob_state): # pylint: disable=super-init-not-called self.app = app self.util = util self.constants = constants + self.glob_state = glob_state self.endpoints = { "transcriptions/get_episodes": self.get_episodes_handler, @@ -79,6 +80,7 @@ class Transcriptions(FastAPI): 'err': True, 'errorText': 'Unknown error.' } + await self.glob_state.increment_counter('transcript_list_requests') async with sqlite3.connect(database=db_path, timeout=1) as _db: async with _db.execute(db_query) as _cursor: result = await _cursor.fetchall() @@ -114,7 +116,8 @@ class Transcriptions(FastAPI): 'err': True, 'errorText': 'Unknown error' } - + + await self.glob_state.increment_counter('transcript_requests') async with sqlite3.connect(database=db_path, timeout=1) as _db: params = (episode_id,) async with _db.execute(db_query, params) as _cursor: diff --git a/endpoints/yt.py b/endpoints/yt.py index 9c88bb1..cf1c30b 100644 --- a/endpoints/yt.py +++ b/endpoints/yt.py @@ -14,10 +14,11 @@ class ValidYTSearchRequest(BaseModel): class YT(FastAPI): """YT Endpoints""" - def __init__(self, app: FastAPI, util, constants): # pylint: disable=super-init-not-called + def __init__(self, app: FastAPI, util, constants, glob_state): # pylint: disable=super-init-not-called self.app = app self.util = util self.constants = constants + self.glob_state = glob_state self.ytsearch = importlib.import_module("youtube_search_async").YoutubeSearch() self.endpoints = { diff --git a/state.py b/state.py new file mode 100644 index 0000000..79488dd --- /dev/null +++ b/state.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3.12 + +"""Global State Storage/Counters""" + +import logging +import os +import aiosqlite as sqlite3 + +from fastapi import FastAPI +from fastapi_utils.tasks import repeat_every + + +class State(FastAPI): + def __init__(self, app: FastAPI, util, constants): + self.counter_db_path = os.path.join("/", "var", "lib", "singerdbs", "stats.db") + self.counters = { + str(counter): 0 for counter in constants.AVAILABLE_COUNTERS + } + self.counters_initialized = False + logging.debug("[State] Counters: %s", self.counters) + + @app.on_event("startup") + async def get_counters(): + logging.info("[State] Initializing counters...") + async with sqlite3.connect(self.counter_db_path, timeout=2) as _db: + _query = "SELECT ai_requests, lyric_requests, transcript_list_requests, transcript_requests, lyrichistory_requests, \ + failedlyric_requests, misc_failures, claude_ai_requests FROM counters LIMIT 1" + async with _db.execute(_query) as _cursor: + _result = await _cursor.fetchone() + (ai_requests, + lyric_requests, + transcript_list_requests, + transcript_requests, + lyrichistory_requests, + failedlyric_requests, + misc_failures, + claude_ai_requests) = _result + self.counters = { + 'ai_requests': ai_requests, + 'lyric_requests': lyric_requests, + 'transcript_list_requests': transcript_list_requests, + 'transcript_requests': transcript_requests, + 'lyrichistory_requests': lyrichistory_requests, + 'failedlyric_requests': failedlyric_requests, + 'misc_failures': misc_failures, + 'claude_ai_requests': claude_ai_requests + } + self.counters_initialized = True + logging.info("Counters loaded from db: %s", self.counters) + + + @app.on_event("startup") + @repeat_every(seconds=10) + async def update_db(): + if self.counters_initialized == False: + logging.debug("[State] TICK: Counters not yet initialized") + return + + ai_requests = self.counters.get('ai_requests') + lyric_requests = self.counters.get('lyric_requests') + transcript_list_requests = self.counters.get('transcript_list_requests') + transcript_requests = self.counters.get('transcript_requests') + lyrichistory_requests = self.counters.get('lyrichistory_requests') + failedlyric_requests = self.counters.get('failedlyric_requests') + claude_ai_requests = self.counters.get('claude_ai_requests') + + async with sqlite3.connect(self.counter_db_path, timeout=2) as _db: + _query = "UPDATE counters SET ai_requests = ?, lyric_requests = ?, transcript_list_requests = ?, \ + transcript_requests = ?, lyrichistory_requests = ?, failedlyric_requests = ?, \ + claude_ai_requests = ?" + + _params = (ai_requests, + lyric_requests, + transcript_list_requests, + transcript_requests, + lyrichistory_requests, + failedlyric_requests, + claude_ai_requests) + async with _db.execute(_query, _params) as _cursor: + if _cursor.rowcount != 1: + logging.error("Failed to update DB") + return + await _db.commit() + logging.debug("[State] Updated DB") + + + + + + + async def increment_counter(self, counter: str): + if not(counter in self.counters.keys()): + raise BaseException("[State] Counter %s does not exist", counter) + + self.counters[counter] += 1 + return True + + async def get_counter(self, counter: str): + if not(counter in self.counters.keys()): + raise BaseException("[State] Counter %s does not exist", counter) + + return self.counters[counter] + + async def get_all_counters(self): + return self.counters + + + +