api/endpoints/ai.py
2025-02-16 08:50:53 -05:00

91 lines
3.4 KiB
Python

#!/usr/bin/env python3.12
import logging
import regex
from regex import Pattern
from typing import Union
from aiohttp import ClientSession, ClientTimeout
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
class AI(FastAPI):
"""AI Endpoints"""
def __init__(self, app: FastAPI,
my_util, constants):
self.app: FastAPI = app
self.util = my_util
self.constants = constants
self.url_clean_regex: Pattern = regex.compile(r'^\/ai\/(openai|base)\/')
self.endpoints: dict = {
"ai/openai": self.ai_openai_handler,
"ai/base": self.ai_handler,
#tbd
}
for endpoint, handler in self.endpoints.items():
app.add_api_route(f"/{endpoint}", handler, methods=["GET", "POST"],
include_in_schema=False)
async def ai_handler(self, request: Request) -> JSONResponse:
"""
/ai/base
AI BASE Request
(Requires key)
"""
if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')):
raise HTTPException(status_code=403, detail="Unauthorized")
local_llm_headers = {
'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
}
forward_path = self.url_clean_regex.sub('', request.url.path)
try:
async with ClientSession() as session:
async with await session.post(f'{self.constants.LOCAL_LLM_BASE}/{forward_path}',
json=await request.json(),
headers=local_llm_headers,
timeout=ClientTimeout(connect=15, sock_read=30)) as out_request:
response = await out_request.json()
return JSONResponse(content=response)
except Exception as e:
logging.error("Error: %s", e)
return JSONResponse(status_code=500, content={
'err': True,
'errorText': 'General Failure'
})
async def ai_openai_handler(self, request: Request) -> JSONResponse:
"""
/ai/openai
AI Request
(Requires key)
"""
if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')):
raise HTTPException(status_code=403, detail="Unauthorized")
"""
TODO: Implement Claude
Currently only routes to local LLM
"""
local_llm_headers = {
'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
}
forward_path = self.url_clean_regex.sub('', request.url.path)
try:
async with ClientSession() as session:
async with await session.post(f'{self.constants.LOCAL_LLM_HOST}/{forward_path}',
json=await request.json(),
headers=local_llm_headers,
timeout=ClientTimeout(connect=15, sock_read=30)) as out_request:
response = await out_request.json()
return JSONResponse(content=response)
except Exception as e:
logging.error("Error: %s", e)
return JSONResponse(status_code=500, content={
'err': True,
'errorText': 'General Failure'
})