From 58a3f90b4c6e1c0864f89724c14dfc1cc77c8734 Mon Sep 17 00:00:00 2001 From: codey Date: Wed, 4 Sep 2024 20:30:11 -0400 Subject: [PATCH] misc covid changes --- endpoints/ai.py | 68 +++++++++++++++++++++++++++++++++++++++++++++---- state.py | 3 +++ 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/endpoints/ai.py b/endpoints/ai.py index 4cb7879..567d984 100644 --- a/endpoints/ai.py +++ b/endpoints/ai.py @@ -9,6 +9,14 @@ from fastapi import FastAPI, Security, Request, HTTPException from fastapi.security import APIKeyHeader, APIKeyQuery from pydantic import BaseModel +class ValidAISongRequest(BaseModel): + """ + - **a**: artist + - **s**: track title + """ + + a: str + s: str class AI(FastAPI): """AI Endpoints""" @@ -17,18 +25,19 @@ class AI(FastAPI): self.util = my_util self.constants = constants self.glob_state = glob_state - self.url_clean_regex = regex.compile(r'^\/ai\/') + self.url_clean_regex = regex.compile(r'^\/ai\/openai\/') self.endpoints = { - "ai": self.ai_handler, + "ai/openai": self.ai_openai_handler, + "ai/song": self.ai_song_handler #tbd } for endpoint, handler in self.endpoints.items(): app.add_api_route(f"/{endpoint}/{{any:path}}", handler, methods=["POST"]) - async def ai_handler(self, request: Request): + async def ai_openai_handler(self, request: Request): """ - /ai/ + /ai/openai/ AI Request """ @@ -59,4 +68,53 @@ class AI(FastAPI): return { 'err': True, 'errorText': 'General Failure' - } \ No newline at end of file + } + + async def ai_song_handler(self, data: ValidAISongRequest, request: Request): + """ + /ai/song/ + AI (Song Info) Request [Public] + """ + + ai_question = f"I am going to listen to a song titled \"{data.s}\" by \"{data.a}\"." + + local_llm_headers = { + 'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}' + } + ai_req_data = { + 'max_context_length': 16784, + 'max_length': 256, + 'temperature': 0.1, + 'min_p': 0.1, + 'quiet': 0, + 'rep_pen': 1.0, + 'rep_pen_range': 600, + 'rep_pen_slope': 0.1, + 'bypass_eos': False, + 'trim_stop': True, + 'top_k': 90, + 'top_p': 1, + 'sampler_order': [6,0,1,3,4,2,5], + 'smoothing_factor': 0.06, + 'memory': "You are a helpful assistant who will provide only totally accurate tidbits of info on songs the user may listen to. You do not include information about which album a song was released on, or when it was released, and do not mention that you are not including this information in your response. No small talk, no introductions, just give info on the song!", + 'stop': ['### Inst', '### Resp'], + 'prompt': ai_question + } + try: + async with ClientSession() as session: + async with await session.post(f'{self.constants.LOCAL_LLM_BASE}/generate', + json=ai_req_data, + headers=local_llm_headers, + timeout=ClientTimeout(connect=15, sock_read=30)) as request: + await self.glob_state.increment_counter('ai_requests') + response = await request.json() + result = { + 'resp': response.get('results')[0].get('text').strip() + } + return result + except Exception as e: # pylint: disable=broad-exception-caught + logging.error("Error: %s", e) + return { + 'err': True, + 'errorText': 'General Failure' + } \ No newline at end of file diff --git a/state.py b/state.py index 79488dd..cf05355 100644 --- a/state.py +++ b/state.py @@ -25,6 +25,7 @@ class State(FastAPI): async with sqlite3.connect(self.counter_db_path, timeout=2) as _db: _query = "SELECT ai_requests, lyric_requests, transcript_list_requests, transcript_requests, lyrichistory_requests, \ failedlyric_requests, misc_failures, claude_ai_requests FROM counters LIMIT 1" + await _db.executescript("pragma journal_mode = WAL; pragma synchronous = normal; pragma temp_store = memory; pragma mmap_size = 30000000000;") async with _db.execute(_query) as _cursor: _result = await _cursor.fetchone() (ai_requests, @@ -76,6 +77,8 @@ class State(FastAPI): lyrichistory_requests, failedlyric_requests, claude_ai_requests) + + await _db.executescript("pragma journal_mode = WAL; pragma synchronous = normal; pragma temp_store = memory; pragma mmap_size = 30000000000;") async with _db.execute(_query, _params) as _cursor: if _cursor.rowcount != 1: logging.error("Failed to update DB")