without local llm

This commit is contained in:
codey 2024-10-02 20:54:34 -04:00
parent f20a325a1f
commit ed5eb36ebb
5 changed files with 224 additions and 70 deletions

19
base.py
View File

@ -2,20 +2,30 @@
import importlib import importlib
import logging import logging
import asyncio
from typing import Any from typing import Any
from fastapi import FastAPI, WebSocket from fastapi import FastAPI, WebSocket
from fastapi.security import APIKeyHeader, APIKeyQuery from fastapi.security import APIKeyHeader, APIKeyQuery
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi_utils.tasks import repeat_every
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(logging.CRITICAL) logger.setLevel(logging.DEBUG)
loop = asyncio.get_event_loop()
app = FastAPI(title="codey.lol API", app = FastAPI(title="codey.lol API",
version="0.1a", version="0.1a",
contact={ contact={
'name': 'codey' 'name': 'codey'
}) },
loop=loop)
app.loop = loop
constants = importlib.import_module("constants").Constants() constants = importlib.import_module("constants").Constants()
util = importlib.import_module("util").Utilities(app, constants) util = importlib.import_module("util").Utilities(app, constants)
@ -74,6 +84,11 @@ xc_endpoints = importlib.import_module("endpoints.xc").XC(app, util, constants,
# Below: CAH endpoint(s) # Below: CAH endpoint(s)
cah_endpoints = importlib.import_module("endpoints.cah").CAH(app, util, constants, glob_state) cah_endpoints = importlib.import_module("endpoints.cah").CAH(app, util, constants, glob_state)
@app.on_event("startup")
@repeat_every(seconds=10)
async def cah_tasks() -> None:
return await cah_endpoints.periodicals()
""" """

View File

@ -6,12 +6,14 @@ class CAHClient:
platform: str, platform: str,
csid: str, csid: str,
connected_at: int, connected_at: int,
players: list): players: list,
games: list):
self.resource: str = resource self.resource: str = resource
self.platform: str = platform self.platform: str = platform
self.csid: str = csid self.csid: str = csid
self.connected_at: int = connected_at self.connected_at: int = connected_at
self.players: list = players self.players: list = players
self.games: list = games
def __iter__(self): def __iter__(self):
return [value for value in self.__dict__.values() if isinstance(value, int) or isinstance(value, float)].__iter__() return [value for value in self.__dict__.values() if isinstance(value, int) or isinstance(value, float)].__iter__()

View File

@ -18,6 +18,14 @@ class ConnectionManager:
if connection.get('csid') == csid: if connection.get('csid') == csid:
return connection return connection
def get_connection_by_resource_label(self, resource: str):
for connection in self.active_connections:
try:
if connection.get('client').get('resource') == resource:
return connection
except:
continue
async def send_client_and_game_lists(self, state, websocket: WebSocket): async def send_client_and_game_lists(self, state, websocket: WebSocket):
clients = [] clients = []
games = [game.__dict__ for game in state.games] games = [game.__dict__ for game in state.games]
@ -85,6 +93,7 @@ class ConnectionManager:
disconnected = self.get_connection_by_ws(websocket) disconnected = self.get_connection_by_ws(websocket)
disconnected_client = disconnected.get('client') disconnected_client = disconnected.get('client')
disconnected_resource = disconnected_client.resource disconnected_resource = disconnected_client.resource
disconnected_games = [str(game.id) for game in disconnected_client.games]
await self.broadcast({ await self.broadcast({
"event": "client_disconnected", "event": "client_disconnected",
"ts": int(time.time()), "ts": int(time.time()),
@ -92,6 +101,7 @@ class ConnectionManager:
"disconnected_resource": disconnected_resource, "disconnected_resource": disconnected_resource,
} }
}) })
await state.remove_resource(disconnected_games, disconnected_resource)
self.active_connections.pop(websocket) self.active_connections.pop(websocket)

View File

@ -25,9 +25,10 @@ class AI(FastAPI):
self.util = my_util self.util = my_util
self.constants = constants self.constants = constants
self.glob_state = glob_state self.glob_state = glob_state
self.url_clean_regex = regex.compile(r'^\/ai\/openai\/') self.url_clean_regex = regex.compile(r'^\/ai\/(openai|base)\/')
self.endpoints = { self.endpoints = {
"ai/openai": self.ai_openai_handler, # "ai/openai": self.ai_openai_handler,
# "ai/base": self.ai_handler,
"ai/song": self.ai_song_handler "ai/song": self.ai_song_handler
#tbd #tbd
} }
@ -35,40 +36,70 @@ class AI(FastAPI):
for endpoint, handler in self.endpoints.items(): for endpoint, handler in self.endpoints.items():
app.add_api_route(f"/{endpoint}/{{any:path}}", handler, methods=["POST"]) app.add_api_route(f"/{endpoint}/{{any:path}}", handler, methods=["POST"])
async def ai_openai_handler(self, request: Request): # async def ai_handler(self, request: Request):
""" # """
/ai/openai/ # /ai/base/
AI Request # AI BASE Request
""" # """
if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')): # if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')):
raise HTTPException(status_code=403, detail="Unauthorized") # raise HTTPException(status_code=403, detail="Unauthorized")
""" # local_llm_headers = {
TODO: Implement Claude # 'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
Currently only routes to local LLM # }
""" # forward_path = self.url_clean_regex.sub('', request.url.path)
# try:
# async with ClientSession() as session:
# async with await session.post(f'{self.constants.LOCAL_LLM_BASE}/{forward_path}',
# json=await request.json(),
# headers=local_llm_headers,
# timeout=ClientTimeout(connect=15, sock_read=30)) as out_request:
# await self.glob_state.increment_counter('ai_requests')
# response = await out_request.json()
# return response
# except Exception as e: # pylint: disable=broad-exception-caught
# logging.error("Error: %s", e)
# return {
# 'err': True,
# 'errorText': 'General Failure'
# }
local_llm_headers = { # async def ai_openai_handler(self, request: Request):
'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}' # """
} # /ai/openai/
forward_path = self.url_clean_regex.sub('', request.url.path) # AI Request
try: # """
async with ClientSession() as session:
async with await session.post(f'{self.constants.LOCAL_LLM_HOST}/{forward_path}', # if not self.util.check_key(request.url.path, request.headers.get('X-Authd-With')):
json=await request.json(), # raise HTTPException(status_code=403, detail="Unauthorized")
headers=local_llm_headers,
timeout=ClientTimeout(connect=15, sock_read=30)) as request:
await self.glob_state.increment_counter('ai_requests') # """
response = await request.json() # TODO: Implement Claude
return response # Currently only routes to local LLM
except Exception as e: # pylint: disable=broad-exception-caught # """
logging.error("Error: %s", e)
return { # local_llm_headers = {
'err': True, # 'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}'
'errorText': 'General Failure' # }
} # forward_path = self.url_clean_regex.sub('', request.url.path)
# try:
# async with ClientSession() as session:
# async with await session.post(f'{self.constants.LOCAL_LLM_HOST}/{forward_path}',
# json=await request.json(),
# headers=local_llm_headers,
# timeout=ClientTimeout(connect=15, sock_read=30)) as out_request:
# await self.glob_state.increment_counter('ai_requests')
# response = await out_request.json()
# return response
# except Exception as e: # pylint: disable=broad-exception-caught
# logging.error("Error: %s", e)
# return {
# 'err': True,
# 'errorText': 'General Failure'
# }
async def ai_song_handler(self, data: ValidAISongRequest, request: Request): async def ai_song_handler(self, data: ValidAISongRequest, request: Request):
""" """
@ -76,33 +107,54 @@ class AI(FastAPI):
AI (Song Info) Request [Public] AI (Song Info) Request [Public]
""" """
ai_prompt = "You are a helpful assistant who will provide tidbits of info on songs the user may listen to."
ai_question = f"I am going to listen to the song \"{data.s}\" by \"{data.a}\"." ai_question = f"I am going to listen to the song \"{data.s}\" by \"{data.a}\"."
local_llm_headers = { local_llm_headers = {
'Authorization': f'Bearer {self.constants.LOCAL_LLM_KEY}' 'x-api-key': self.constants.LOCAL_LLM_KEY,
'anthropic-version': '2023-06-01',
'content-type': 'application/json',
} }
ai_req_data = {
'max_context_length': 16784, request_data = {
'max_length': 256, 'model': 'claude-3-haiku-20240307',
'temperature': 0.2, 'max_tokens': 512,
'quiet': 1, 'temperature': 0.6,
'bypass_eos': False, 'system': ai_prompt,
'trim_stop': True, 'messages': [
'sampler_order': [6,0,1,3,4,2,5], {
'memory': "You are a helpful assistant who will provide only totally accurate tidbits of info on songs the user may listen to. You do not include information about which album a song was released on, or when it was released, and do not mention that you are not including this information in your response. If the input provided is not a song you are aware of, simply state that.", "role": "user",
'stop': ['### Inst', '### Resp'], "content": ai_question.strip(),
'prompt': ai_question }
]
} }
# ai_req_data = {
# 'max_context_length': 16784,
# 'max_length': 256,
# 'temperature': 0.2,
# 'quiet': 1,
# 'bypass_eos': False,
# 'trim_stop': True,
# 'sampler_order': [6,0,1,3,4,2,5],
# 'memory': "You are a helpful assistant who will provide only totally accurate tidbits of info on songs the user may listen to. You do not include information about which album a song was released on, or when it was released, and do not mention that you are not including this information in your response. If the input provided is not a song you are aware of, simply state that.",
# 'stop': ['### Inst', '### Resp'],
# 'prompt': ai_question
# }
try: try:
async with ClientSession() as session: async with ClientSession() as session:
async with await session.post(f'{self.constants.LOCAL_LLM_BASE}/generate', async with await session.post('https://api.anthropic.com/v1/messages',
json=ai_req_data, json=request_data,
headers=local_llm_headers, headers=local_llm_headers,
timeout=ClientTimeout(connect=15, sock_read=30)) as request: timeout=ClientTimeout(connect=15, sock_read=30)) as request:
await self.glob_state.increment_counter('ai_requests') await self.glob_state.increment_counter('claude_ai_requests')
response = await request.json() response = await request.json()
print(f"Response: {response}")
result = { result = {
'resp': response.get('results')[0].get('text').strip() 'resp': response.get('content')[0].get('text').strip()
} }
return result return result
except Exception as e: # pylint: disable=broad-exception-caught except Exception as e: # pylint: disable=broad-exception-caught

View File

@ -1,10 +1,11 @@
#!/usr/bin/env python3.12 #!/usr/bin/env python3.12
from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks
from fastapi_utils.tasks import repeat_every from fastapi_utils.tasks import repeat_every
import time import time
import uuid import uuid
import json import json
import logging
import asyncio import asyncio
import traceback import traceback
import random import random
@ -13,6 +14,42 @@ from cah.websocket_conn import ConnectionManager
class CAH(FastAPI): class CAH(FastAPI):
"""CAH Endpoint(s)""" """CAH Endpoint(s)"""
# TASKS
async def send_heartbeats(self) -> None:
try:
while True:
logging.critical("Heartbeat!")
self.connection_manager.broadcast({
"event": "heartbeat",
"ts": int(time.time())
})
await asyncio.sleep(10)
except:
print(traceback.format_exc())
async def clean_stale_games(self) -> None:
try:
logging.critical("Looking for stale games...")
for game in self.games:
print(f"Checking {game}...")
if not game.players:
logging.critical(f"{game.id} seems pretty stale! {game.resources}")
self.games.remove(game)
return
except:
print(traceback.format_exc())
async def periodicals(self) -> None:
asyncio.get_event_loop().create_task(self.send_heartbeats())
asyncio.get_event_loop().create_task(self.clean_stale_games())
# END TASKS
def __init__(self, app: FastAPI, util, constants, glob_state): # pylint: disable=super-init-not-called def __init__(self, app: FastAPI, util, constants, glob_state): # pylint: disable=super-init-not-called
self.app = app self.app = app
self.util = util self.util = util
@ -39,19 +76,9 @@ class CAH(FastAPI):
for endpoint, handler in self.endpoints.items(): for endpoint, handler in self.endpoints.items():
app.add_api_route(f"/{endpoint}/", handler, methods=["POST"]) app.add_api_route(f"/{endpoint}/", handler, methods=["POST"])
asyncio.get_event_loop().create_task(self.send_heartbeats())
async def send_heartbeats(self):
while True:
print("Heartbeat!")
await self.connection_manager.broadcast({
"event": "heartbeat",
"ts": int(time.time())
})
await asyncio.sleep(5)
# heartbeats = app.loop.create_task(self.send_heartbeats())
# asyncio.get_event_loop().run_until_complete(heartbeats)
async def remove_player(self, game: str, player: str): async def remove_player(self, game: str, player: str):
try: try:
@ -60,6 +87,8 @@ class CAH(FastAPI):
if __game.id == game: if __game.id == game:
_game = __game _game = __game
print(f"Got game!!!\n{_game}\nPlayers: {_game.players}") print(f"Got game!!!\n{_game}\nPlayers: {_game.players}")
other_players_still_share_resource = False
for idx, _player in enumerate(_game.players): for idx, _player in enumerate(_game.players):
if _player.get('handle') == player: if _player.get('handle') == player:
_game.players.pop(idx) _game.players.pop(idx)
@ -70,6 +99,39 @@ class CAH(FastAPI):
'player': _player 'player': _player
} }
}) # Change to broadcast to current game members only }) # Change to broadcast to current game members only
# else:
# if _player.get('related_resource') == _player.related_resource:
# other_players_still_share_resource = True
# if not other_players_still_share_resource:
# _game.resources.remove(_player.get('related_resource'))
except:
print(traceback.format_exc())
return {
'err': True,
'errorText': 'Server error'
}
async def remove_resource(self, games: list, resource: str):
try:
_game = None
for resource_game in games:
for __game in self.games:
if __game.id == resource_game:
_game = __game
print(f"Got game!!!\n{_game}\nResources: {_game.resources}")
for idx, _resource in enumerate(_game.resources):
if _resource == resource:
_game.resources.pop(idx)
await self.connection_manager.broadcast({
'event': 'resource_left',
'ts': int(time.time()),
'data': {
'resource': _resource,
}
}) # Change to broadcast to current game members only
resource_obj = self.connection_manager.get_connection_by_resource_label(resource)
# for player in resource_obj.players:
# await self.remove_player(player.current_game, player)
except: except:
print(traceback.format_exc()) print(traceback.format_exc())
return { return {
@ -104,13 +166,15 @@ class CAH(FastAPI):
'player': player.__dict__, 'player': player.__dict__,
} }
}) # Change to broadcast to current game members only }) # Change to broadcast to current game members only
if not player.related_resource in joined_game.resources:
joined_game.resources.append(player.related_resource)
return joined_game return joined_game
async def cah_handler(self, websocket: WebSocket): async def cah_handler(self, websocket: WebSocket):
"""/cah WebSocket""" """/cah WebSocket"""
await self.connection_manager.connect(websocket) await self.connection_manager.connect(websocket)
@ -198,6 +262,7 @@ class CAH(FastAPI):
await self.connection_manager.disconnect(self, websocket) await self.connection_manager.disconnect(self, websocket)
def get_game_by_id(self, _id: str): def get_game_by_id(self, _id: str):
for game in self.games: for game in self.games:
if game.id == _id: if game.id == _id:
@ -281,6 +346,7 @@ class CAH(FastAPI):
csid=csid, csid=csid,
connected_at=int(time.time()), connected_at=int(time.time()),
players=[], players=[],
games=[],
) )
await self.connection_manager.handshake_complete(self, websocket, csid, client) await self.connection_manager.handshake_complete(self, websocket, csid, client)
@ -304,7 +370,7 @@ class CAH(FastAPI):
"err": True, "err": True,
"errorText": "Unauthorized", "errorText": "Unauthorized",
}) })
if not data.get('rounds') or not str(data.get('rounds')).isnumeric(): if not data.get('rounds') or not str(data.get('rounds')).isnumeric() or not data.get('creator_handle'):
return await websocket.send_json({ return await websocket.send_json({
"event": "create_game_response", "event": "create_game_response",
"ts": int(time.time()), "ts": int(time.time()),
@ -327,10 +393,18 @@ class CAH(FastAPI):
client = self.connection_manager.get_connection_by_ws(websocket).get('client') client = self.connection_manager.get_connection_by_ws(websocket).get('client')
rounds = int(data.get('rounds')) rounds = int(data.get('rounds'))
game_uuid = str(uuid.uuid4()) game_uuid = str(uuid.uuid4())
creator_handle = data.get('creator_handle')
creator = CAHPlayer(id=str(uuid.uuid4()),
current_game=game_uuid,
platform=client.platform,
related_resource=client.resource,
joined_at=int(time.time()),
handle=creator_handle,
)
game = CAHGame(id=game_uuid, game = CAHGame(id=game_uuid,
rounds=rounds, rounds=rounds,
resources=[client.resource,], resources=[client.resource,],
players=[], players=[creator.__dict__,],
created_at=int(time.time()), created_at=int(time.time()),
state=-1, state=-1,
started_at=0, started_at=0,
@ -350,5 +424,6 @@ class CAH(FastAPI):
'game': game.__dict__, 'game': game.__dict__,
} }
}) })
client.games.append(game)
self.games.append(game) self.games.append(game)
await self.connection_manager.send_client_and_game_lists(self, websocket) await self.connection_manager.send_client_and_game_lists(self, websocket)