From 5d68d132ae18e4a76c81b8e7be98eb8ec6b5acc4 Mon Sep 17 00:00:00 2001 From: codey Date: Tue, 17 Sep 2024 23:07:16 -0400 Subject: [PATCH] misc / begin CAH work, much cleanup to do / conceptualizing so far only --- base.py | 6 ++- cah/constructors.py | 12 +++++ cah/websocket_conn.py | 55 +++++++++++++++++++++++ endpoints/cah.py | 102 ++++++++++++++++++++++++++++++++++++++++++ endpoints/counters.py | 1 - endpoints/ws_test.py | 34 ++++++++++++++ 6 files changed, 207 insertions(+), 3 deletions(-) create mode 100644 cah/constructors.py create mode 100644 cah/websocket_conn.py create mode 100644 endpoints/cah.py create mode 100644 endpoints/ws_test.py diff --git a/base.py b/base.py index 40c04fa..3077414 100644 --- a/base.py +++ b/base.py @@ -4,7 +4,7 @@ import importlib import logging from typing import Any -from fastapi import FastAPI +from fastapi import FastAPI, WebSocket from fastapi.security import APIKeyHeader, APIKeyQuery from fastapi.middleware.cors import CORSMiddleware @@ -33,6 +33,7 @@ allow_methods=["POST"], allow_headers=["*"]) + """ Blacklisted routes """ @@ -70,7 +71,8 @@ lastfm_endpoints = importlib.import_module("endpoints.lastfm").LastFM(app, util, yt_endpoints = importlib.import_module("endpoints.yt").YT(app, util, constants, glob_state) # Below: XC endpoint(s) xc_endpoints = importlib.import_module("endpoints.xc").XC(app, util, constants, glob_state) - +# Below: CAH endpoint(s) +cah_endpoints = importlib.import_module("endpoints.cah").CAH(app, util, constants, glob_state) diff --git a/cah/constructors.py b/cah/constructors.py new file mode 100644 index 0000000..8511acd --- /dev/null +++ b/cah/constructors.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3.12 + +class CAHClient: + def __init__(self, + resource: str, + platform: str, + csid: str, + connected_at: int): + self.resource: str = resource + self.platform: str = platform + self.csid: str = csid + self.connected_at: int = connected_at \ No newline at end of file diff --git a/cah/websocket_conn.py b/cah/websocket_conn.py new file mode 100644 index 0000000..2303837 --- /dev/null +++ b/cah/websocket_conn.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3.12 + +import json +import time + +from fastapi import WebSocket +from cah.constructors import CAHClient + +class ConnectionManager: + def __init__(self): + self.active_connections: dict = {} + + def get_connection_by_ws(self, websocket: WebSocket) -> WebSocket: + return self.active_connections.get(websocket) + + def get_connection_by_csid(self, csid: str) -> WebSocket: + for connection in self.active_connections: + if connection.get('csid') == csid: + return connection + + async def connect(self, websocket: WebSocket): + await websocket.accept() + self.active_connections[websocket] = { + 'client': None, + 'websocket': websocket, + } + + + async def handshake_complete(self, + websocket: WebSocket, + csid: str, + handshakedClient: CAHClient): + self.active_connections[websocket] = { + 'websocket': websocket, + 'csid': csid, + 'client': handshakedClient, + } + + await self.broadcast({ + "event": "client_connected", + "ts": str(time.time()), + "data": { + "connected_resource": handshakedClient.resource, + } + }) + + def disconnect(self, websocket: WebSocket, csid: str = None): + self.active_connections.pop(websocket) + + async def send(self, message: str, websocket: WebSocket): + await websocket.send_json(message) + + async def broadcast(self, message: str): + for connection in self.active_connections: + await connection.send_json(message) \ No newline at end of file diff --git a/endpoints/cah.py b/endpoints/cah.py new file mode 100644 index 0000000..7118795 --- /dev/null +++ b/endpoints/cah.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3.12 +from fastapi import FastAPI, WebSocket, WebSocketDisconnect + +import time +import uuid +import json +import random +from cah.constructors import CAHClient +from cah.websocket_conn import ConnectionManager + +class CAH(FastAPI): + """CAH Endpoint(s)""" + def __init__(self, app: FastAPI, util, constants, glob_state): # pylint: disable=super-init-not-called + self.app = app + self.util = util + self.constants = constants + self.glob_state = glob_state + + self.ws_endpoints = { + "cah": self.cah_handler, + #tbd + } + + self.endpoints = { + #tbd if any non-WS endpoints + } + + self.connection_manager = ConnectionManager() + + + for endpoint, handler in self.ws_endpoints.items(): + print(f"Adding websocket route: {endpoint} @ {handler}") + app.add_api_websocket_route(f"/{endpoint}/", handler) + + for endpoint, handler in self.endpoints.items(): + app.add_api_route(f"/{endpoint}/", handler, methods=["POST"]) + + async def cah_handler(self, websocket: WebSocket): + """/cah WebSocket""" + await self.connection_manager.connect(websocket) + await websocket.send_json({ + "event": "connection_established", + "ts": int(time.time()), + }) + + try: + while True: + data = await websocket.receive_json() + if data.get('event') == 'handshake': + await self.cah_handshake(websocket, + data) + except WebSocketDisconnect: + disconnected = self.connection_manager.get_connection_by_ws(websocket) + self.connection_manager.disconnect(websocket) + await self.connection_manager.broadcast({ + "event": "client_disconnected", + "ts": time.time(), + "data": { + "disconnected_resource": disconnected.get('client').resource, + } + }) + + + + async def cah_handshake(self, websocket: WebSocket, data): + """Handshake""" + self.connection_manager.connect(websocket) + data = data.get('data') + if not data: + await websocket.send_json({ + "err": "WTF", + }) + return await websocket.close() + + csid = str(data.get('csid')) + resource = data.get('resource') + platform = data.get('platform') + + if not csid in self.constants.VALID_CSIDS: + await websocket.send_json({ + "err": "Unauthorized", + }) + return await websocket.close() + + client = CAHClient( + resource=resource, + platform=platform, + csid=csid, + connected_at=time.time(), # fix + ) + await self.connection_manager.handshake_complete(websocket, csid, client) + + await websocket.send_json({ + "event": "handshake_response", + "ts": int(time.time()), + "data": { + "success": True, + "resource": resource, + "platform": platform, + "games": [str(uuid.uuid4()) for x in range(0, 11)], + }, + }) diff --git a/endpoints/counters.py b/endpoints/counters.py index fafafc0..bd8c0a2 100644 --- a/endpoints/counters.py +++ b/endpoints/counters.py @@ -2,7 +2,6 @@ #!/usr/bin/env python3.12 -import importlib from fastapi import FastAPI from pydantic import BaseModel diff --git a/endpoints/ws_test.py b/endpoints/ws_test.py new file mode 100644 index 0000000..7f99dce --- /dev/null +++ b/endpoints/ws_test.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python + +import asyncio +import json +import random +import uuid +from websockets.client import connect + +async def hello(): + async with connect("wss://api.codey.lol/cah/") as websocket: + message = await websocket.recv() + x = str(uuid.uuid4()) + print(f"Init Message: {message}") + await websocket.send(json.dumps({ + "event": "handshake", + "data": { + "resource": f"Test Client UUID {x}", + "platform": "Discord", + "csid": "2bd60a53-023d-49a5-a668-ce8fa8f6ec7f", + } + })) + while True: + message = await websocket.recv() + print(message) + await asyncio.sleep(random.uniform(20, 35)) + await websocket.close() + # for x in range(0, 200): + # await websocket.send(f"Hello world! {x}") + # message = await websocket.recv() + # print(f"Received: {message}") + # await asyncio.sleep(0.2) + +asyncio.get_event_loop().create_task(hello()) +asyncio.get_event_loop().run_forever() \ No newline at end of file