2024-08-14 22:43:20 -04:00
#!/usr/bin/env python3.12
import importlib
import logging
import regex
from aiohttp import ClientSession , ClientTimeout
from fastapi import FastAPI , Security , Request , HTTPException
from fastapi . security import APIKeyHeader , APIKeyQuery
from pydantic import BaseModel
2024-09-04 20:30:11 -04:00
class ValidAISongRequest ( BaseModel ) :
"""
- * * a * * : artist
- * * s * * : track title
"""
a : str
s : str
2024-08-14 22:43:20 -04:00
class AI ( FastAPI ) :
""" AI Endpoints """
def __init__ ( self , app : FastAPI , my_util , constants , glob_state ) : # pylint: disable=super-init-not-called
self . app = app
self . util = my_util
self . constants = constants
self . glob_state = glob_state
2024-10-02 20:54:34 -04:00
self . url_clean_regex = regex . compile ( r ' ^ \ /ai \ /(openai|base) \ / ' )
2024-08-14 22:43:20 -04:00
self . endpoints = {
2024-11-14 14:37:32 -05:00
" ai/openai " : self . ai_openai_handler ,
" ai/base " : self . ai_handler ,
2024-09-04 20:30:11 -04:00
" ai/song " : self . ai_song_handler
2024-08-14 22:43:20 -04:00
#tbd
}
for endpoint , handler in self . endpoints . items ( ) :
app . add_api_route ( f " / { endpoint } / {{ any:path }} " , handler , methods = [ " POST " ] )
2024-10-02 20:54:34 -04:00
2024-11-14 14:37:32 -05:00
async def ai_handler ( self , request : Request ) :
"""
/ ai / base /
AI BASE Request
2024-11-29 15:33:12 -05:00
( Requires key )
2024-11-14 14:37:32 -05:00
"""
2024-10-02 20:54:34 -04:00
2024-11-14 14:37:32 -05:00
if not self . util . check_key ( request . url . path , request . headers . get ( ' X-Authd-With ' ) ) :
raise HTTPException ( status_code = 403 , detail = " Unauthorized " )
2024-10-02 20:54:34 -04:00
2024-11-14 14:37:32 -05:00
local_llm_headers = {
' Authorization ' : f ' Bearer { self . constants . LOCAL_LLM_KEY } '
}
forward_path = self . url_clean_regex . sub ( ' ' , request . url . path )
try :
async with ClientSession ( ) as session :
async with await session . post ( f ' { self . constants . LOCAL_LLM_BASE } / { forward_path } ' ,
json = await request . json ( ) ,
headers = local_llm_headers ,
timeout = ClientTimeout ( connect = 15 , sock_read = 30 ) ) as out_request :
await self . glob_state . increment_counter ( ' ai_requests ' )
response = await out_request . json ( )
return response
except Exception as e : # pylint: disable=broad-exception-caught
logging . error ( " Error: %s " , e )
return {
' err ' : True ,
' errorText ' : ' General Failure '
}
2024-08-14 22:43:20 -04:00
2024-11-14 14:37:32 -05:00
async def ai_openai_handler ( self , request : Request ) :
"""
/ ai / openai /
AI Request
2024-11-29 15:33:12 -05:00
( Requires key )
2024-11-14 14:37:32 -05:00
"""
2024-08-14 22:43:20 -04:00
2024-11-14 14:37:32 -05:00
if not self . util . check_key ( request . url . path , request . headers . get ( ' X-Authd-With ' ) ) :
raise HTTPException ( status_code = 403 , detail = " Unauthorized " )
2024-08-14 22:43:20 -04:00
2024-11-14 14:37:32 -05:00
"""
TODO : Implement Claude
Currently only routes to local LLM
"""
local_llm_headers = {
' Authorization ' : f ' Bearer { self . constants . LOCAL_LLM_KEY } '
}
forward_path = self . url_clean_regex . sub ( ' ' , request . url . path )
try :
async with ClientSession ( ) as session :
async with await session . post ( f ' { self . constants . LOCAL_LLM_HOST } / { forward_path } ' ,
json = await request . json ( ) ,
headers = local_llm_headers ,
timeout = ClientTimeout ( connect = 15 , sock_read = 30 ) ) as out_request :
await self . glob_state . increment_counter ( ' ai_requests ' )
response = await out_request . json ( )
return response
except Exception as e : # pylint: disable=broad-exception-caught
logging . error ( " Error: %s " , e )
return {
' err ' : True ,
' errorText ' : ' General Failure '
}
"""
CLAUDE BELOW , COMMENTED
"""
2024-08-14 22:43:20 -04:00
2024-09-04 20:30:11 -04:00
async def ai_song_handler ( self , data : ValidAISongRequest , request : Request ) :
"""
/ ai / song /
AI ( Song Info ) Request [ Public ]
"""
2024-10-02 20:54:34 -04:00
ai_prompt = " You are a helpful assistant who will provide tidbits of info on songs the user may listen to. "
2024-09-10 16:16:10 -04:00
ai_question = f " I am going to listen to the song \" { data . s } \" by \" { data . a } \" . "
2024-09-04 20:30:11 -04:00
2024-10-02 20:54:34 -04:00
2024-09-04 20:30:11 -04:00
local_llm_headers = {
2024-11-14 14:37:32 -05:00
' x-api-key ' : self . constants . CLAUDE_API_KEY ,
2024-10-02 20:54:34 -04:00
' anthropic-version ' : ' 2023-06-01 ' ,
' content-type ' : ' application/json ' ,
2024-09-04 20:30:11 -04:00
}
2024-10-02 20:54:34 -04:00
request_data = {
' model ' : ' claude-3-haiku-20240307 ' ,
' max_tokens ' : 512 ,
' temperature ' : 0.6 ,
' system ' : ai_prompt ,
' messages ' : [
{
" role " : " user " ,
" content " : ai_question . strip ( ) ,
}
]
2024-09-04 20:30:11 -04:00
}
2024-10-02 20:54:34 -04:00
2024-09-04 20:30:11 -04:00
try :
async with ClientSession ( ) as session :
2024-10-02 20:54:34 -04:00
async with await session . post ( ' https://api.anthropic.com/v1/messages ' ,
json = request_data ,
2024-09-04 20:30:11 -04:00
headers = local_llm_headers ,
timeout = ClientTimeout ( connect = 15 , sock_read = 30 ) ) as request :
2024-10-02 20:54:34 -04:00
await self . glob_state . increment_counter ( ' claude_ai_requests ' )
2024-09-04 20:30:11 -04:00
response = await request . json ( )
2024-10-02 20:54:34 -04:00
print ( f " Response: { response } " )
2024-11-29 15:33:12 -05:00
if response . get ( ' type ' ) == ' error ' :
error_type = response . get ( ' error ' ) . get ( ' type ' )
error_message = response . get ( ' error ' ) . get ( ' message ' )
result = {
' resp ' : f " { error_type } error ( { error_message } ) "
}
else :
result = {
' resp ' : response . get ( ' content ' ) [ 0 ] . get ( ' text ' ) . strip ( )
}
2024-09-04 20:30:11 -04:00
return result
except Exception as e : # pylint: disable=broad-exception-caught
logging . error ( " Error: %s " , e )
return {
' err ' : True ,
' errorText ' : ' General Failure '
2024-11-14 14:37:32 -05:00
}
# async def ai_song_handler(self, data: ValidAISongRequest, request: Request):
# """
# /ai/song/
# AI (Song Info) Request [Public]
# """
# ai_question = f"I am going to listen to the song \"{data.s}\" by \"{data.a}\"."
# local_llm_headers = {
# 'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
# }
# ai_req_data = {
# 'max_context_length': 8192,
2024-11-17 13:41:20 -05:00
# 'max_length': 512,
# 'temperature': 0,
# 'n': 1,
# 'top_k': 30,
# 'top_a': 0,
# 'top_p': 0,
# 'typical': 0,
# 'mirostat': 0,
# 'use_default_badwordsids': False,
# 'rep_pen': 1.0,
# 'rep_pen_range': 320,
# 'rep_pen_slope': 0.05,
2024-11-14 14:37:32 -05:00
# 'quiet': 1,
# 'bypass_eos': False,
2024-11-17 13:41:20 -05:00
# # 'trim_stop': True,
2024-11-14 14:37:32 -05:00
# 'sampler_order': [6,0,1,3,4,2,5],
2024-11-17 13:41:20 -05:00
# 'memory': "You are a helpful assistant who will provide ONLY TOTALLY ACCURATE tidbits of info on songs the user may listen to. You do not include information about which album a song was released on, or when it was released, and do not mention that you are not including this information in your response. If the input provided is not a song you are aware of, simply state that. Begin your output at your own response.",
2024-11-14 14:37:32 -05:00
# 'stop': ['### Inst', '### Resp'],
# 'prompt': ai_question
# }
# try:
# async with ClientSession() as session:
# async with await session.post(f'{self.constants.LOCAL_LLM_BASE}/generate',
# json=ai_req_data,
# headers=local_llm_headers,
# timeout=ClientTimeout(connect=15, sock_read=30)) as request:
# await self.glob_state.increment_counter('ai_requests')
# response = await request.json()
# result = {
# 'resp': response.get('results')[0].get('text').strip()
# }
# return result
# except Exception as e: # pylint: disable=broad-exception-caught
# logging.error("Error: %s", e)
# return {
# 'err': True,
# 'errorText': 'General Failure'
# }