From ed5eb36ebb0ca316ead1719a4bb4ffdb463b47c8 Mon Sep 17 00:00:00 2001 From: codey Date: Wed, 2 Oct 2024 20:54:34 -0400 Subject: [PATCH] without local llm --- base.py | 19 +++++- cah/constructors.py | 4 +- cah/websocket_conn.py | 10 +++ endpoints/ai.py | 148 ++++++++++++++++++++++++++++-------------- endpoints/cah.py | 113 ++++++++++++++++++++++++++------ 5 files changed, 224 insertions(+), 70 deletions(-) diff --git a/base.py b/base.py index 6c98196..f48e4c4 100644 --- a/base.py +++ b/base.py @@ -2,20 +2,30 @@ import importlib import logging +import asyncio from typing import Any from fastapi import FastAPI, WebSocket from fastapi.security import APIKeyHeader, APIKeyQuery from fastapi.middleware.cors import CORSMiddleware +from fastapi_utils.tasks import repeat_every + logger = logging.getLogger() -logger.setLevel(logging.CRITICAL) +logger.setLevel(logging.DEBUG) + +loop = asyncio.get_event_loop() app = FastAPI(title="codey.lol API", version="0.1a", contact={ 'name': 'codey' - }) + }, + loop=loop) + +app.loop = loop + + constants = importlib.import_module("constants").Constants() util = importlib.import_module("util").Utilities(app, constants) @@ -74,6 +84,11 @@ xc_endpoints = importlib.import_module("endpoints.xc").XC(app, util, constants, # Below: CAH endpoint(s) cah_endpoints = importlib.import_module("endpoints.cah").CAH(app, util, constants, glob_state) +@app.on_event("startup") +@repeat_every(seconds=10) +async def cah_tasks() -> None: + return await cah_endpoints.periodicals() + """ diff --git a/cah/constructors.py b/cah/constructors.py index 4a08a0c..f7a3c74 100644 --- a/cah/constructors.py +++ b/cah/constructors.py @@ -6,12 +6,14 @@ class CAHClient: platform: str, csid: str, connected_at: int, - players: list): + players: list, + games: list): self.resource: str = resource self.platform: str = platform self.csid: str = csid self.connected_at: int = connected_at self.players: list = players + self.games: list = games def __iter__(self): return [value for value in self.__dict__.values() if isinstance(value, int) or isinstance(value, float)].__iter__() diff --git a/cah/websocket_conn.py b/cah/websocket_conn.py index 2a86a6c..7b1795a 100644 --- a/cah/websocket_conn.py +++ b/cah/websocket_conn.py @@ -18,6 +18,14 @@ class ConnectionManager: if connection.get('csid') == csid: return connection + def get_connection_by_resource_label(self, resource: str): + for connection in self.active_connections: + try: + if connection.get('client').get('resource') == resource: + return connection + except: + continue + async def send_client_and_game_lists(self, state, websocket: WebSocket): clients = [] games = [game.__dict__ for game in state.games] @@ -85,6 +93,7 @@ class ConnectionManager: disconnected = self.get_connection_by_ws(websocket) disconnected_client = disconnected.get('client') disconnected_resource = disconnected_client.resource + disconnected_games = [str(game.id) for game in disconnected_client.games] await self.broadcast({ "event": "client_disconnected", "ts": int(time.time()), @@ -92,6 +101,7 @@ class ConnectionManager: "disconnected_resource": disconnected_resource, } }) + await state.remove_resource(disconnected_games, disconnected_resource) self.active_connections.pop(websocket) diff --git a/endpoints/ai.py b/endpoints/ai.py index 033b7f5..d6641c6 100644 --- a/endpoints/ai.py +++ b/endpoints/ai.py @@ -25,50 +25,81 @@ class AI(FastAPI): self.util = my_util self.constants = constants self.glob_state = glob_state - self.url_clean_regex = regex.compile(r'^\/ai\/openai\/') + self.url_clean_regex = regex.compile(r'^\/ai\/(openai|base)\/') self.endpoints = { - "ai/openai": self.ai_openai_handler, + # "ai/openai": self.ai_openai_handler, + # "ai/base": self.ai_handler, "ai/song": self.ai_song_handler #tbd } for endpoint, handler in self.endpoints.items(): app.add_api_route(f"/{endpoint}/{{any:path}}", handler, methods=["POST"]) + + # async def ai_handler(self, request: Request): + # """ + # /ai/base/ + # AI BASE Request + # """ + + # if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')): + # raise HTTPException(status_code=403, detail="Unauthorized") + + + # local_llm_headers = { + # 'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}' + # } + # forward_path = self.url_clean_regex.sub('', request.url.path) + # try: + # async with ClientSession() as session: + # async with await session.post(f'{self.constants.LOCAL_LLM_BASE}/{forward_path}', + # json=await request.json(), + # headers=local_llm_headers, + # timeout=ClientTimeout(connect=15, sock_read=30)) as out_request: + # await self.glob_state.increment_counter('ai_requests') + # response = await out_request.json() + # return response + # except Exception as e: # pylint: disable=broad-exception-caught + # logging.error("Error: %s", e) + # return { + # 'err': True, + # 'errorText': 'General Failure' + # } - async def ai_openai_handler(self, request: Request): - """ - /ai/openai/ - AI Request - """ + # async def ai_openai_handler(self, request: Request): + # """ + # /ai/openai/ + # AI Request + # """ - if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')): - raise HTTPException(status_code=403, detail="Unauthorized") + # if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')): + # raise HTTPException(status_code=403, detail="Unauthorized") - """ - TODO: Implement Claude - Currently only routes to local LLM - """ + # """ + # TODO: Implement Claude + # Currently only routes to local LLM + # """ - local_llm_headers = { - 'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}' - } - forward_path = self.url_clean_regex.sub('', request.url.path) - try: - async with ClientSession() as session: - async with await session.post(f'{self.constants.LOCAL_LLM_HOST}/{forward_path}', - json=await request.json(), - headers=local_llm_headers, - timeout=ClientTimeout(connect=15, sock_read=30)) as request: - await self.glob_state.increment_counter('ai_requests') - response = await request.json() - return response - except Exception as e: # pylint: disable=broad-exception-caught - logging.error("Error: %s", e) - return { - 'err': True, - 'errorText': 'General Failure' - } + # local_llm_headers = { + # 'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}' + # } + # forward_path = self.url_clean_regex.sub('', request.url.path) + # try: + # async with ClientSession() as session: + # async with await session.post(f'{self.constants.LOCAL_LLM_HOST}/{forward_path}', + # json=await request.json(), + # headers=local_llm_headers, + # timeout=ClientTimeout(connect=15, sock_read=30)) as out_request: + # await self.glob_state.increment_counter('ai_requests') + # response = await out_request.json() + # return response + # except Exception as e: # pylint: disable=broad-exception-caught + # logging.error("Error: %s", e) + # return { + # 'err': True, + # 'errorText': 'General Failure' + # } async def ai_song_handler(self, data: ValidAISongRequest, request: Request): """ @@ -76,33 +107,54 @@ class AI(FastAPI): AI (Song Info) Request [Public] """ + ai_prompt = "You are a helpful assistant who will provide tidbits of info on songs the user may listen to." + ai_question = f"I am going to listen to the song \"{data.s}\" by \"{data.a}\"." + + local_llm_headers = { - 'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}' + 'x-api-key': self.constants.LOCAL_LLM_KEY, + 'anthropic-version': '2023-06-01', + 'content-type': 'application/json', } - ai_req_data = { - 'max_context_length': 16784, - 'max_length': 256, - 'temperature': 0.2, - 'quiet': 1, - 'bypass_eos': False, - 'trim_stop': True, - 'sampler_order': [6,0,1,3,4,2,5], - 'memory': "You are a helpful assistant who will provide only totally accurate tidbits of info on songs the user may listen to. You do not include information about which album a song was released on, or when it was released, and do not mention that you are not including this information in your response. If the input provided is not a song you are aware of, simply state that.", - 'stop': ['### Inst', '### Resp'], - 'prompt': ai_question + + request_data = { + 'model': 'claude-3-haiku-20240307', + 'max_tokens': 512, + 'temperature': 0.6, + 'system': ai_prompt, + 'messages': [ + { + "role": "user", + "content": ai_question.strip(), + } + ] } + + # ai_req_data = { + # 'max_context_length': 16784, + # 'max_length': 256, + # 'temperature': 0.2, + # 'quiet': 1, + # 'bypass_eos': False, + # 'trim_stop': True, + # 'sampler_order': [6,0,1,3,4,2,5], + # 'memory': "You are a helpful assistant who will provide only totally accurate tidbits of info on songs the user may listen to. You do not include information about which album a song was released on, or when it was released, and do not mention that you are not including this information in your response. If the input provided is not a song you are aware of, simply state that.", + # 'stop': ['### Inst', '### Resp'], + # 'prompt': ai_question + # } try: async with ClientSession() as session: - async with await session.post(f'{self.constants.LOCAL_LLM_BASE}/generate', - json=ai_req_data, + async with await session.post('https://api.anthropic.com/v1/messages', + json=request_data, headers=local_llm_headers, timeout=ClientTimeout(connect=15, sock_read=30)) as request: - await self.glob_state.increment_counter('ai_requests') + await self.glob_state.increment_counter('claude_ai_requests') response = await request.json() + print(f"Response: {response}") result = { - 'resp': response.get('results')[0].get('text').strip() + 'resp': response.get('content')[0].get('text').strip() } return result except Exception as e: # pylint: disable=broad-exception-caught diff --git a/endpoints/cah.py b/endpoints/cah.py index 176a63a..fa1f45f 100644 --- a/endpoints/cah.py +++ b/endpoints/cah.py @@ -1,10 +1,11 @@ #!/usr/bin/env python3.12 -from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks from fastapi_utils.tasks import repeat_every import time import uuid import json +import logging import asyncio import traceback import random @@ -12,7 +13,43 @@ from cah.constructors import CAHClient, CAHPlayer, CAHGame from cah.websocket_conn import ConnectionManager class CAH(FastAPI): - """CAH Endpoint(s)""" + """CAH Endpoint(s)""" + + + # TASKS + + async def send_heartbeats(self) -> None: + try: + while True: + logging.critical("Heartbeat!") + self.connection_manager.broadcast({ + "event": "heartbeat", + "ts": int(time.time()) + }) + await asyncio.sleep(10) + except: + print(traceback.format_exc()) + + + async def clean_stale_games(self) -> None: + try: + logging.critical("Looking for stale games...") + for game in self.games: + print(f"Checking {game}...") + if not game.players: + logging.critical(f"{game.id} seems pretty stale! {game.resources}") + self.games.remove(game) + return + except: + print(traceback.format_exc()) + + async def periodicals(self) -> None: + asyncio.get_event_loop().create_task(self.send_heartbeats()) + asyncio.get_event_loop().create_task(self.clean_stale_games()) + + + # END TASKS + def __init__(self, app: FastAPI, util, constants, glob_state): # pylint: disable=super-init-not-called self.app = app self.util = util @@ -38,21 +75,11 @@ class CAH(FastAPI): for endpoint, handler in self.endpoints.items(): app.add_api_route(f"/{endpoint}/", handler, methods=["POST"]) - - asyncio.get_event_loop().create_task(self.send_heartbeats()) - - async def send_heartbeats(self): - while True: - print("Heartbeat!") - await self.connection_manager.broadcast({ - "event": "heartbeat", - "ts": int(time.time()) - }) - await asyncio.sleep(5) + # heartbeats = app.loop.create_task(self.send_heartbeats()) + # asyncio.get_event_loop().run_until_complete(heartbeats) - async def remove_player(self, game: str, player: str): try: _game = None @@ -60,6 +87,8 @@ class CAH(FastAPI): if __game.id == game: _game = __game print(f"Got game!!!\n{_game}\nPlayers: {_game.players}") + + other_players_still_share_resource = False for idx, _player in enumerate(_game.players): if _player.get('handle') == player: _game.players.pop(idx) @@ -70,6 +99,39 @@ class CAH(FastAPI): 'player': _player } }) # Change to broadcast to current game members only + # else: + # if _player.get('related_resource') == _player.related_resource: + # other_players_still_share_resource = True + # if not other_players_still_share_resource: + # _game.resources.remove(_player.get('related_resource')) + except: + print(traceback.format_exc()) + return { + 'err': True, + 'errorText': 'Server error' + } + + async def remove_resource(self, games: list, resource: str): + try: + _game = None + for resource_game in games: + for __game in self.games: + if __game.id == resource_game: + _game = __game + print(f"Got game!!!\n{_game}\nResources: {_game.resources}") + for idx, _resource in enumerate(_game.resources): + if _resource == resource: + _game.resources.pop(idx) + await self.connection_manager.broadcast({ + 'event': 'resource_left', + 'ts': int(time.time()), + 'data': { + 'resource': _resource, + } + }) # Change to broadcast to current game members only + resource_obj = self.connection_manager.get_connection_by_resource_label(resource) + # for player in resource_obj.players: + # await self.remove_player(player.current_game, player) except: print(traceback.format_exc()) return { @@ -104,10 +166,12 @@ class CAH(FastAPI): 'player': player.__dict__, } }) # Change to broadcast to current game members only + + if not player.related_resource in joined_game.resources: + joined_game.resources.append(player.related_resource) + return joined_game - - - + @@ -198,6 +262,7 @@ class CAH(FastAPI): await self.connection_manager.disconnect(self, websocket) + def get_game_by_id(self, _id: str): for game in self.games: if game.id == _id: @@ -281,6 +346,7 @@ class CAH(FastAPI): csid=csid, connected_at=int(time.time()), players=[], + games=[], ) await self.connection_manager.handshake_complete(self, websocket, csid, client) @@ -304,7 +370,7 @@ class CAH(FastAPI): "err": True, "errorText": "Unauthorized", }) - if not data.get('rounds') or not str(data.get('rounds')).isnumeric(): + if not data.get('rounds') or not str(data.get('rounds')).isnumeric() or not data.get('creator_handle'): return await websocket.send_json({ "event": "create_game_response", "ts": int(time.time()), @@ -327,10 +393,18 @@ class CAH(FastAPI): client = self.connection_manager.get_connection_by_ws(websocket).get('client') rounds = int(data.get('rounds')) game_uuid = str(uuid.uuid4()) + creator_handle = data.get('creator_handle') + creator = CAHPlayer(id=str(uuid.uuid4()), + current_game=game_uuid, + platform=client.platform, + related_resource=client.resource, + joined_at=int(time.time()), + handle=creator_handle, + ) game = CAHGame(id=game_uuid, rounds=rounds, resources=[client.resource,], - players=[], + players=[creator.__dict__,], created_at=int(time.time()), state=-1, started_at=0, @@ -350,5 +424,6 @@ class CAH(FastAPI): 'game': game.__dict__, } }) + client.games.append(game) self.games.append(game) await self.connection_manager.send_client_and_game_lists(self, websocket)