api/endpoints/ai.py
2025-02-14 16:07:24 -05:00

191 lines
7.4 KiB
Python

#!/usr/bin/env python3.12
# pylint: disable=bare-except, broad-exception-caught, invalid-name
import logging
import traceback
import regex
from regex import Pattern
from typing import Union
from aiohttp import ClientSession, ClientTimeout
from fastapi import FastAPI, Request, HTTPException, BackgroundTasks
from .constructors import ValidHookSongRequest, ValidAISongRequest
class AI(FastAPI):
"""AI Endpoints"""
def __init__(self, app: FastAPI, my_util, constants): # pylint: disable=super-init-not-called
self.app = app
self.util = my_util
self.constants = constants
self.url_clean_regex: Pattern = regex.compile(r'^\/ai\/(openai|base)\/')
self.endpoints: dict = {
"ai/openai": self.ai_openai_handler,
"ai/base": self.ai_handler,
"ai/song": self.ai_song_handler,
"ai/hook": self.ai_hook_handler,
#tbd
}
for endpoint, handler in self.endpoints.items():
app.add_api_route(f"/{endpoint}", handler, methods=["GET", "POST"],
include_in_schema=False)
async def respond_via_webhook(self, data: ValidHookSongRequest, originalRequest: Request) -> bool:
"""Respond via Webhook"""
try:
logging.debug("Request received: %s", data)
data2 = data.copy()
del data2.hook
if not data.hook:
return False
response = await self.ai_song_handler(data2, originalRequest)
if not response.get('resp'):
logging.critical("NO RESP!")
return False
response = response.get('resp')
hook_data = {
'username': 'Claude',
"embeds": [{
"title": "Claude's Feedback",
"description": response,
"footer": {
"text": "Current model: claude-3-haiku-20240307",
}
}]
}
async with ClientSession() as session:
async with await session.post(data.hook, json=hook_data,
timeout=ClientTimeout(connect=5, sock_read=5), headers={
'content-type': 'application/json; charset=utf-8',}) as request:
request.raise_for_status()
return True
except:
traceback.print_exc()
return False
async def ai_handler(self, request: Request):
"""
/ai/base
AI BASE Request
(Requires key)
"""
if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')):
raise HTTPException(status_code=403, detail="Unauthorized")
local_llm_headers = {
'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
}
forward_path = self.url_clean_regex.sub('', request.url.path)
try:
async with ClientSession() as session:
async with await session.post(f'{self.constants.LOCAL_LLM_BASE}/{forward_path}',
json=await request.json(),
headers=local_llm_headers,
timeout=ClientTimeout(connect=15, sock_read=30)) as out_request:
response = await out_request.json()
return response
except Exception as e: # pylint: disable=broad-exception-caught
logging.error("Error: %s", e)
return {
'err': True,
'errorText': 'General Failure'
}
async def ai_openai_handler(self, request: Request):
"""
/ai/openai
AI Request
(Requires key)
"""
if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')):
raise HTTPException(status_code=403, detail="Unauthorized")
"""
TODO: Implement Claude
Currently only routes to local LLM
"""
local_llm_headers = {
'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
}
forward_path = self.url_clean_regex.sub('', request.url.path)
try:
async with ClientSession() as session:
async with await session.post(f'{self.constants.LOCAL_LLM_HOST}/{forward_path}',
json=await request.json(),
headers=local_llm_headers,
timeout=ClientTimeout(connect=15, sock_read=30)) as out_request:
response = await out_request.json()
return response
except Exception as e: # pylint: disable=broad-exception-caught
logging.error("Error: %s", e)
return {
'err': True,
'errorText': 'General Failure'
}
async def ai_hook_handler(self, data: ValidHookSongRequest, request: Request, background_tasks: BackgroundTasks):
"""AI Hook Handler"""
background_tasks.add_task(self.respond_via_webhook, data, request)
return {
'success': True,
}
async def ai_song_handler(self, data: Union[ValidAISongRequest, ValidHookSongRequest], request: Request):
"""
/ai/song
AI (Song Info) Request [Public]
"""
ai_prompt = "You are a helpful assistant who will provide tidbits of info on songs the user may listen to."
ai_question = f"I am going to listen to the song \"{data.s}\" by \"{data.a}\"."
local_llm_headers = {
'x-api-key': self.constants.CLAUDE_API_KEY,
'anthropic-version': '2023-06-01',
'content-type': 'application/json',
}
request_data = {
'model': 'claude-3-haiku-20240307',
'max_tokens': 512,
'temperature': 0.6,
'system': ai_prompt,
'messages': [
{
"role": "user",
"content": ai_question.strip(),
}
]
}
try:
async with ClientSession() as session:
async with await session.post('https://api.anthropic.com/v1/messages',
json=request_data,
headers=local_llm_headers,
timeout=ClientTimeout(connect=15, sock_read=30)) as aiohttp_request:
response = await aiohttp_request.json()
logging.debug("Response: %s",
response)
if response.get('type') == 'error':
error_type = response.get('error').get('type')
error_message = response.get('error').get('message')
result = {
'resp': f"{error_type} error ({error_message})"
}
else:
result = {
'resp': response.get('content')[0].get('text').strip()
}
return result
except Exception as e: # pylint: disable=broad-exception-caught
logging.error("Error: %s", e)
return {
'err': True,
'errorText': 'General Failure'
}