without local llm

This commit is contained in:
codey 2024-10-02 20:54:34 -04:00
parent f20a325a1f
commit ed5eb36ebb
5 changed files with 224 additions and 70 deletions

19
base.py
View File

@ -2,20 +2,30 @@
import importlib
import logging
import asyncio
from typing import Any
from fastapi import FastAPI, WebSocket
from fastapi.security import APIKeyHeader, APIKeyQuery
from fastapi.middleware.cors import CORSMiddleware
from fastapi_utils.tasks import repeat_every
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)
logger.setLevel(logging.DEBUG)
loop = asyncio.get_event_loop()
app = FastAPI(title="codey.lol API",
version="0.1a",
contact={
'name': 'codey'
})
},
loop=loop)
app.loop = loop
constants = importlib.import_module("constants").Constants()
util = importlib.import_module("util").Utilities(app, constants)
@ -74,6 +84,11 @@ xc_endpoints = importlib.import_module("endpoints.xc").XC(app, util, constants,
# Below: CAH endpoint(s)
cah_endpoints = importlib.import_module("endpoints.cah").CAH(app, util, constants, glob_state)
@app.on_event("startup")
@repeat_every(seconds=10)
async def cah_tasks() -> None:
return await cah_endpoints.periodicals()
"""

View File

@ -6,12 +6,14 @@ class CAHClient:
platform: str,
csid: str,
connected_at: int,
players: list):
players: list,
games: list):
self.resource: str = resource
self.platform: str = platform
self.csid: str = csid
self.connected_at: int = connected_at
self.players: list = players
self.games: list = games
def __iter__(self):
return [value for value in self.__dict__.values() if isinstance(value, int) or isinstance(value, float)].__iter__()

View File

@ -18,6 +18,14 @@ class ConnectionManager:
if connection.get('csid') == csid:
return connection
def get_connection_by_resource_label(self, resource: str):
for connection in self.active_connections:
try:
if connection.get('client').get('resource') == resource:
return connection
except:
continue
async def send_client_and_game_lists(self, state, websocket: WebSocket):
clients = []
games = [game.__dict__ for game in state.games]
@ -85,6 +93,7 @@ class ConnectionManager:
disconnected = self.get_connection_by_ws(websocket)
disconnected_client = disconnected.get('client')
disconnected_resource = disconnected_client.resource
disconnected_games = [str(game.id) for game in disconnected_client.games]
await self.broadcast({
"event": "client_disconnected",
"ts": int(time.time()),
@ -92,6 +101,7 @@ class ConnectionManager:
"disconnected_resource": disconnected_resource,
}
})
await state.remove_resource(disconnected_games, disconnected_resource)
self.active_connections.pop(websocket)

View File

@ -25,9 +25,10 @@ class AI(FastAPI):
self.util = my_util
self.constants = constants
self.glob_state = glob_state
self.url_clean_regex = regex.compile(r'^\/ai\/openai\/')
self.url_clean_regex = regex.compile(r'^\/ai\/(openai|base)\/')
self.endpoints = {
"ai/openai": self.ai_openai_handler,
# "ai/openai": self.ai_openai_handler,
# "ai/base": self.ai_handler,
"ai/song": self.ai_song_handler
#tbd
}
@ -35,40 +36,70 @@ class AI(FastAPI):
for endpoint, handler in self.endpoints.items():
app.add_api_route(f"/{endpoint}/{{any:path}}", handler, methods=["POST"])
async def ai_openai_handler(self, request: Request):
"""
/ai/openai/
AI Request
"""
# async def ai_handler(self, request: Request):
# """
# /ai/base/
# AI BASE Request
# """
if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')):
raise HTTPException(status_code=403, detail="Unauthorized")
# if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')):
# raise HTTPException(status_code=403, detail="Unauthorized")
"""
TODO: Implement Claude
Currently only routes to local LLM
"""
# local_llm_headers = {
# 'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
# }
# forward_path = self.url_clean_regex.sub('', request.url.path)
# try:
# async with ClientSession() as session:
# async with await session.post(f'{self.constants.LOCAL_LLM_BASE}/{forward_path}',
# json=await request.json(),
# headers=local_llm_headers,
# timeout=ClientTimeout(connect=15, sock_read=30)) as out_request:
# await self.glob_state.increment_counter('ai_requests')
# response = await out_request.json()
# return response
# except Exception as e: # pylint: disable=broad-exception-caught
# logging.error("Error: %s", e)
# return {
# 'err': True,
# 'errorText': 'General Failure'
# }
local_llm_headers = {
'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
}
forward_path = self.url_clean_regex.sub('', request.url.path)
try:
async with ClientSession() as session:
async with await session.post(f'{self.constants.LOCAL_LLM_HOST}/{forward_path}',
json=await request.json(),
headers=local_llm_headers,
timeout=ClientTimeout(connect=15, sock_read=30)) as request:
await self.glob_state.increment_counter('ai_requests')
response = await request.json()
return response
except Exception as e: # pylint: disable=broad-exception-caught
logging.error("Error: %s", e)
return {
'err': True,
'errorText': 'General Failure'
}
# async def ai_openai_handler(self, request: Request):
# """
# /ai/openai/
# AI Request
# """
# if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')):
# raise HTTPException(status_code=403, detail="Unauthorized")
# """
# TODO: Implement Claude
# Currently only routes to local LLM
# """
# local_llm_headers = {
# 'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
# }
# forward_path = self.url_clean_regex.sub('', request.url.path)
# try:
# async with ClientSession() as session:
# async with await session.post(f'{self.constants.LOCAL_LLM_HOST}/{forward_path}',
# json=await request.json(),
# headers=local_llm_headers,
# timeout=ClientTimeout(connect=15, sock_read=30)) as out_request:
# await self.glob_state.increment_counter('ai_requests')
# response = await out_request.json()
# return response
# except Exception as e: # pylint: disable=broad-exception-caught
# logging.error("Error: %s", e)
# return {
# 'err': True,
# 'errorText': 'General Failure'
# }
async def ai_song_handler(self, data: ValidAISongRequest, request: Request):
"""
@ -76,33 +107,54 @@ class AI(FastAPI):
AI (Song Info) Request [Public]
"""
ai_prompt = "You are a helpful assistant who will provide tidbits of info on songs the user may listen to."
ai_question = f"I am going to listen to the song \"{data.s}\" by \"{data.a}\"."
local_llm_headers = {
'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
'x-api-key': self.constants.LOCAL_LLM_KEY,
'anthropic-version': '2023-06-01',
'content-type': 'application/json',
}
ai_req_data = {
'max_context_length': 16784,
'max_length': 256,
'temperature': 0.2,
'quiet': 1,
'bypass_eos': False,
'trim_stop': True,
'sampler_order': [6,0,1,3,4,2,5],
'memory': "You are a helpful assistant who will provide only totally accurate tidbits of info on songs the user may listen to. You do not include information about which album a song was released on, or when it was released, and do not mention that you are not including this information in your response. If the input provided is not a song you are aware of, simply state that.",
'stop': ['### Inst', '### Resp'],
'prompt': ai_question
request_data = {
'model': 'claude-3-haiku-20240307',
'max_tokens': 512,
'temperature': 0.6,
'system': ai_prompt,
'messages': [
{
"role": "user",
"content": ai_question.strip(),
}
]
}
# ai_req_data = {
# 'max_context_length': 16784,
# 'max_length': 256,
# 'temperature': 0.2,
# 'quiet': 1,
# 'bypass_eos': False,
# 'trim_stop': True,
# 'sampler_order': [6,0,1,3,4,2,5],
# 'memory': "You are a helpful assistant who will provide only totally accurate tidbits of info on songs the user may listen to. You do not include information about which album a song was released on, or when it was released, and do not mention that you are not including this information in your response. If the input provided is not a song you are aware of, simply state that.",
# 'stop': ['### Inst', '### Resp'],
# 'prompt': ai_question
# }
try:
async with ClientSession() as session:
async with await session.post(f'{self.constants.LOCAL_LLM_BASE}/generate',
json=ai_req_data,
async with await session.post('https://api.anthropic.com/v1/messages',
json=request_data,
headers=local_llm_headers,
timeout=ClientTimeout(connect=15, sock_read=30)) as request:
await self.glob_state.increment_counter('ai_requests')
await self.glob_state.increment_counter('claude_ai_requests')
response = await request.json()
print(f"Response: {response}")
result = {
'resp': response.get('results')[0].get('text').strip()
'resp': response.get('content')[0].get('text').strip()
}
return result
except Exception as e: # pylint: disable=broad-exception-caught

View File

@ -1,10 +1,11 @@
#!/usr/bin/env python3.12
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks
from fastapi_utils.tasks import repeat_every
import time
import uuid
import json
import logging
import asyncio
import traceback
import random
@ -13,6 +14,42 @@ from cah.websocket_conn import ConnectionManager
class CAH(FastAPI):
"""CAH Endpoint(s)"""
# TASKS
async def send_heartbeats(self) -> None:
try:
while True:
logging.critical("Heartbeat!")
self.connection_manager.broadcast({
"event": "heartbeat",
"ts": int(time.time())
})
await asyncio.sleep(10)
except:
print(traceback.format_exc())
async def clean_stale_games(self) -> None:
try:
logging.critical("Looking for stale games...")
for game in self.games:
print(f"Checking {game}...")
if not game.players:
logging.critical(f"{game.id} seems pretty stale! {game.resources}")
self.games.remove(game)
return
except:
print(traceback.format_exc())
async def periodicals(self) -> None:
asyncio.get_event_loop().create_task(self.send_heartbeats())
asyncio.get_event_loop().create_task(self.clean_stale_games())
# END TASKS
def __init__(self, app: FastAPI, util, constants, glob_state): # pylint: disable=super-init-not-called
self.app = app
self.util = util
@ -39,19 +76,9 @@ class CAH(FastAPI):
for endpoint, handler in self.endpoints.items():
app.add_api_route(f"/{endpoint}/", handler, methods=["POST"])
asyncio.get_event_loop().create_task(self.send_heartbeats())
async def send_heartbeats(self):
while True:
print("Heartbeat!")
await self.connection_manager.broadcast({
"event": "heartbeat",
"ts": int(time.time())
})
await asyncio.sleep(5)
# heartbeats = app.loop.create_task(self.send_heartbeats())
# asyncio.get_event_loop().run_until_complete(heartbeats)
async def remove_player(self, game: str, player: str):
try:
@ -60,6 +87,8 @@ class CAH(FastAPI):
if __game.id == game:
_game = __game
print(f"Got game!!!\n{_game}\nPlayers: {_game.players}")
other_players_still_share_resource = False
for idx, _player in enumerate(_game.players):
if _player.get('handle') == player:
_game.players.pop(idx)
@ -70,6 +99,39 @@ class CAH(FastAPI):
'player': _player
}
}) # Change to broadcast to current game members only
# else:
# if _player.get('related_resource') == _player.related_resource:
# other_players_still_share_resource = True
# if not other_players_still_share_resource:
# _game.resources.remove(_player.get('related_resource'))
except:
print(traceback.format_exc())
return {
'err': True,
'errorText': 'Server error'
}
async def remove_resource(self, games: list, resource: str):
try:
_game = None
for resource_game in games:
for __game in self.games:
if __game.id == resource_game:
_game = __game
print(f"Got game!!!\n{_game}\nResources: {_game.resources}")
for idx, _resource in enumerate(_game.resources):
if _resource == resource:
_game.resources.pop(idx)
await self.connection_manager.broadcast({
'event': 'resource_left',
'ts': int(time.time()),
'data': {
'resource': _resource,
}
}) # Change to broadcast to current game members only
resource_obj = self.connection_manager.get_connection_by_resource_label(resource)
# for player in resource_obj.players:
# await self.remove_player(player.current_game, player)
except:
print(traceback.format_exc())
return {
@ -104,13 +166,15 @@ class CAH(FastAPI):
'player': player.__dict__,
}
}) # Change to broadcast to current game members only
if not player.related_resource in joined_game.resources:
joined_game.resources.append(player.related_resource)
return joined_game
async def cah_handler(self, websocket: WebSocket):
"""/cah WebSocket"""
await self.connection_manager.connect(websocket)
@ -198,6 +262,7 @@ class CAH(FastAPI):
await self.connection_manager.disconnect(self, websocket)
def get_game_by_id(self, _id: str):
for game in self.games:
if game.id == _id:
@ -281,6 +346,7 @@ class CAH(FastAPI):
csid=csid,
connected_at=int(time.time()),
players=[],
games=[],
)
await self.connection_manager.handshake_complete(self, websocket, csid, client)
@ -304,7 +370,7 @@ class CAH(FastAPI):
"err": True,
"errorText": "Unauthorized",
})
if not data.get('rounds') or not str(data.get('rounds')).isnumeric():
if not data.get('rounds') or not str(data.get('rounds')).isnumeric() or not data.get('creator_handle'):
return await websocket.send_json({
"event": "create_game_response",
"ts": int(time.time()),
@ -327,10 +393,18 @@ class CAH(FastAPI):
client = self.connection_manager.get_connection_by_ws(websocket).get('client')
rounds = int(data.get('rounds'))
game_uuid = str(uuid.uuid4())
creator_handle = data.get('creator_handle')
creator = CAHPlayer(id=str(uuid.uuid4()),
current_game=game_uuid,
platform=client.platform,
related_resource=client.resource,
joined_at=int(time.time()),
handle=creator_handle,
)
game = CAHGame(id=game_uuid,
rounds=rounds,
resources=[client.resource,],
players=[],
players=[creator.__dict__,],
created_at=int(time.time()),
state=-1,
started_at=0,
@ -350,5 +424,6 @@ class CAH(FastAPI):
'game': game.__dict__,
}
})
client.games.append(game)
self.games.append(game)
await self.connection_manager.send_client_and_game_lists(self, websocket)