diff --git a/base.py b/base.py index 1aaf3e1..d9cf1a3 100644 --- a/base.py +++ b/base.py @@ -5,16 +5,22 @@ import logging from typing import Any from fastapi import FastAPI +from fastapi.security import APIKeyHeader, APIKeyQuery from fastapi.middleware.cors import CORSMiddleware logger = logging.getLogger() logger.setLevel(logging.DEBUG) app = FastAPI() -util = importlib.import_module("util").Utilities() + constants = importlib.import_module("constants").Constants() +util = importlib.import_module("util").Utilities(app, constants) glob_state = importlib.import_module("state").State(app, util, constants) +api_key_query = APIKeyQuery(name=constants.API_KEY_NAME, auto_error=False) +api_key_header = APIKeyQuery(name=f"x-{constants.API_KEY_NAME}", auto_error=False) + + origins = [ "https://codey.lol", ] @@ -26,8 +32,6 @@ allow_methods=["POST"], allow_headers=["*"]) - - """ Blacklisted routes """ @@ -37,7 +41,7 @@ def disallow_get(): return util.get_blocked_response() @app.get("/{any}") -def disallow_get_any(var: Any): +def disallow_get_any(var: Any = None): return util.get_blocked_response() @app.post("/") @@ -56,6 +60,7 @@ Actionable Routes counter_endpoints = importlib.import_module("endpoints.counters").Counters(app, util, constants, glob_state) randmsg_endpoint = importlib.import_module("endpoints.rand_msg").RandMsg(app, util, constants, glob_state) transcription_endpoints = importlib.import_module("endpoints.transcriptions").Transcriptions(app, util, constants, glob_state) +ai_endpoints = importlib.import_module("endpoints.ai").AI(app, util, constants, glob_state) # Below also provides: /lyric_cache_list/ (in addition to /lyric_search/) lyric_search_endpoint = importlib.import_module("endpoints.lyric_search").LyricSearch(app, util, constants, glob_state) # Below provides numerous LastFM-fed endpoints diff --git a/endpoints/ai.py b/endpoints/ai.py new file mode 100644 index 0000000..174ed02 --- /dev/null +++ b/endpoints/ai.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3.12 + +import importlib +import logging +import regex + +from aiohttp import ClientSession, ClientTimeout + +from fastapi import FastAPI, Security, Request, HTTPException +from fastapi.security import APIKeyHeader, APIKeyQuery + + +from pydantic import BaseModel + + +api_key_header = APIKeyHeader(name="X-Authd-With") + + + +class AI(FastAPI): + """AI Endpoints""" + def __init__(self, app: FastAPI, my_util, constants, glob_state): # pylint: disable=super-init-not-called + self.app = app + self.util = my_util + self.constants = constants + self.glob_state = glob_state + self.url_clean_regex = regex.compile(r'^\/ai\/') + self.endpoints = { + "ai": self.ai_handler, + #tbd + } + + for endpoint, handler in self.endpoints.items(): + app.add_api_route(f"/{endpoint}/{{any:path}}", handler, methods=["POST"]) + + async def ai_handler(self, request: Request): + """ + /ai/ + AI Request + """ + + if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')): + raise HTTPException(status_code=403, detail="Unauthorized") + + + """ + TODO: Implement Claude + Currently only routes to local LLM + """ + + local_llm_headers = { + 'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}' + } + + forward_path = self.url_clean_regex.sub('', request.url.path) + + print(f"Original path: {request.url.path}; Forward path: {forward_path}") + print(f"Request data: {await request.json()}") + try: + async with ClientSession() as session: + async with await session.post(f'{self.constants.LOCAL_LLM_HOST}/{forward_path}', + json=await request.json(), + headers=local_llm_headers, + timeout=ClientTimeout(connect=15, sock_read=30)) as request: + await self.glob_state.increment_counter('ai_requests') + response = await request.json() + print(f"Response received: {response}") + return response + except Exception as e: # pylint: disable=broad-exception-caught + logging.error("Error: %s", e) + return { + 'err': True, + 'errorText': 'General Failure' + } \ No newline at end of file diff --git a/util.py b/util.py index 50b1b58..1e368f4 100644 --- a/util.py +++ b/util.py @@ -2,13 +2,20 @@ import logging -from fastapi import Response +from fastapi import FastAPI, Response, HTTPException, Security +from fastapi.security import APIKeyHeader, APIKeyQuery + +global api_key_query +global api_key_header class Utilities: - def __init__(self): + def __init__(self, app: FastAPI, constants): + self.constants = constants self.blocked_response_status = 422 self.blocked_response_content = None - pass + + self.api_key_query = APIKeyQuery(name=constants.API_KEY_NAME, auto_error=False) + self.api_key_header = APIKeyHeader(name=f"x-{constants.API_KEY_NAME}", auto_error=False) def get_blocked_response(self, path: str | None = None): logging.error("Rejected request: Blocked") @@ -19,5 +26,22 @@ class Utilities: logging.error("Rejected request: No such endpoint") raise HTTPException(detail="Unknown endpoint", status_code=404) + def check_key(self, path: str, key: str): + """ + Accepts path as an argument to allow fine tuning access for each API key, not currently in use. + """ + print(f"Testing with path: {path}, key: {key}") + + if not key or not key.startswith("Bearer "): + return False + + key = key.split("Bearer ", maxsplit=1)[1].strip() + + if not key in self.constants.API_KEYS: + print("Auth failed.") + return False + + print("Auth succeeded") + return True