Resolves #15, addl unrelated changes

This commit is contained in:
codey 2024-08-14 22:43:20 -04:00
parent 05e99718f8
commit ebc9460b8d
3 changed files with 110 additions and 7 deletions

13
base.py
View File

@ -5,16 +5,22 @@ import logging
from typing import Any from typing import Any
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.security import APIKeyHeader, APIKeyQuery
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
app = FastAPI() app = FastAPI()
util = importlib.import_module("util").Utilities()
constants = importlib.import_module("constants").Constants() constants = importlib.import_module("constants").Constants()
util = importlib.import_module("util").Utilities(app, constants)
glob_state = importlib.import_module("state").State(app, util, constants) glob_state = importlib.import_module("state").State(app, util, constants)
api_key_query = APIKeyQuery(name=constants.API_KEY_NAME, auto_error=False)
api_key_header = APIKeyQuery(name=f"x-{constants.API_KEY_NAME}", auto_error=False)
origins = [ origins = [
"https://codey.lol", "https://codey.lol",
] ]
@ -26,8 +32,6 @@ allow_methods=["POST"],
allow_headers=["*"]) allow_headers=["*"])
""" """
Blacklisted routes Blacklisted routes
""" """
@ -37,7 +41,7 @@ def disallow_get():
return util.get_blocked_response() return util.get_blocked_response()
@app.get("/{any}") @app.get("/{any}")
def disallow_get_any(var: Any): def disallow_get_any(var: Any = None):
return util.get_blocked_response() return util.get_blocked_response()
@app.post("/") @app.post("/")
@ -56,6 +60,7 @@ Actionable Routes
counter_endpoints = importlib.import_module("endpoints.counters").Counters(app, util, constants, glob_state) counter_endpoints = importlib.import_module("endpoints.counters").Counters(app, util, constants, glob_state)
randmsg_endpoint = importlib.import_module("endpoints.rand_msg").RandMsg(app, util, constants, glob_state) randmsg_endpoint = importlib.import_module("endpoints.rand_msg").RandMsg(app, util, constants, glob_state)
transcription_endpoints = importlib.import_module("endpoints.transcriptions").Transcriptions(app, util, constants, glob_state) transcription_endpoints = importlib.import_module("endpoints.transcriptions").Transcriptions(app, util, constants, glob_state)
ai_endpoints = importlib.import_module("endpoints.ai").AI(app, util, constants, glob_state)
# Below also provides: /lyric_cache_list/ (in addition to /lyric_search/) # Below also provides: /lyric_cache_list/ (in addition to /lyric_search/)
lyric_search_endpoint = importlib.import_module("endpoints.lyric_search").LyricSearch(app, util, constants, glob_state) lyric_search_endpoint = importlib.import_module("endpoints.lyric_search").LyricSearch(app, util, constants, glob_state)
# Below provides numerous LastFM-fed endpoints # Below provides numerous LastFM-fed endpoints

74
endpoints/ai.py Normal file
View File

@ -0,0 +1,74 @@
#!/usr/bin/env python3.12
import importlib
import logging
import regex
from aiohttp import ClientSession, ClientTimeout
from fastapi import FastAPI, Security, Request, HTTPException
from fastapi.security import APIKeyHeader, APIKeyQuery
from pydantic import BaseModel
api_key_header = APIKeyHeader(name="X-Authd-With")
class AI(FastAPI):
"""AI Endpoints"""
def __init__(self, app: FastAPI, my_util, constants, glob_state): # pylint: disable=super-init-not-called
self.app = app
self.util = my_util
self.constants = constants
self.glob_state = glob_state
self.url_clean_regex = regex.compile(r'^\/ai\/')
self.endpoints = {
"ai": self.ai_handler,
#tbd
}
for endpoint, handler in self.endpoints.items():
app.add_api_route(f"/{endpoint}/{{any:path}}", handler, methods=["POST"])
async def ai_handler(self, request: Request):
"""
/ai/
AI Request
"""
if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')):
raise HTTPException(status_code=403, detail="Unauthorized")
"""
TODO: Implement Claude
Currently only routes to local LLM
"""
local_llm_headers = {
'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
}
forward_path = self.url_clean_regex.sub('', request.url.path)
print(f"Original path: {request.url.path}; Forward path: {forward_path}")
print(f"Request data: {await request.json()}")
try:
async with ClientSession() as session:
async with await session.post(f'{self.constants.LOCAL_LLM_HOST}/{forward_path}',
json=await request.json(),
headers=local_llm_headers,
timeout=ClientTimeout(connect=15, sock_read=30)) as request:
await self.glob_state.increment_counter('ai_requests')
response = await request.json()
print(f"Response received: {response}")
return response
except Exception as e: # pylint: disable=broad-exception-caught
logging.error("Error: %s", e)
return {
'err': True,
'errorText': 'General Failure'
}

30
util.py
View File

@ -2,13 +2,20 @@
import logging import logging
from fastapi import Response from fastapi import FastAPI, Response, HTTPException, Security
from fastapi.security import APIKeyHeader, APIKeyQuery
global api_key_query
global api_key_header
class Utilities: class Utilities:
def __init__(self): def __init__(self, app: FastAPI, constants):
self.constants = constants
self.blocked_response_status = 422 self.blocked_response_status = 422
self.blocked_response_content = None self.blocked_response_content = None
pass
self.api_key_query = APIKeyQuery(name=constants.API_KEY_NAME, auto_error=False)
self.api_key_header = APIKeyHeader(name=f"x-{constants.API_KEY_NAME}", auto_error=False)
def get_blocked_response(self, path: str | None = None): def get_blocked_response(self, path: str | None = None):
logging.error("Rejected request: Blocked") logging.error("Rejected request: Blocked")
@ -19,5 +26,22 @@ class Utilities:
logging.error("Rejected request: No such endpoint") logging.error("Rejected request: No such endpoint")
raise HTTPException(detail="Unknown endpoint", status_code=404) raise HTTPException(detail="Unknown endpoint", status_code=404)
def check_key(self, path: str, key: str):
"""
Accepts path as an argument to allow fine tuning access for each API key, not currently in use.
"""
print(f"Testing with path: {path}, key: {key}")
if not key or not key.startswith("Bearer "):
return False
key = key.split("Bearer ", maxsplit=1)[1].strip()
if not key in self.constants.API_KEYS:
print("Auth failed.")
return False
print("Auth succeeded")
return True