This commit is contained in:
codey 2024-08-14 22:47:48 -04:00
parent ebc9460b8d
commit 56ae0071fa
3 changed files with 1 additions and 22 deletions

View File

@ -17,9 +17,6 @@ constants = importlib.import_module("constants").Constants()
util = importlib.import_module("util").Utilities(app, constants) util = importlib.import_module("util").Utilities(app, constants)
glob_state = importlib.import_module("state").State(app, util, constants) glob_state = importlib.import_module("state").State(app, util, constants)
api_key_query = APIKeyQuery(name=constants.API_KEY_NAME, auto_error=False)
api_key_header = APIKeyQuery(name=f"x-{constants.API_KEY_NAME}", auto_error=False)
origins = [ origins = [
"https://codey.lol", "https://codey.lol",
@ -40,7 +37,7 @@ Blacklisted routes
def disallow_get(): def disallow_get():
return util.get_blocked_response() return util.get_blocked_response()
@app.get("/{any}") @app.get("/{any:path}")
def disallow_get_any(var: Any = None): def disallow_get_any(var: Any = None):
return util.get_blocked_response() return util.get_blocked_response()

View File

@ -5,18 +5,11 @@ import logging
import regex import regex
from aiohttp import ClientSession, ClientTimeout from aiohttp import ClientSession, ClientTimeout
from fastapi import FastAPI, Security, Request, HTTPException from fastapi import FastAPI, Security, Request, HTTPException
from fastapi.security import APIKeyHeader, APIKeyQuery from fastapi.security import APIKeyHeader, APIKeyQuery
from pydantic import BaseModel from pydantic import BaseModel
api_key_header = APIKeyHeader(name="X-Authd-With")
class AI(FastAPI): class AI(FastAPI):
"""AI Endpoints""" """AI Endpoints"""
def __init__(self, app: FastAPI, my_util, constants, glob_state): # pylint: disable=super-init-not-called def __init__(self, app: FastAPI, my_util, constants, glob_state): # pylint: disable=super-init-not-called
@ -51,11 +44,7 @@ class AI(FastAPI):
local_llm_headers = { local_llm_headers = {
'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}' 'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
} }
forward_path = self.url_clean_regex.sub('', request.url.path) forward_path = self.url_clean_regex.sub('', request.url.path)
print(f"Original path: {request.url.path}; Forward path: {forward_path}")
print(f"Request data: {await request.json()}")
try: try:
async with ClientSession() as session: async with ClientSession() as session:
async with await session.post(f'{self.constants.LOCAL_LLM_HOST}/{forward_path}', async with await session.post(f'{self.constants.LOCAL_LLM_HOST}/{forward_path}',
@ -64,7 +53,6 @@ class AI(FastAPI):
timeout=ClientTimeout(connect=15, sock_read=30)) as request: timeout=ClientTimeout(connect=15, sock_read=30)) as request:
await self.glob_state.increment_counter('ai_requests') await self.glob_state.increment_counter('ai_requests')
response = await request.json() response = await request.json()
print(f"Response received: {response}")
return response return response
except Exception as e: # pylint: disable=broad-exception-caught except Exception as e: # pylint: disable=broad-exception-caught
logging.error("Error: %s", e) logging.error("Error: %s", e)

View File

@ -3,19 +3,13 @@
import logging import logging
from fastapi import FastAPI, Response, HTTPException, Security from fastapi import FastAPI, Response, HTTPException, Security
from fastapi.security import APIKeyHeader, APIKeyQuery
global api_key_query
global api_key_header
class Utilities: class Utilities:
def __init__(self, app: FastAPI, constants): def __init__(self, app: FastAPI, constants):
self.constants = constants self.constants = constants
self.blocked_response_status = 422 self.blocked_response_status = 422
self.blocked_response_content = None self.blocked_response_content = None
self.api_key_query = APIKeyQuery(name=constants.API_KEY_NAME, auto_error=False)
self.api_key_header = APIKeyHeader(name=f"x-{constants.API_KEY_NAME}", auto_error=False)
def get_blocked_response(self, path: str | None = None): def get_blocked_response(self, path: str | None = None):
logging.error("Rejected request: Blocked") logging.error("Rejected request: Blocked")