radio changes/progress
This commit is contained in:
		@@ -6,12 +6,14 @@ import traceback
 | 
				
			|||||||
import os
 | 
					import os
 | 
				
			||||||
import aiosqlite as sqlite3
 | 
					import aiosqlite as sqlite3
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
import regex
 | 
					import regex
 | 
				
			||||||
import music_tag
 | 
					import music_tag
 | 
				
			||||||
 | 
					from . import radio_util
 | 
				
			||||||
from uuid import uuid4 as uuid
 | 
					from uuid import uuid4 as uuid
 | 
				
			||||||
from pydantic import BaseModel
 | 
					from pydantic import BaseModel
 | 
				
			||||||
from fastapi import FastAPI, Request, Response, HTTPException
 | 
					from fastapi import FastAPI, BackgroundTasks, Request, Response, HTTPException
 | 
				
			||||||
from fastapi.responses import RedirectResponse
 | 
					from fastapi.responses import RedirectResponse
 | 
				
			||||||
from aiohttp import ClientSession, ClientTimeout
 | 
					from aiohttp import ClientSession, ClientTimeout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -39,16 +41,45 @@ class ValidRadioSongRequest(BaseModel):
 | 
				
			|||||||
class ValidRadioNextRequest(BaseModel):
 | 
					class ValidRadioNextRequest(BaseModel):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    - **key**: API Key
 | 
					    - **key**: API Key
 | 
				
			||||||
 | 
					    - **skipTo**: UUID to skip to [optional]
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    key: str
 | 
					    key: str
 | 
				
			||||||
 | 
					    skipTo: str|None = None
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					class ValidRadioReshuffleRequest(ValidRadioNextRequest):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    - **key**: API Key
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					class ValidRadioQueueShiftRequest(BaseModel):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    - **key**: API Key
 | 
				
			||||||
 | 
					    - **uuid**: UUID to shift 
 | 
				
			||||||
 | 
					    - **next**: Play next if true, immediately if false, default False
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    key: str
 | 
				
			||||||
 | 
					    uuid: str
 | 
				
			||||||
 | 
					    next: bool = False
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					class ValidRadioQueueRemovalRequest(BaseModel):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    - **key**: API Key
 | 
				
			||||||
 | 
					    - **uuid**: UUID to remove 
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    key: str
 | 
				
			||||||
 | 
					    uuid: str
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Radio(FastAPI):
 | 
					class Radio(FastAPI):
 | 
				
			||||||
    """Radio Endpoints"""
 | 
					    """Radio Endpoints"""
 | 
				
			||||||
    def __init__(self, app: FastAPI, my_util, constants, glob_state):  # pylint: disable=super-init-not-called
 | 
					    def __init__(self, app: FastAPI, my_util, constants, glob_state) -> None:  # pylint: disable=super-init-not-called
 | 
				
			||||||
        self.app = app
 | 
					        self.app = app
 | 
				
			||||||
        self.util = my_util
 | 
					        self.util = my_util
 | 
				
			||||||
        self.constants = constants
 | 
					        self.constants = constants
 | 
				
			||||||
 | 
					        self.radio_util = radio_util.RadioUtil(self.constants)
 | 
				
			||||||
        self.glob_state = glob_state
 | 
					        self.glob_state = glob_state
 | 
				
			||||||
        self.ls_uri = "http://10.10.10.101:29000"
 | 
					        self.ls_uri = "http://10.10.10.101:29000"
 | 
				
			||||||
        self.sqlite_exts: list[str] = ['/home/singer/api/solibs/spellfix1.cpython-311-x86_64-linux-gnu.so']
 | 
					        self.sqlite_exts: list[str] = ['/home/singer/api/solibs/spellfix1.cpython-311-x86_64-linux-gnu.so']
 | 
				
			||||||
@@ -70,6 +101,9 @@ class Radio(FastAPI):
 | 
				
			|||||||
            "radio/request": self.radio_request,
 | 
					            "radio/request": self.radio_request,
 | 
				
			||||||
            "radio/get_queue": self.radio_get_queue,
 | 
					            "radio/get_queue": self.radio_get_queue,
 | 
				
			||||||
            "radio/skip": self.radio_skip,
 | 
					            "radio/skip": self.radio_skip,
 | 
				
			||||||
 | 
					            "radio/queue_shift": self.radio_queue_shift,
 | 
				
			||||||
 | 
					            "radio/reshuffle": self.radio_reshuffle,
 | 
				
			||||||
 | 
					            "radio/queue_remove": self.radio_queue_remove,
 | 
				
			||||||
            # "widget/sqlite": self.homepage_sqlite_widget,
 | 
					            # "widget/sqlite": self.homepage_sqlite_widget,
 | 
				
			||||||
            # "widget/lyrics": self.homepage_lyrics_widget,
 | 
					            # "widget/lyrics": self.homepage_lyrics_widget,
 | 
				
			||||||
            # "widget/radio": self.homepage_radio_widget,            
 | 
					            # "widget/radio": self.homepage_radio_widget,            
 | 
				
			||||||
@@ -88,7 +122,7 @@ class Radio(FastAPI):
 | 
				
			|||||||
        asyncio.get_event_loop().run_until_complete(self.load_playlist())
 | 
					        asyncio.get_event_loop().run_until_complete(self.load_playlist())
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
    async def get_queue_item_by_uuid(self, uuid: str) -> tuple[int, dict] | None:
 | 
					    def get_queue_item_by_uuid(self, uuid: str) -> tuple[int, dict] | None:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Get queue item by UUID
 | 
					        Get queue item by UUID
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
@@ -99,7 +133,7 @@ class Radio(FastAPI):
 | 
				
			|||||||
        for x, item in enumerate(self.active_playlist):
 | 
					        for x, item in enumerate(self.active_playlist):
 | 
				
			||||||
            if item.get('uuid') == uuid:
 | 
					            if item.get('uuid') == uuid:
 | 
				
			||||||
                return (x, item)
 | 
					                return (x, item)
 | 
				
			||||||
        return False
 | 
					        return None
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    async def _ls_skip(self) -> bool:
 | 
					    async def _ls_skip(self) -> bool:
 | 
				
			||||||
         async with ClientSession() as session:
 | 
					         async with ClientSession() as session:
 | 
				
			||||||
@@ -111,24 +145,40 @@ class Radio(FastAPI):
 | 
				
			|||||||
    
 | 
					    
 | 
				
			||||||
    async def radio_skip(self, data: ValidRadioNextRequest, request: Request) -> bool:
 | 
					    async def radio_skip(self, data: ValidRadioNextRequest, request: Request) -> bool:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Skip to the next track in the queue
 | 
					        Skip to the next track in the queue, or to uuid specified in skipTo if provided
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            if not self.util.check_key(path=request.url.path, req_type=4, key=data.key):
 | 
					            if not self.util.check_key(path=request.url.path, req_type=4, key=data.key):
 | 
				
			||||||
                raise HTTPException(status_code=403, detail="Unauthorized")
 | 
					                raise HTTPException(status_code=403, detail="Unauthorized")
 | 
				
			||||||
 | 
					            if data.skipTo:
 | 
				
			||||||
 | 
					                (x, _) = self.get_queue_item_by_uuid(data.skipTo)
 | 
				
			||||||
 | 
					                self.active_playlist = self.active_playlist[x:]
 | 
				
			||||||
 | 
					                if not self.active_playlist:
 | 
				
			||||||
 | 
					                    await self.load_playlist()
 | 
				
			||||||
            return await self._ls_skip()
 | 
					            return await self._ls_skip()
 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
            traceback.print_exc()
 | 
					            traceback.print_exc()
 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
                
 | 
					                
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    async def radio_reshuffle(self, data: ValidRadioReshuffleRequest, request: Request) -> dict:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Reshuffle the play queue
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if not self.util.check_key(path=request.url.path, req_type=4, key=data.key):
 | 
				
			||||||
 | 
					            raise HTTPException(status_code=403, detail="Unauthorized")
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
					        random.shuffle(self.active_playlist)
 | 
				
			||||||
 | 
					        return {
 | 
				
			||||||
 | 
					            'ok': True
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    async def radio_get_queue(self, request: Request, limit: int = 100) -> dict:
 | 
					    async def radio_get_queue(self, request: Request, limit: int = 20_000) -> dict:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Get current play queue, up to limit n [default: 100]
 | 
					        Get current play queue, up to limit n [default: 20k]
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            limit (int): Number of results to return (default 100)
 | 
					            limit (int): Number of results to return (default 20k)
 | 
				
			||||||
        Returns: 
 | 
					        Returns: 
 | 
				
			||||||
            dict
 | 
					            dict
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@@ -146,6 +196,37 @@ class Radio(FastAPI):
 | 
				
			|||||||
            'items': queue_out
 | 
					            'items': queue_out
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
					    async def radio_queue_shift(self, data: ValidRadioQueueShiftRequest, request: Request) -> dict:
 | 
				
			||||||
 | 
					        """Shift position of a UUID within the queue [currently limited to playing next or immediately]"""
 | 
				
			||||||
 | 
					        if not self.util.check_key(path=request.url.path, req_type=4, key=data.key):
 | 
				
			||||||
 | 
					            raise HTTPException(status_code=403, detail="Unauthorized")
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        (x, item) = self.get_queue_item_by_uuid(data.uuid)
 | 
				
			||||||
 | 
					        self.active_playlist.pop(x)        
 | 
				
			||||||
 | 
					        self.active_playlist.insert(0, item)
 | 
				
			||||||
 | 
					        if not data.next:
 | 
				
			||||||
 | 
					            await self._ls_skip()
 | 
				
			||||||
 | 
					        return {
 | 
				
			||||||
 | 
					            'ok': True,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    async def radio_queue_remove(self, data: ValidRadioQueueRemovalRequest, request: Request) -> dict:
 | 
				
			||||||
 | 
					        """Remove an item from the current play queue"""
 | 
				
			||||||
 | 
					        if not self.util.check_key(path=request.url.path, req_type=4, key=data.key):
 | 
				
			||||||
 | 
					            raise HTTPException(status_code=403, detail="Unauthorized")
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        (x, found_item) = self.get_queue_item_by_uuid(data.uuid)
 | 
				
			||||||
 | 
					        if not found_item:
 | 
				
			||||||
 | 
					            return {
 | 
				
			||||||
 | 
					                'ok': False,
 | 
				
			||||||
 | 
					                'err': 'UUID not found in play queue',
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        self.active_playlist.pop(x)
 | 
				
			||||||
 | 
					        return {
 | 
				
			||||||
 | 
					            'ok': True,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					         
 | 
				
			||||||
    async def search_playlist(self, artistsong: str|None = None, artist: str|None = None, song: str|None = None) -> bool:
 | 
					    async def search_playlist(self, artistsong: str|None = None, artist: str|None = None, song: str|None = None) -> bool:
 | 
				
			||||||
        if artistsong and (artist or song):
 | 
					        if artistsong and (artist or song):
 | 
				
			||||||
            raise RadioException("Cannot search using combination provided")
 | 
					            raise RadioException("Cannot search using combination provided")
 | 
				
			||||||
@@ -250,7 +331,8 @@ class Radio(FastAPI):
 | 
				
			|||||||
        return ret_obj
 | 
					        return ret_obj
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    async def radio_get_next(self, data: ValidRadioNextRequest, request: Request) -> dict:
 | 
					    async def radio_get_next(self, data: ValidRadioNextRequest, request: Request,
 | 
				
			||||||
 | 
					                             background_tasks: BackgroundTasks) -> dict:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Get next track
 | 
					        Get next track
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
@@ -285,6 +367,10 @@ class Radio(FastAPI):
 | 
				
			|||||||
            self.now_playing = next
 | 
					            self.now_playing = next
 | 
				
			||||||
            next['start'] = time_started
 | 
					            next['start'] = time_started
 | 
				
			||||||
            next['end'] = time_ends
 | 
					            next['end'] = time_ends
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                background_tasks.add_task(self.radio_util.webhook_song_change, next)
 | 
				
			||||||
 | 
					            except Exception as e:
 | 
				
			||||||
 | 
					                traceback.print_exc()
 | 
				
			||||||
            return next
 | 
					            return next
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return await self.radio_pop_track(request, recursion_type="not list: self.active_playlist")
 | 
					            return await self.radio_pop_track(request, recursion_type="not list: self.active_playlist")
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										119
									
								
								endpoints/radio_util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								endpoints/radio_util.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,119 @@
 | 
				
			|||||||
 | 
					#!/usr/bin/env python3.12
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					Radio Utils
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					import traceback
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import datetime
 | 
				
			||||||
 | 
					from aiohttp import ClientSession, ClientTimeout
 | 
				
			||||||
 | 
					import gpt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RadioUtil:
 | 
				
			||||||
 | 
					    def __init__(self, constants) -> None:
 | 
				
			||||||
 | 
					        self.constants = constants
 | 
				
			||||||
 | 
					        self.gpt = gpt.GPT(self.constants)
 | 
				
			||||||
 | 
					        self.webhooks = {
 | 
				
			||||||
 | 
					                'gpt': {
 | 
				
			||||||
 | 
					                    'hook': self.constants.GPT_WEBHOOK,
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					                'sfm': {
 | 
				
			||||||
 | 
					                    'hook': self.constants.SFM_WEBHOOK,
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					    def duration_conv(self, s: int|float) -> str:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Convert duration given in seconds to hours, minutes, and seconds (h:m:s)
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            s (int|float): seconds to convert
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            str
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return str(datetime.timedelta(seconds=s)).split(".", maxsplit=1)[0]
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    async def get_ai_song_info(self, artist: str, song: str) -> str|None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Get AI Song Info
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            artist (str)
 | 
				
			||||||
 | 
					            song (str)
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            str|None
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        response = await self.gpt.get_completion(prompt=f"I am going to listen to {song} by {artist}.")
 | 
				
			||||||
 | 
					        if not response:
 | 
				
			||||||
 | 
					            logging.critical("No response received from GPT?")
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        return response
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    async def webhook_song_change(self, track: dict) -> None:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            """
 | 
				
			||||||
 | 
					            Handles Song Change Outbounds (Webhooks)
 | 
				
			||||||
 | 
					            Args: 
 | 
				
			||||||
 | 
					                track (dict)
 | 
				
			||||||
 | 
					            Returns:
 | 
				
			||||||
 | 
					                None
 | 
				
			||||||
 | 
					            """
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            # First, send track info
 | 
				
			||||||
 | 
					            friendly_track_start = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(track['start']))
 | 
				
			||||||
 | 
					            friendly_track_end = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(track['end']))
 | 
				
			||||||
 | 
					            hook_data = {
 | 
				
			||||||
 | 
					                'username': 'serious.FM',
 | 
				
			||||||
 | 
					                    "embeds": [{
 | 
				
			||||||
 | 
					                        "title": "Now Playing",
 | 
				
			||||||
 | 
					                        "description": f'# {track['song']}\nby **{track['artist']}**',
 | 
				
			||||||
 | 
					                        "color": 0x30c56f,
 | 
				
			||||||
 | 
					                        "fields": [
 | 
				
			||||||
 | 
					                            {
 | 
				
			||||||
 | 
					                                "name": "Duration",
 | 
				
			||||||
 | 
					                                "value": self.duration_conv(track['duration']),
 | 
				
			||||||
 | 
					                                "inline": True,
 | 
				
			||||||
 | 
					                            },
 | 
				
			||||||
 | 
					                            {
 | 
				
			||||||
 | 
					                                "name": "Filetype",
 | 
				
			||||||
 | 
					                                "value": track['file_path'].rsplit(".", maxsplit=1)[1],
 | 
				
			||||||
 | 
					                                "inline": True,
 | 
				
			||||||
 | 
					                            },
 | 
				
			||||||
 | 
					                            {
 | 
				
			||||||
 | 
					                                "name": "Higher Res",
 | 
				
			||||||
 | 
					                                "value": "[stream/icecast](https://relay.sfm.codey.lol/aces.ogg) || [web player](https://codey.lol/radio)",
 | 
				
			||||||
 | 
					                                "inline": True,
 | 
				
			||||||
 | 
					                            }
 | 
				
			||||||
 | 
					                        ]
 | 
				
			||||||
 | 
					                    }]
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            sfm_hook = self.webhooks['gpt'].get('hook')
 | 
				
			||||||
 | 
					            async with ClientSession() as session:
 | 
				
			||||||
 | 
					                    async with await session.post(sfm_hook, json=hook_data,
 | 
				
			||||||
 | 
					                                timeout=ClientTimeout(connect=5, sock_read=5), headers={
 | 
				
			||||||
 | 
					                                    'content-type': 'application/json; charset=utf-8',}) as request:
 | 
				
			||||||
 | 
					                        request.raise_for_status()            
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            # Next, AI feedback            
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            ai_response = await self.get_ai_song_info(track['artist'],
 | 
				
			||||||
 | 
					                                                  track['song'])
 | 
				
			||||||
 | 
					            hook_data = {
 | 
				
			||||||
 | 
					                'username': 'GPT',
 | 
				
			||||||
 | 
					                    "embeds": [{
 | 
				
			||||||
 | 
					                        "title": "AI Feedback",
 | 
				
			||||||
 | 
					                        "color": 0x35d0ff,
 | 
				
			||||||
 | 
					                        "description": ai_response.strip(),
 | 
				
			||||||
 | 
					                    }]
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            ai_hook = self.webhooks['gpt'].get('hook')
 | 
				
			||||||
 | 
					            async with ClientSession() as session:
 | 
				
			||||||
 | 
					                    async with await session.post(ai_hook, json=hook_data,
 | 
				
			||||||
 | 
					                                timeout=ClientTimeout(connect=5, sock_read=5), headers={
 | 
				
			||||||
 | 
					                                    'content-type': 'application/json; charset=utf-8',}) as request:
 | 
				
			||||||
 | 
					                        request.raise_for_status()
 | 
				
			||||||
 | 
					        except Exception as e:
 | 
				
			||||||
 | 
					            traceback.print_exc()
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
							
								
								
									
										33
									
								
								gpt/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								gpt/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,33 @@
 | 
				
			|||||||
 | 
					#!/usr/bin/env python3.12
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					from openai import AsyncOpenAI
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GPT:
 | 
				
			||||||
 | 
					    def __init__(self, constants):
 | 
				
			||||||
 | 
					        self.constants = constants
 | 
				
			||||||
 | 
					        self.api_key = self.constants.OPENAI_API_KEY
 | 
				
			||||||
 | 
					        self.client = AsyncOpenAI(
 | 
				
			||||||
 | 
					            api_key=self.api_key,
 | 
				
			||||||
 | 
					            timeout=10.0,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.default_system_prompt = "You are a helpful assistant who will provide tidbits of info on songs the user may listen to."
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def get_completion(self, prompt: str, system_prompt: Optional[str] = None) -> None:
 | 
				
			||||||
 | 
					        if not system_prompt:
 | 
				
			||||||
 | 
					            system_prompt = self.default_system_prompt
 | 
				
			||||||
 | 
					        chat_completion = await self.client.chat.completions.create(
 | 
				
			||||||
 | 
					            messages=[
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "role": "system",
 | 
				
			||||||
 | 
					                    "content": system_prompt,
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "role": "user",
 | 
				
			||||||
 | 
					                    "content": prompt,
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					            model="gpt-4o-mini",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return chat_completion.choices[0].message.content
 | 
				
			||||||
		Reference in New Issue
	
	Block a user