misc covid changes
This commit is contained in:
@ -9,6 +9,14 @@ from fastapi import FastAPI, Security, Request, HTTPException
|
||||
from fastapi.security import APIKeyHeader, APIKeyQuery
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ValidAISongRequest(BaseModel):
|
||||
"""
|
||||
- **a**: artist
|
||||
- **s**: track title
|
||||
"""
|
||||
|
||||
a: str
|
||||
s: str
|
||||
|
||||
class AI(FastAPI):
|
||||
"""AI Endpoints"""
|
||||
@ -17,18 +25,19 @@ class AI(FastAPI):
|
||||
self.util = my_util
|
||||
self.constants = constants
|
||||
self.glob_state = glob_state
|
||||
self.url_clean_regex = regex.compile(r'^\/ai\/')
|
||||
self.url_clean_regex = regex.compile(r'^\/ai\/openai\/')
|
||||
self.endpoints = {
|
||||
"ai": self.ai_handler,
|
||||
"ai/openai": self.ai_openai_handler,
|
||||
"ai/song": self.ai_song_handler
|
||||
#tbd
|
||||
}
|
||||
|
||||
for endpoint, handler in self.endpoints.items():
|
||||
app.add_api_route(f"/{endpoint}/{{any:path}}", handler, methods=["POST"])
|
||||
|
||||
async def ai_handler(self, request: Request):
|
||||
async def ai_openai_handler(self, request: Request):
|
||||
"""
|
||||
/ai/
|
||||
/ai/openai/
|
||||
AI Request
|
||||
"""
|
||||
|
||||
@ -59,4 +68,53 @@ class AI(FastAPI):
|
||||
return {
|
||||
'err': True,
|
||||
'errorText': 'General Failure'
|
||||
}
|
||||
}
|
||||
|
||||
async def ai_song_handler(self, data: ValidAISongRequest, request: Request):
|
||||
"""
|
||||
/ai/song/
|
||||
AI (Song Info) Request [Public]
|
||||
"""
|
||||
|
||||
ai_question = f"I am going to listen to a song titled \"{data.s}\" by \"{data.a}\"."
|
||||
|
||||
local_llm_headers = {
|
||||
'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
|
||||
}
|
||||
ai_req_data = {
|
||||
'max_context_length': 16784,
|
||||
'max_length': 256,
|
||||
'temperature': 0.1,
|
||||
'min_p': 0.1,
|
||||
'quiet': 0,
|
||||
'rep_pen': 1.0,
|
||||
'rep_pen_range': 600,
|
||||
'rep_pen_slope': 0.1,
|
||||
'bypass_eos': False,
|
||||
'trim_stop': True,
|
||||
'top_k': 90,
|
||||
'top_p': 1,
|
||||
'sampler_order': [6,0,1,3,4,2,5],
|
||||
'smoothing_factor': 0.06,
|
||||
'memory': "You are a helpful assistant who will provide only totally accurate tidbits of info on songs the user may listen to. You do not include information about which album a song was released on, or when it was released, and do not mention that you are not including this information in your response. No small talk, no introductions, just give info on the song!",
|
||||
'stop': ['### Inst', '### Resp'],
|
||||
'prompt': ai_question
|
||||
}
|
||||
try:
|
||||
async with ClientSession() as session:
|
||||
async with await session.post(f'{self.constants.LOCAL_LLM_BASE}/generate',
|
||||
json=ai_req_data,
|
||||
headers=local_llm_headers,
|
||||
timeout=ClientTimeout(connect=15, sock_read=30)) as request:
|
||||
await self.glob_state.increment_counter('ai_requests')
|
||||
response = await request.json()
|
||||
result = {
|
||||
'resp': response.get('results')[0].get('text').strip()
|
||||
}
|
||||
return result
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logging.error("Error: %s", e)
|
||||
return {
|
||||
'err': True,
|
||||
'errorText': 'General Failure'
|
||||
}
|
Reference in New Issue
Block a user