Add global state module
This commit is contained in:
		
							
								
								
									
										20
									
								
								base.py
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								base.py
									
									
									
									
									
								
							| @@ -13,7 +13,7 @@ logger.setLevel(logging.DEBUG) | |||||||
| app = FastAPI() | app = FastAPI() | ||||||
| util = importlib.import_module("util").Utilities() | util = importlib.import_module("util").Utilities() | ||||||
| constants = importlib.import_module("constants").Constants() | constants = importlib.import_module("constants").Constants() | ||||||
|  | glob_state = importlib.import_module("state").State(app, util, constants) | ||||||
|  |  | ||||||
| origins = [ | origins = [ | ||||||
|     "https://codey.lol", |     "https://codey.lol", | ||||||
| @@ -25,6 +25,9 @@ allow_credentials=True, | |||||||
| allow_methods=["POST"], | allow_methods=["POST"], | ||||||
| allow_headers=["*"]) | allow_headers=["*"]) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| """ | """ | ||||||
| Blacklisted routes | Blacklisted routes | ||||||
| """ | """ | ||||||
| @@ -50,16 +53,17 @@ End Blacklisted Routes | |||||||
| """ | """ | ||||||
| Actionable Routes | Actionable Routes | ||||||
| """ | """ | ||||||
|  | counter_endpoints = importlib.import_module("endpoints.counters").Counters(app, util, constants, glob_state) | ||||||
| randmsg_endpoint = importlib.import_module("endpoints.rand_msg").RandMsg(app, util, constants) | randmsg_endpoint = importlib.import_module("endpoints.rand_msg").RandMsg(app, util, constants, glob_state) | ||||||
|  | transcription_endpoints = importlib.import_module("endpoints.transcriptions").Transcriptions(app, util, constants, glob_state) | ||||||
| # Below also provides: /lyric_cache_list/ (in addition to /lyric_search/) | # Below also provides: /lyric_cache_list/ (in addition to /lyric_search/) | ||||||
| lyric_search_endpoint = importlib.import_module("endpoints.lyric_search").LyricSearch(app, util, constants) | lyric_search_endpoint = importlib.import_module("endpoints.lyric_search").LyricSearch(app, util, constants, glob_state) | ||||||
| # Below provides numerous LastFM-fed endpoints | # Below provides numerous LastFM-fed endpoints | ||||||
| lastfm_endpoints = importlib.import_module("endpoints.lastfm").LastFM(app, util, constants) | lastfm_endpoints = importlib.import_module("endpoints.lastfm").LastFM(app, util, constants, glob_state) | ||||||
| # Below: YT endpoint(s) | # Below: YT endpoint(s) | ||||||
| yt_endpoints = importlib.import_module("endpoints.yt").YT(app, util, constants) | yt_endpoints = importlib.import_module("endpoints.yt").YT(app, util, constants, glob_state) | ||||||
| # Below: Transcription endpoints |  | ||||||
| transcription_endpoints = importlib.import_module("endpoints.transcriptions").Transcriptions(app, util, constants) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| """ | """ | ||||||
|   | |||||||
							
								
								
									
										69
									
								
								endpoints/counters.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								endpoints/counters.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | |||||||
|  | #!/usr/bin/env python3.12 | ||||||
|  |  | ||||||
|  | #!/usr/bin/env python3.12 | ||||||
|  |  | ||||||
|  | import importlib | ||||||
|  | from fastapi import FastAPI | ||||||
|  | from pydantic import BaseModel | ||||||
|  |  | ||||||
|  | class ValidCounterIncrementRequest(BaseModel): | ||||||
|  |     """ | ||||||
|  |     - **counter**: counter to update | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     counter: str | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ValidCounterRetrievalRequest(BaseModel): | ||||||
|  |     """ | ||||||
|  |     - **counter**: counter to retrieve (if none is provided, all counters are returned) | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     counter: str = "all" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Counters(FastAPI): | ||||||
|  |     """Counter Endpoints"""    | ||||||
|  |     def __init__(self, app: FastAPI, util, constants, glob_state):  # pylint: disable=super-init-not-called | ||||||
|  |         self.app = app | ||||||
|  |         self.util = util | ||||||
|  |         self.constants = constants | ||||||
|  |         self.glob_state = glob_state | ||||||
|  |  | ||||||
|  |         self.endpoints = { | ||||||
|  |             "counters/get": self.get_counter_handler, | ||||||
|  |             "counters/increment": self.increment_counter_handler | ||||||
|  |             #tbd | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |         for endpoint, handler in self.endpoints.items(): | ||||||
|  |             app.add_api_route(f"/{endpoint}/", handler, methods=["POST"]) | ||||||
|  |      | ||||||
|  |     async def get_counter_handler(self, data: ValidCounterRetrievalRequest): | ||||||
|  |         """ | ||||||
|  |         /get/ | ||||||
|  |         Get current counter value | ||||||
|  |         """ | ||||||
|  |          | ||||||
|  |         counter = data.counter | ||||||
|  |         if not counter == 'all': | ||||||
|  |             count = await self.glob_state.get_counter(counter) | ||||||
|  |         else: | ||||||
|  |             count = await self.glob_state.get_all_counters() | ||||||
|  |         return { | ||||||
|  |             'counter': counter, | ||||||
|  |             'count': count | ||||||
|  |  | ||||||
|  |         } | ||||||
|  |      | ||||||
|  |     async def increment_counter_handler(self, data: ValidCounterIncrementRequest): | ||||||
|  |         """ | ||||||
|  |         /increment/ | ||||||
|  |         Increment counter value (requires PUT KEY) | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         return { | ||||||
|  |  | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -22,10 +22,11 @@ class ValidAlbumDetailRequest(BaseModel): | |||||||
|  |  | ||||||
| class LastFM(FastAPI): | class LastFM(FastAPI): | ||||||
|     """Last.FM Endpoints"""    |     """Last.FM Endpoints"""    | ||||||
|     def __init__(self, app: FastAPI, util, constants):  # pylint: disable=super-init-not-called |     def __init__(self, app: FastAPI, util, constants, glob_state):  # pylint: disable=super-init-not-called | ||||||
|         self.app = app |         self.app = app | ||||||
|         self.util = util |         self.util = util | ||||||
|         self.constants = constants |         self.constants = constants | ||||||
|  |         self.glob_state = glob_state | ||||||
|         self.lastfm = importlib.import_module("lastfm_wrapper").LastFM() |         self.lastfm = importlib.import_module("lastfm_wrapper").LastFM() | ||||||
|  |  | ||||||
|         self.endpoints = { |         self.endpoints = { | ||||||
|   | |||||||
| @@ -37,10 +37,11 @@ class ValidLyricRequest(BaseModel): | |||||||
|  |  | ||||||
| class LyricSearch(FastAPI): | class LyricSearch(FastAPI): | ||||||
|     """Lyric Search Endpoint""" |     """Lyric Search Endpoint""" | ||||||
|     def __init__(self, app: FastAPI, util, constants): # pylint: disable=super-init-not-called |     def __init__(self, app: FastAPI, util, constants, glob_state): # pylint: disable=super-init-not-called | ||||||
|         self.app = app |         self.app = app | ||||||
|         self.util = util |         self.util = util | ||||||
|         self.constants = constants |         self.constants = constants | ||||||
|  |         self.glob_state = glob_state | ||||||
|         self.lyrics_engine = importlib.import_module("lyrics_engine").LyricsEngine() |         self.lyrics_engine = importlib.import_module("lyrics_engine").LyricsEngine() | ||||||
|  |  | ||||||
|         self.endpoint_name = "lyric_search" |         self.endpoint_name = "lyric_search" | ||||||
| @@ -87,7 +88,9 @@ class LyricSearch(FastAPI): | |||||||
|  |  | ||||||
|         src = data.src.upper() |         src = data.src.upper() | ||||||
|         if not src in self.acceptable_request_sources: |         if not src in self.acceptable_request_sources: | ||||||
|             raise HTTPException(detail="Invalid request source", status_code=403)            |             raise HTTPException(detail="Invalid request source", status_code=403) | ||||||
|  |  | ||||||
|  |         await self.glob_state.increment_counter('lyric_requests')            | ||||||
|  |  | ||||||
|         search_artist = data.a |         search_artist = data.a | ||||||
|         search_song = data.s |         search_song = data.s | ||||||
| @@ -128,6 +131,7 @@ class LyricSearch(FastAPI): | |||||||
|         recipient='anyone') |         recipient='anyone') | ||||||
|  |  | ||||||
|         if not search_worker or not 'l' in search_worker.keys(): |         if not search_worker or not 'l' in search_worker.keys(): | ||||||
|  |             await self.glob_state.increment_counter('failedlyric_requests') | ||||||
|             return { |             return { | ||||||
|                 'err': True, |                 'err': True, | ||||||
|                 'errorText': 'Sources exhausted, lyrics not located.' |                 'errorText': 'Sources exhausted, lyrics not located.' | ||||||
| @@ -141,4 +145,5 @@ class LyricSearch(FastAPI): | |||||||
|             'lyrics': regex.sub(r"\s/\s", "<br>", " ".join(search_worker['l'])), |             'lyrics': regex.sub(r"\s/\s", "<br>", " ".join(search_worker['l'])), | ||||||
|             'from_cache': search_worker['method'].strip().lower().startswith("local cache"), |             'from_cache': search_worker['method'].strip().lower().startswith("local cache"), | ||||||
|             'src': search_worker['method'] if add_extras else None, |             'src': search_worker['method'] if add_extras else None, | ||||||
|  |             'reqn': await self.glob_state.get_counter('lyric_requests') | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -9,11 +9,12 @@ from fastapi import FastAPI | |||||||
|  |  | ||||||
| class RandMsg(FastAPI): | class RandMsg(FastAPI): | ||||||
|     """Random Message Endpoint"""    |     """Random Message Endpoint"""    | ||||||
|     def __init__(self, app: FastAPI, util, constants):  # pylint: disable=super-init-not-called |     def __init__(self, app: FastAPI, util, constants, glob_state):  # pylint: disable=super-init-not-called | ||||||
|         self.app = app |         self.app = app | ||||||
|         self.util = util |         self.util = util | ||||||
|         self.constants = constants |         self.constants = constants | ||||||
|  |         self.glob_state = glob_state | ||||||
|  |          | ||||||
|         self.endpoint_name = "randmsg" |         self.endpoint_name = "randmsg" | ||||||
|  |  | ||||||
|         app.add_api_route(f"/{self.endpoint_name}/", self.randmsg_handler, methods=["POST"]) |         app.add_api_route(f"/{self.endpoint_name}/", self.randmsg_handler, methods=["POST"]) | ||||||
|   | |||||||
| @@ -23,10 +23,11 @@ class ValidShowEpisodeLineRequest(BaseModel): | |||||||
|  |  | ||||||
| class Transcriptions(FastAPI): | class Transcriptions(FastAPI): | ||||||
|     """Transcription Endpoints"""    |     """Transcription Endpoints"""    | ||||||
|     def __init__(self, app: FastAPI, util, constants):  # pylint: disable=super-init-not-called |     def __init__(self, app: FastAPI, util, constants, glob_state):  # pylint: disable=super-init-not-called | ||||||
|         self.app = app |         self.app = app | ||||||
|         self.util = util |         self.util = util | ||||||
|         self.constants = constants |         self.constants = constants | ||||||
|  |         self.glob_state = glob_state | ||||||
|  |  | ||||||
|         self.endpoints = { |         self.endpoints = { | ||||||
|             "transcriptions/get_episodes": self.get_episodes_handler, |             "transcriptions/get_episodes": self.get_episodes_handler, | ||||||
| @@ -79,6 +80,7 @@ class Transcriptions(FastAPI): | |||||||
|                     'err': True, |                     'err': True, | ||||||
|                     'errorText': 'Unknown error.' |                     'errorText': 'Unknown error.' | ||||||
|                 } |                 } | ||||||
|  |         await self.glob_state.increment_counter('transcript_list_requests') | ||||||
|         async with sqlite3.connect(database=db_path, timeout=1) as _db: |         async with sqlite3.connect(database=db_path, timeout=1) as _db: | ||||||
|             async with _db.execute(db_query) as _cursor: |             async with _db.execute(db_query) as _cursor: | ||||||
|                 result = await _cursor.fetchall() |                 result = await _cursor.fetchall() | ||||||
| @@ -114,7 +116,8 @@ class Transcriptions(FastAPI): | |||||||
|                     'err': True, |                     'err': True, | ||||||
|                     'errorText': 'Unknown error' |                     'errorText': 'Unknown error' | ||||||
|                 }         |                 }         | ||||||
|              |          | ||||||
|  |         await self.glob_state.increment_counter('transcript_requests') | ||||||
|         async with sqlite3.connect(database=db_path, timeout=1) as _db: |         async with sqlite3.connect(database=db_path, timeout=1) as _db: | ||||||
|             params = (episode_id,) |             params = (episode_id,) | ||||||
|             async with _db.execute(db_query, params) as _cursor: |             async with _db.execute(db_query, params) as _cursor: | ||||||
|   | |||||||
| @@ -14,10 +14,11 @@ class ValidYTSearchRequest(BaseModel): | |||||||
|  |  | ||||||
| class YT(FastAPI): | class YT(FastAPI): | ||||||
|     """YT Endpoints"""    |     """YT Endpoints"""    | ||||||
|     def __init__(self, app: FastAPI, util, constants):  # pylint: disable=super-init-not-called |     def __init__(self, app: FastAPI, util, constants, glob_state):  # pylint: disable=super-init-not-called | ||||||
|         self.app = app |         self.app = app | ||||||
|         self.util = util |         self.util = util | ||||||
|         self.constants = constants |         self.constants = constants | ||||||
|  |         self.glob_state = glob_state | ||||||
|         self.ytsearch = importlib.import_module("youtube_search_async").YoutubeSearch() |         self.ytsearch = importlib.import_module("youtube_search_async").YoutubeSearch() | ||||||
|  |  | ||||||
|         self.endpoints = { |         self.endpoints = { | ||||||
|   | |||||||
							
								
								
									
										109
									
								
								state.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								state.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,109 @@ | |||||||
|  | #!/usr/bin/env python3.12 | ||||||
|  |  | ||||||
|  | """Global State Storage/Counters""" | ||||||
|  |  | ||||||
|  | import logging | ||||||
|  | import os | ||||||
|  | import aiosqlite as sqlite3 | ||||||
|  |  | ||||||
|  | from fastapi import FastAPI | ||||||
|  | from fastapi_utils.tasks import repeat_every | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class State(FastAPI): | ||||||
|  |     def __init__(self, app: FastAPI, util, constants): | ||||||
|  |         self.counter_db_path = os.path.join("/", "var", "lib", "singerdbs", "stats.db") | ||||||
|  |         self.counters = { | ||||||
|  |             str(counter): 0 for counter in constants.AVAILABLE_COUNTERS | ||||||
|  |         } | ||||||
|  |         self.counters_initialized = False | ||||||
|  |         logging.debug("[State] Counters: %s", self.counters) | ||||||
|  |  | ||||||
|  |         @app.on_event("startup") | ||||||
|  |         async def get_counters(): | ||||||
|  |             logging.info("[State] Initializing counters...") | ||||||
|  |             async with sqlite3.connect(self.counter_db_path, timeout=2) as _db: | ||||||
|  |                 _query = "SELECT ai_requests, lyric_requests, transcript_list_requests, transcript_requests, lyrichistory_requests, \ | ||||||
|  |                     failedlyric_requests, misc_failures, claude_ai_requests FROM counters LIMIT 1" | ||||||
|  |                 async with _db.execute(_query) as _cursor: | ||||||
|  |                     _result = await _cursor.fetchone() | ||||||
|  |                     (ai_requests, | ||||||
|  |                      lyric_requests, | ||||||
|  |                      transcript_list_requests, | ||||||
|  |                      transcript_requests, | ||||||
|  |                      lyrichistory_requests, | ||||||
|  |                      failedlyric_requests, | ||||||
|  |                      misc_failures, | ||||||
|  |                      claude_ai_requests) = _result | ||||||
|  |                 self.counters = { | ||||||
|  |                     'ai_requests': ai_requests, | ||||||
|  |                     'lyric_requests': lyric_requests, | ||||||
|  |                     'transcript_list_requests': transcript_list_requests, | ||||||
|  |                     'transcript_requests': transcript_requests, | ||||||
|  |                     'lyrichistory_requests': lyrichistory_requests, | ||||||
|  |                     'failedlyric_requests': failedlyric_requests, | ||||||
|  |                     'misc_failures': misc_failures, | ||||||
|  |                     'claude_ai_requests': claude_ai_requests | ||||||
|  |                 }  | ||||||
|  |                 self.counters_initialized = True | ||||||
|  |                 logging.info("Counters loaded from db: %s", self.counters)                    | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         @app.on_event("startup") | ||||||
|  |         @repeat_every(seconds=10) | ||||||
|  |         async def update_db(): | ||||||
|  |             if self.counters_initialized == False: | ||||||
|  |                 logging.debug("[State] TICK: Counters not yet initialized") | ||||||
|  |                 return | ||||||
|  |              | ||||||
|  |             ai_requests = self.counters.get('ai_requests') | ||||||
|  |             lyric_requests = self.counters.get('lyric_requests') | ||||||
|  |             transcript_list_requests = self.counters.get('transcript_list_requests') | ||||||
|  |             transcript_requests = self.counters.get('transcript_requests') | ||||||
|  |             lyrichistory_requests = self.counters.get('lyrichistory_requests') | ||||||
|  |             failedlyric_requests = self.counters.get('failedlyric_requests') | ||||||
|  |             claude_ai_requests = self.counters.get('claude_ai_requests')             | ||||||
|  |              | ||||||
|  |             async with sqlite3.connect(self.counter_db_path, timeout=2) as _db: | ||||||
|  |                 _query = "UPDATE counters SET ai_requests = ?, lyric_requests = ?, transcript_list_requests = ?, \ | ||||||
|  |                     transcript_requests = ?, lyrichistory_requests = ?, failedlyric_requests = ?, \ | ||||||
|  |                         claude_ai_requests = ?" | ||||||
|  |  | ||||||
|  |                 _params = (ai_requests, | ||||||
|  |                            lyric_requests, | ||||||
|  |                            transcript_list_requests, | ||||||
|  |                            transcript_requests, | ||||||
|  |                            lyrichistory_requests, | ||||||
|  |                            failedlyric_requests, | ||||||
|  |                            claude_ai_requests) | ||||||
|  |                 async with _db.execute(_query, _params) as _cursor: | ||||||
|  |                     if _cursor.rowcount != 1: | ||||||
|  |                         logging.error("Failed to update DB") | ||||||
|  |                         return | ||||||
|  |                     await _db.commit() | ||||||
|  |                     logging.debug("[State] Updated DB") | ||||||
|  |                      | ||||||
|  |  | ||||||
|  |              | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     async def increment_counter(self, counter: str): | ||||||
|  |         if not(counter in self.counters.keys()): | ||||||
|  |             raise BaseException("[State] Counter %s does not exist", counter) | ||||||
|  |          | ||||||
|  |         self.counters[counter] += 1 | ||||||
|  |         return True | ||||||
|  |      | ||||||
|  |     async def get_counter(self, counter: str): | ||||||
|  |         if not(counter in self.counters.keys()): | ||||||
|  |             raise BaseException("[State] Counter %s does not exist", counter) | ||||||
|  |  | ||||||
|  |         return self.counters[counter] | ||||||
|  |  | ||||||
|  |     async def get_all_counters(self): | ||||||
|  |         return self.counters             | ||||||
|  |  | ||||||
|  |      | ||||||
|  |  | ||||||
|  |  | ||||||
		Reference in New Issue
	
	Block a user