2024-08-14 22:43:20 -04:00
#!/usr/bin/env python3.12
import importlib
import logging
import regex
from aiohttp import ClientSession , ClientTimeout
from fastapi import FastAPI , Security , Request , HTTPException
from fastapi . security import APIKeyHeader , APIKeyQuery
from pydantic import BaseModel
2024-09-04 20:30:11 -04:00
class ValidAISongRequest ( BaseModel ) :
"""
- * * a * * : artist
- * * s * * : track title
"""
a : str
s : str
2024-08-14 22:43:20 -04:00
class AI ( FastAPI ) :
""" AI Endpoints """
def __init__ ( self , app : FastAPI , my_util , constants , glob_state ) : # pylint: disable=super-init-not-called
self . app = app
self . util = my_util
self . constants = constants
self . glob_state = glob_state
2024-09-04 20:30:11 -04:00
self . url_clean_regex = regex . compile ( r ' ^ \ /ai \ /openai \ / ' )
2024-08-14 22:43:20 -04:00
self . endpoints = {
2024-09-04 20:30:11 -04:00
" ai/openai " : self . ai_openai_handler ,
" ai/song " : self . ai_song_handler
2024-08-14 22:43:20 -04:00
#tbd
}
for endpoint , handler in self . endpoints . items ( ) :
app . add_api_route ( f " / { endpoint } / {{ any:path }} " , handler , methods = [ " POST " ] )
2024-09-04 20:30:11 -04:00
async def ai_openai_handler ( self , request : Request ) :
2024-08-14 22:43:20 -04:00
"""
2024-09-04 20:30:11 -04:00
/ ai / openai /
2024-08-14 22:43:20 -04:00
AI Request
"""
if not self . util . check_key ( request . url . path , request . headers . get ( ' X-Authd-With ' ) ) :
raise HTTPException ( status_code = 403 , detail = " Unauthorized " )
"""
TODO : Implement Claude
Currently only routes to local LLM
"""
local_llm_headers = {
' Authorization ' : f ' Bearer { self . constants . LOCAL_LLM_KEY } '
}
forward_path = self . url_clean_regex . sub ( ' ' , request . url . path )
try :
async with ClientSession ( ) as session :
async with await session . post ( f ' { self . constants . LOCAL_LLM_HOST } / { forward_path } ' ,
json = await request . json ( ) ,
headers = local_llm_headers ,
timeout = ClientTimeout ( connect = 15 , sock_read = 30 ) ) as request :
await self . glob_state . increment_counter ( ' ai_requests ' )
response = await request . json ( )
return response
except Exception as e : # pylint: disable=broad-exception-caught
logging . error ( " Error: %s " , e )
return {
' err ' : True ,
' errorText ' : ' General Failure '
2024-09-04 20:30:11 -04:00
}
async def ai_song_handler ( self , data : ValidAISongRequest , request : Request ) :
"""
/ ai / song /
AI ( Song Info ) Request [ Public ]
"""
ai_question = f " I am going to listen to a song titled \" { data . s } \" by \" { data . a } \" . "
local_llm_headers = {
' Authorization ' : f ' Bearer { self . constants . LOCAL_LLM_KEY } '
}
ai_req_data = {
' max_context_length ' : 16784 ,
' max_length ' : 256 ,
' temperature ' : 0.1 ,
' min_p ' : 0.1 ,
' quiet ' : 0 ,
' rep_pen ' : 1.0 ,
' rep_pen_range ' : 600 ,
' rep_pen_slope ' : 0.1 ,
' bypass_eos ' : False ,
' trim_stop ' : True ,
' top_k ' : 90 ,
' top_p ' : 1 ,
' sampler_order ' : [ 6 , 0 , 1 , 3 , 4 , 2 , 5 ] ,
' smoothing_factor ' : 0.06 ,
' memory ' : " You are a helpful assistant who will provide only totally accurate tidbits of info on songs the user may listen to. You do not include information about which album a song was released on, or when it was released, and do not mention that you are not including this information in your response. No small talk, no introductions, just give info on the song! " ,
' stop ' : [ ' ### Inst ' , ' ### Resp ' ] ,
' prompt ' : ai_question
}
try :
async with ClientSession ( ) as session :
async with await session . post ( f ' { self . constants . LOCAL_LLM_BASE } /generate ' ,
json = ai_req_data ,
headers = local_llm_headers ,
timeout = ClientTimeout ( connect = 15 , sock_read = 30 ) ) as request :
await self . glob_state . increment_counter ( ' ai_requests ' )
response = await request . json ( )
result = {
' resp ' : response . get ( ' results ' ) [ 0 ] . get ( ' text ' ) . strip ( )
}
return result
except Exception as e : # pylint: disable=broad-exception-caught
logging . error ( " Error: %s " , e )
return {
' err ' : True ,
' errorText ' : ' General Failure '
}