This commit is contained in:
2025-01-11 20:59:10 -05:00
parent 85a0d6bc62
commit 3c57f13557
18 changed files with 464 additions and 365 deletions

View File

@@ -1,25 +1,39 @@
#!/usr/bin/env python3.12
# pylint: disable=bare-except, broad-exception-caught, invalid-name
import importlib
import logging
import traceback
import regex
from aiohttp import ClientSession, ClientTimeout
from fastapi import FastAPI, Security, Request, HTTPException
from fastapi.security import APIKeyHeader, APIKeyQuery
from fastapi import FastAPI, Request, HTTPException, BackgroundTasks
from pydantic import BaseModel
class ValidAISongRequest(BaseModel):
"""
- **a**: artist
- **s**: track title
"""
a: str
s: str
class ValidHookSongRequest(BaseModel):
"""
- **a**: artist
- **s**: track title
- **hook**: hook to return
"""
a: str
s: str
hook: str | None = ""
# pylint: enable=bad-indentation
class AI(FastAPI):
"""AI Endpoints"""
"""AI Endpoints"""
def __init__(self, app: FastAPI, my_util, constants, glob_state): # pylint: disable=super-init-not-called
self.app = app
self.util = my_util
@@ -29,12 +43,48 @@ class AI(FastAPI):
self.endpoints = {
"ai/openai": self.ai_openai_handler,
"ai/base": self.ai_handler,
"ai/song": self.ai_song_handler
"ai/song": self.ai_song_handler,
"ai/hook": self.ai_hook_handler,
#tbd
}
}
for endpoint, handler in self.endpoints.items():
app.add_api_route(f"/{endpoint}/{{any:path}}", handler, methods=["POST"])
app.add_api_route(f"/{endpoint}/openai/", handler, methods=["POST"])
async def respond_via_webhook(self, data: ValidHookSongRequest, originalRequest: Request):
"""Respond via Webhook"""
try:
logging.debug("Request received: %s", data)
data2 = data.copy()
del data2.hook
response = await self.ai_song_handler(data2, originalRequest)
if not response.get('resp'):
logging.critical("NO RESP!")
return
response = response.get('resp')
hook_data = {
'username': 'Claude',
"embeds": [{
"title": "Claude's Feedback",
"description": response,
"footer": {
"text": "Current model: claude-3-haiku-20240307",
}
}]
}
logging.critical("Request: %s", data)
async with ClientSession() as session:
async with session.post(data.hook, json=hook_data,
timeout=ClientTimeout(connect=5, sock_read=5), headers={
'content-type': 'application/json; charset=utf-8',}) as request:
logging.debug("Returned: %s",
await request.json())
await request.raise_for_status()
return True
except:
traceback.print_exc()
return False
async def ai_handler(self, request: Request):
"""
@@ -42,83 +92,80 @@ class AI(FastAPI):
AI BASE Request
(Requires key)
"""
if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')):
raise HTTPException(status_code=403, detail="Unauthorized")
local_llm_headers = {
'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
}
forward_path = self.url_clean_regex.sub('', request.url.path)
try:
async with ClientSession() as session:
async with await session.post(f'{self.constants.LOCAL_LLM_BASE}/{forward_path}',
json=await request.json(),
headers=local_llm_headers,
timeout=ClientTimeout(connect=15, sock_read=30)) as out_request:
await self.glob_state.increment_counter('ai_requests')
response = await out_request.json()
return response
async with ClientSession() as session:
async with await session.post(f'{self.constants.LOCAL_LLM_BASE}/{forward_path}',
json=await request.json(),
headers=local_llm_headers,
timeout=ClientTimeout(connect=15, sock_read=30)) as out_request:
await self.glob_state.increment_counter('ai_requests')
response = await out_request.json()
return response
except Exception as e: # pylint: disable=broad-exception-caught
logging.error("Error: %s", e)
return {
'err': True,
'errorText': 'General Failure'
}
logging.error("Error: %s", e)
return {
'err': True,
'errorText': 'General Failure'
}
async def ai_openai_handler(self, request: Request):
"""
/ai/openai/
AI Request
(Requires key)
"""
if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')):
raise HTTPException(status_code=403, detail="Unauthorized")
"""
TODO: Implement Claude
Currently only routes to local LLM
"""
local_llm_headers = {
'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
}
forward_path = self.url_clean_regex.sub('', request.url.path)
try:
async with ClientSession() as session:
async with await session.post(f'{self.constants.LOCAL_LLM_HOST}/{forward_path}',
json=await request.json(),
headers=local_llm_headers,
timeout=ClientTimeout(connect=15, sock_read=30)) as out_request:
await self.glob_state.increment_counter('ai_requests')
response = await out_request.json()
return response
async with ClientSession() as session:
async with await session.post(f'{self.constants.LOCAL_LLM_HOST}/{forward_path}',
json=await request.json(),
headers=local_llm_headers,
timeout=ClientTimeout(connect=15, sock_read=30)) as out_request:
await self.glob_state.increment_counter('ai_requests')
response = await out_request.json()
return response
except Exception as e: # pylint: disable=broad-exception-caught
logging.error("Error: %s", e)
return {
'err': True,
'errorText': 'General Failure'
}
logging.error("Error: %s", e)
return {
'err': True,
'errorText': 'General Failure'
}
"""
CLAUDE BELOW, COMMENTED
"""
async def ai_hook_handler(self, data: ValidHookSongRequest, request: Request, background_tasks: BackgroundTasks):
"""AI Hook Handler"""
background_tasks.add_task(self.respond_via_webhook, data, request)
return {
'success': True,
}
async def ai_song_handler(self, data: ValidAISongRequest, request: Request):
"""
/ai/song/
AI (Song Info) Request [Public]
"""
ai_prompt = "You are a helpful assistant who will provide tidbits of info on songs the user may listen to."
ai_question = f"I am going to listen to the song \"{data.s}\" by \"{data.a}\"."
local_llm_headers = {
'x-api-key': self.constants.CLAUDE_API_KEY,
'anthropic-version': '2023-06-01',
@@ -126,93 +173,42 @@ class AI(FastAPI):
}
request_data = {
'model': 'claude-3-haiku-20240307',
'max_tokens': 512,
'temperature': 0.6,
'system': ai_prompt,
'messages': [
{
"role": "user",
"content": ai_question.strip(),
}
]
'model': 'claude-3-haiku-20240307',
'max_tokens': 512,
'temperature': 0.6,
'system': ai_prompt,
'messages': [
{
"role": "user",
"content": ai_question.strip(),
}
]
}
try:
async with ClientSession() as session:
async with await session.post('https://api.anthropic.com/v1/messages',
json=request_data,
headers=local_llm_headers,
timeout=ClientTimeout(connect=15, sock_read=30)) as request:
await self.glob_state.increment_counter('claude_ai_requests')
response = await request.json()
print(f"Response: {response}")
if response.get('type') == 'error':
error_type = response.get('error').get('type')
error_message = response.get('error').get('message')
result = {
'resp': f"{error_type} error ({error_message})"
}
else:
result = {
'resp': response.get('content')[0].get('text').strip()
}
return result
async with ClientSession() as session:
async with await session.post('https://api.anthropic.com/v1/messages',
json=request_data,
headers=local_llm_headers,
timeout=ClientTimeout(connect=15, sock_read=30)) as request:
await self.glob_state.increment_counter('claude_ai_requests')
response = await request.json()
logging.debug("Response: %s",
response)
if response.get('type') == 'error':
error_type = response.get('error').get('type')
error_message = response.get('error').get('message')
result = {
'resp': f"{error_type} error ({error_message})"
}
else:
result = {
'resp': response.get('content')[0].get('text').strip()
}
return result
except Exception as e: # pylint: disable=broad-exception-caught
logging.error("Error: %s", e)
return {
'err': True,
'errorText': 'General Failure'
}
# async def ai_song_handler(self, data: ValidAISongRequest, request: Request):
# """
# /ai/song/
# AI (Song Info) Request [Public]
# """
# ai_question = f"I am going to listen to the song \"{data.s}\" by \"{data.a}\"."
# local_llm_headers = {
# 'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
# }
# ai_req_data = {
# 'max_context_length': 8192,
# 'max_length': 512,
# 'temperature': 0,
# 'n': 1,
# 'top_k': 30,
# 'top_a': 0,
# 'top_p': 0,
# 'typical': 0,
# 'mirostat': 0,
# 'use_default_badwordsids': False,
# 'rep_pen': 1.0,
# 'rep_pen_range': 320,
# 'rep_pen_slope': 0.05,
# 'quiet': 1,
# 'bypass_eos': False,
# # 'trim_stop': True,
# 'sampler_order': [6,0,1,3,4,2,5],
# 'memory': "You are a helpful assistant who will provide ONLY TOTALLY ACCURATE tidbits of info on songs the user may listen to. You do not include information about which album a song was released on, or when it was released, and do not mention that you are not including this information in your response. If the input provided is not a song you are aware of, simply state that. Begin your output at your own response.",
# 'stop': ['### Inst', '### Resp'],
# 'prompt': ai_question
# }
# try:
# async with ClientSession() as session:
# async with await session.post(f'{self.constants.LOCAL_LLM_BASE}/generate',
# json=ai_req_data,
# headers=local_llm_headers,
# timeout=ClientTimeout(connect=15, sock_read=30)) as request:
# await self.glob_state.increment_counter('ai_requests')
# response = await request.json()
# result = {
# 'resp': response.get('results')[0].get('text').strip()
# }
# return result
# except Exception as e: # pylint: disable=broad-exception-caught
# logging.error("Error: %s", e)
# return {
# 'err': True,
# 'errorText': 'General Failure'
# }
logging.error("Error: %s", e)
return {
'err': True,
'errorText': 'General Failure'
}