2024-08-14 22:43:20 -04:00
#!/usr/bin/env python3.12
import importlib
import logging
import regex
from aiohttp import ClientSession , ClientTimeout
from fastapi import FastAPI , Security , Request , HTTPException
from fastapi . security import APIKeyHeader , APIKeyQuery
from pydantic import BaseModel
2024-09-04 20:30:11 -04:00
class ValidAISongRequest ( BaseModel ) :
"""
- * * a * * : artist
- * * s * * : track title
"""
a : str
s : str
2024-08-14 22:43:20 -04:00
class AI ( FastAPI ) :
""" AI Endpoints """
def __init__ ( self , app : FastAPI , my_util , constants , glob_state ) : # pylint: disable=super-init-not-called
self . app = app
self . util = my_util
self . constants = constants
self . glob_state = glob_state
2024-09-04 20:30:11 -04:00
self . url_clean_regex = regex . compile ( r ' ^ \ /ai \ /openai \ / ' )
2024-08-14 22:43:20 -04:00
self . endpoints = {
2024-09-04 20:30:11 -04:00
" ai/openai " : self . ai_openai_handler ,
" ai/song " : self . ai_song_handler
2024-08-14 22:43:20 -04:00
#tbd
}
for endpoint , handler in self . endpoints . items ( ) :
app . add_api_route ( f " / { endpoint } / {{ any:path }} " , handler , methods = [ " POST " ] )
2024-09-04 20:30:11 -04:00
async def ai_openai_handler ( self , request : Request ) :
2024-08-14 22:43:20 -04:00
"""
2024-09-04 20:30:11 -04:00
/ ai / openai /
2024-08-14 22:43:20 -04:00
AI Request
"""
if not self . util . check_key ( request . url . path , request . headers . get ( ' X-Authd-With ' ) ) :
raise HTTPException ( status_code = 403 , detail = " Unauthorized " )
"""
TODO : Implement Claude
Currently only routes to local LLM
"""
local_llm_headers = {
' Authorization ' : f ' Bearer { self . constants . LOCAL_LLM_KEY } '
}
forward_path = self . url_clean_regex . sub ( ' ' , request . url . path )
try :
async with ClientSession ( ) as session :
async with await session . post ( f ' { self . constants . LOCAL_LLM_HOST } / { forward_path } ' ,
json = await request . json ( ) ,
headers = local_llm_headers ,
timeout = ClientTimeout ( connect = 15 , sock_read = 30 ) ) as request :
await self . glob_state . increment_counter ( ' ai_requests ' )
response = await request . json ( )
return response
except Exception as e : # pylint: disable=broad-exception-caught
logging . error ( " Error: %s " , e )
return {
' err ' : True ,
' errorText ' : ' General Failure '
2024-09-04 20:30:11 -04:00
}
async def ai_song_handler ( self , data : ValidAISongRequest , request : Request ) :
"""
/ ai / song /
AI ( Song Info ) Request [ Public ]
"""
2024-09-10 16:16:10 -04:00
ai_question = f " I am going to listen to the song \" { data . s } \" by \" { data . a } \" . "
2024-09-04 20:30:11 -04:00
local_llm_headers = {
' Authorization ' : f ' Bearer { self . constants . LOCAL_LLM_KEY } '
}
ai_req_data = {
' max_context_length ' : 16784 ,
' max_length ' : 256 ,
2024-09-10 16:16:10 -04:00
' temperature ' : 0.2 ,
' quiet ' : 1 ,
2024-09-04 20:30:11 -04:00
' bypass_eos ' : False ,
2024-09-10 16:16:10 -04:00
' trim_stop ' : True ,
2024-09-04 20:30:11 -04:00
' sampler_order ' : [ 6 , 0 , 1 , 3 , 4 , 2 , 5 ] ,
2024-09-10 16:16:10 -04:00
' memory ' : " You are a helpful assistant who will provide only totally accurate tidbits of info on songs the user may listen to. You do not include information about which album a song was released on, or when it was released, and do not mention that you are not including this information in your response. If the input provided is not a song you are aware of, simply state that. " ,
2024-09-04 20:30:11 -04:00
' stop ' : [ ' ### Inst ' , ' ### Resp ' ] ,
' prompt ' : ai_question
}
try :
async with ClientSession ( ) as session :
async with await session . post ( f ' { self . constants . LOCAL_LLM_BASE } /generate ' ,
json = ai_req_data ,
headers = local_llm_headers ,
timeout = ClientTimeout ( connect = 15 , sock_read = 30 ) ) as request :
await self . glob_state . increment_counter ( ' ai_requests ' )
response = await request . json ( )
result = {
' resp ' : response . get ( ' results ' ) [ 0 ] . get ( ' text ' ) . strip ( )
}
return result
except Exception as e : # pylint: disable=broad-exception-caught
logging . error ( " Error: %s " , e )
return {
' err ' : True ,
' errorText ' : ' General Failure '
}