formatting
This commit is contained in:
@@ -41,6 +41,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class ConnectionStatus(Enum):
|
class ConnectionStatus(Enum):
|
||||||
"""Connection status enum for better tracking."""
|
"""Connection status enum for better tracking."""
|
||||||
|
|
||||||
DISCONNECTED = "disconnected"
|
DISCONNECTED = "disconnected"
|
||||||
CONNECTING = "connecting"
|
CONNECTING = "connecting"
|
||||||
CONNECTED = "connected"
|
CONNECTED = "connected"
|
||||||
@@ -70,7 +71,7 @@ class Lighting:
|
|||||||
|
|
||||||
Manages authentication and device control for Cync smart lights.
|
Manages authentication and device control for Cync smart lights.
|
||||||
Uses pycync library which maintains a TCP connection for device commands.
|
Uses pycync library which maintains a TCP connection for device commands.
|
||||||
|
|
||||||
2FA Handling:
|
2FA Handling:
|
||||||
- When 2FA is required, status changes to AWAITING_2FA
|
- When 2FA is required, status changes to AWAITING_2FA
|
||||||
- Set the 2FA code via Redis: SET cync:2fa_code "123456"
|
- Set the 2FA code via Redis: SET cync:2fa_code "123456"
|
||||||
@@ -222,7 +223,9 @@ class Lighting:
|
|||||||
"last_error": self._state.last_error,
|
"last_error": self._state.last_error,
|
||||||
"updated_at": time.time(),
|
"updated_at": time.time(),
|
||||||
}
|
}
|
||||||
self.redis_client.set(self.REDIS_STATUS_KEY, json.dumps(status_data), ex=300)
|
self.redis_client.set(
|
||||||
|
self.REDIS_STATUS_KEY, json.dumps(status_data), ex=300
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Failed to update status in Redis: {e}")
|
logger.debug(f"Failed to update status in Redis: {e}")
|
||||||
|
|
||||||
@@ -442,7 +445,7 @@ class Lighting:
|
|||||||
async def _handle_2fa(self) -> None:
|
async def _handle_2fa(self) -> None:
|
||||||
"""
|
"""
|
||||||
Handle 2FA authentication by polling Redis for the code.
|
Handle 2FA authentication by polling Redis for the code.
|
||||||
|
|
||||||
This is non-blocking - it sets the status to AWAITING_2FA and starts
|
This is non-blocking - it sets the status to AWAITING_2FA and starts
|
||||||
a background task to poll for the code. The code can be provided via:
|
a background task to poll for the code. The code can be provided via:
|
||||||
1. Environment variable CYNC_2FA_CODE (checked first)
|
1. Environment variable CYNC_2FA_CODE (checked first)
|
||||||
@@ -458,9 +461,11 @@ class Lighting:
|
|||||||
|
|
||||||
# Set status and start polling Redis
|
# Set status and start polling Redis
|
||||||
self._state.status = ConnectionStatus.AWAITING_2FA
|
self._state.status = ConnectionStatus.AWAITING_2FA
|
||||||
self._state.last_error = "2FA code required - check email and submit via API or Redis"
|
self._state.last_error = (
|
||||||
|
"2FA code required - check email and submit via API or Redis"
|
||||||
|
)
|
||||||
self._update_status_in_redis()
|
self._update_status_in_redis()
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Cync 2FA required. Submit code via POST /lighting/2fa or "
|
"Cync 2FA required. Submit code via POST /lighting/2fa or "
|
||||||
f"set Redis key '{self.REDIS_2FA_KEY}'"
|
f"set Redis key '{self.REDIS_2FA_KEY}'"
|
||||||
@@ -469,28 +474,28 @@ class Lighting:
|
|||||||
# Start background polling task if not already running
|
# Start background polling task if not already running
|
||||||
if self._2fa_task is None or self._2fa_task.done():
|
if self._2fa_task is None or self._2fa_task.done():
|
||||||
self._2fa_task = asyncio.create_task(self._poll_for_2fa_code())
|
self._2fa_task = asyncio.create_task(self._poll_for_2fa_code())
|
||||||
|
|
||||||
# Raise to signal caller that we're waiting for 2FA
|
# Raise to signal caller that we're waiting for 2FA
|
||||||
raise TwoFactorRequiredError("Awaiting 2FA code via Redis or API")
|
raise TwoFactorRequiredError("Awaiting 2FA code via Redis or API")
|
||||||
|
|
||||||
async def _poll_for_2fa_code(self) -> None:
|
async def _poll_for_2fa_code(self) -> None:
|
||||||
"""Background task to poll Redis for 2FA code."""
|
"""Background task to poll Redis for 2FA code."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
while time.time() - start_time < self.TWO_FA_TIMEOUT:
|
while time.time() - start_time < self.TWO_FA_TIMEOUT:
|
||||||
try:
|
try:
|
||||||
# Check Redis for 2FA code
|
# Check Redis for 2FA code
|
||||||
code = self.redis_client.get(self.REDIS_2FA_KEY)
|
code = self.redis_client.get(self.REDIS_2FA_KEY)
|
||||||
|
|
||||||
if code:
|
if code:
|
||||||
code_str = code.decode() if isinstance(code, bytes) else str(code)
|
code_str = code.decode() if isinstance(code, bytes) else str(code)
|
||||||
code_str = code_str.strip()
|
code_str = code_str.strip()
|
||||||
|
|
||||||
if code_str:
|
if code_str:
|
||||||
logger.info("Found 2FA code in Redis, attempting login...")
|
logger.info("Found 2FA code in Redis, attempting login...")
|
||||||
# Clear the code from Redis immediately
|
# Clear the code from Redis immediately
|
||||||
self.redis_client.delete(self.REDIS_2FA_KEY)
|
self.redis_client.delete(self.REDIS_2FA_KEY)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._complete_2fa_login(code_str)
|
await self._complete_2fa_login(code_str)
|
||||||
logger.info("2FA login successful via Redis polling")
|
logger.info("2FA login successful via Redis polling")
|
||||||
@@ -500,16 +505,16 @@ class Lighting:
|
|||||||
self._state.last_error = f"2FA login failed: {e}"
|
self._state.last_error = f"2FA login failed: {e}"
|
||||||
self._update_status_in_redis()
|
self._update_status_in_redis()
|
||||||
# Continue polling in case user wants to retry
|
# Continue polling in case user wants to retry
|
||||||
|
|
||||||
await asyncio.sleep(self.TWO_FA_POLL_INTERVAL)
|
await asyncio.sleep(self.TWO_FA_POLL_INTERVAL)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("2FA polling task cancelled")
|
logger.info("2FA polling task cancelled")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error polling for 2FA code: {e}")
|
logger.error(f"Error polling for 2FA code: {e}")
|
||||||
await asyncio.sleep(self.TWO_FA_POLL_INTERVAL)
|
await asyncio.sleep(self.TWO_FA_POLL_INTERVAL)
|
||||||
|
|
||||||
# Timeout reached
|
# Timeout reached
|
||||||
logger.error(f"2FA code timeout after {self.TWO_FA_TIMEOUT}s")
|
logger.error(f"2FA code timeout after {self.TWO_FA_TIMEOUT}s")
|
||||||
self._state.status = ConnectionStatus.ERROR
|
self._state.status = ConnectionStatus.ERROR
|
||||||
@@ -520,21 +525,21 @@ class Lighting:
|
|||||||
"""Complete the 2FA login process with the provided code."""
|
"""Complete the 2FA login process with the provided code."""
|
||||||
if not code:
|
if not code:
|
||||||
raise ValueError("Empty 2FA code provided")
|
raise ValueError("Empty 2FA code provided")
|
||||||
|
|
||||||
logger.info("Completing 2FA login...")
|
logger.info("Completing 2FA login...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert self._state.auth is not None, "Auth not initialized"
|
assert self._state.auth is not None, "Auth not initialized"
|
||||||
self._state.user = await self._state.auth.login(two_factor_code=code)
|
self._state.user = await self._state.auth.login(two_factor_code=code)
|
||||||
self._save_cached_token(self._state.user)
|
self._save_cached_token(self._state.user)
|
||||||
|
|
||||||
# Now complete the connection
|
# Now complete the connection
|
||||||
self._state.status = ConnectionStatus.CONNECTING
|
self._state.status = ConnectionStatus.CONNECTING
|
||||||
self._update_status_in_redis()
|
self._update_status_in_redis()
|
||||||
|
|
||||||
# Reconnect with the new token
|
# Reconnect with the new token
|
||||||
await self._connect(force=True)
|
await self._connect(force=True)
|
||||||
|
|
||||||
logger.info("Cync 2FA login successful")
|
logger.info("Cync 2FA login successful")
|
||||||
except TwoFactorRequiredError:
|
except TwoFactorRequiredError:
|
||||||
# Code was invalid, still needs 2FA
|
# Code was invalid, still needs 2FA
|
||||||
@@ -606,7 +611,7 @@ class Lighting:
|
|||||||
async def _health_monitor(self) -> None:
|
async def _health_monitor(self) -> None:
|
||||||
"""
|
"""
|
||||||
Background task to monitor connection health and reconnect aggressively.
|
Background task to monitor connection health and reconnect aggressively.
|
||||||
|
|
||||||
Checks every HEALTH_CHECK_INTERVAL seconds and reconnects if:
|
Checks every HEALTH_CHECK_INTERVAL seconds and reconnects if:
|
||||||
- Token is expiring soon
|
- Token is expiring soon
|
||||||
- TCP connection appears dead
|
- TCP connection appears dead
|
||||||
@@ -615,7 +620,7 @@ class Lighting:
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(self.HEALTH_CHECK_INTERVAL)
|
await asyncio.sleep(self.HEALTH_CHECK_INTERVAL)
|
||||||
|
|
||||||
# Skip health checks if awaiting 2FA
|
# Skip health checks if awaiting 2FA
|
||||||
if self._state.status == ConnectionStatus.AWAITING_2FA:
|
if self._state.status == ConnectionStatus.AWAITING_2FA:
|
||||||
continue
|
continue
|
||||||
@@ -643,7 +648,7 @@ class Lighting:
|
|||||||
logger.warning(f"Health monitor triggering reconnection: {reason}")
|
logger.warning(f"Health monitor triggering reconnection: {reason}")
|
||||||
self._state.status = ConnectionStatus.CONNECTING
|
self._state.status = ConnectionStatus.CONNECTING
|
||||||
self._update_status_in_redis()
|
self._update_status_in_redis()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._connect(force=True)
|
await self._connect(force=True)
|
||||||
logger.info("Health monitor reconnection successful")
|
logger.info("Health monitor reconnection successful")
|
||||||
@@ -894,7 +899,7 @@ class Lighting:
|
|||||||
async def get_connection_status(self) -> JSONResponse:
|
async def get_connection_status(self) -> JSONResponse:
|
||||||
"""
|
"""
|
||||||
Get the current Cync connection status.
|
Get the current Cync connection status.
|
||||||
|
|
||||||
Returns status, error info, and timing information.
|
Returns status, error info, and timing information.
|
||||||
No authentication required - useful for monitoring.
|
No authentication required - useful for monitoring.
|
||||||
"""
|
"""
|
||||||
@@ -902,56 +907,63 @@ class Lighting:
|
|||||||
# Try to get from Redis first (more up-to-date)
|
# Try to get from Redis first (more up-to-date)
|
||||||
cached = self.redis_client.get(self.REDIS_STATUS_KEY)
|
cached = self.redis_client.get(self.REDIS_STATUS_KEY)
|
||||||
if cached:
|
if cached:
|
||||||
data = json.loads(cached.decode() if isinstance(cached, bytes) else str(cached))
|
data = json.loads(
|
||||||
|
cached.decode() if isinstance(cached, bytes) else str(cached)
|
||||||
|
)
|
||||||
return JSONResponse(content=data)
|
return JSONResponse(content=data)
|
||||||
|
|
||||||
# Fall back to current state
|
# Fall back to current state
|
||||||
return JSONResponse(content={
|
return JSONResponse(
|
||||||
"status": self._state.status.value,
|
content={
|
||||||
"connected_at": self._state.connected_at,
|
"status": self._state.status.value,
|
||||||
"last_command_at": self._state.last_command_at,
|
"connected_at": self._state.connected_at,
|
||||||
"last_successful_command": self._state.last_successful_command,
|
"last_command_at": self._state.last_command_at,
|
||||||
"consecutive_failures": self._state.consecutive_failures,
|
"last_successful_command": self._state.last_successful_command,
|
||||||
"last_error": self._state.last_error,
|
"consecutive_failures": self._state.consecutive_failures,
|
||||||
"updated_at": time.time(),
|
"last_error": self._state.last_error,
|
||||||
})
|
"updated_at": time.time(),
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting connection status: {e}")
|
logger.error(f"Error getting connection status: {e}")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500, content={"error": str(e), "status": "unknown"}
|
||||||
content={"error": str(e), "status": "unknown"}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def submit_2fa_code(self, request: Request) -> JSONResponse:
|
async def submit_2fa_code(self, request: Request) -> JSONResponse:
|
||||||
"""
|
"""
|
||||||
Submit a 2FA code for Cync authentication.
|
Submit a 2FA code for Cync authentication.
|
||||||
|
|
||||||
The code will be stored in Redis and picked up by the polling task.
|
The code will be stored in Redis and picked up by the polling task.
|
||||||
No authentication required since 2FA is needed to set up the connection.
|
No authentication required since 2FA is needed to set up the connection.
|
||||||
|
|
||||||
Request body: {"code": "123456"}
|
Request body: {"code": "123456"}
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
code = body.get("code", "").strip()
|
code = body.get("code", "").strip()
|
||||||
|
|
||||||
if not code:
|
if not code:
|
||||||
raise HTTPException(status_code=400, detail="Missing 'code' in request body")
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Missing 'code' in request body"
|
||||||
|
)
|
||||||
|
|
||||||
if not code.isdigit() or len(code) != 6:
|
if not code.isdigit() or len(code) != 6:
|
||||||
raise HTTPException(status_code=400, detail="Code must be 6 digits")
|
raise HTTPException(status_code=400, detail="Code must be 6 digits")
|
||||||
|
|
||||||
# Store in Redis for the polling task to pick up
|
# Store in Redis for the polling task to pick up
|
||||||
self.redis_client.set(self.REDIS_2FA_KEY, code, ex=self.TWO_FA_TIMEOUT)
|
self.redis_client.set(self.REDIS_2FA_KEY, code, ex=self.TWO_FA_TIMEOUT)
|
||||||
|
|
||||||
logger.info("2FA code submitted via API")
|
logger.info("2FA code submitted via API")
|
||||||
|
|
||||||
return JSONResponse(content={
|
return JSONResponse(
|
||||||
"message": "2FA code submitted successfully",
|
content={
|
||||||
"status": self._state.status.value,
|
"message": "2FA code submitted successfully",
|
||||||
"note": "The code will be used on the next authentication attempt"
|
"status": self._state.status.value,
|
||||||
})
|
"note": "The code will be used on the next authentication attempt",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -961,31 +973,35 @@ class Lighting:
|
|||||||
async def force_reconnect(self, user=Depends(get_current_user)) -> JSONResponse:
|
async def force_reconnect(self, user=Depends(get_current_user)) -> JSONResponse:
|
||||||
"""
|
"""
|
||||||
Force a reconnection to the Cync service.
|
Force a reconnection to the Cync service.
|
||||||
|
|
||||||
Requires admin or lighting role.
|
Requires admin or lighting role.
|
||||||
"""
|
"""
|
||||||
if "lighting" not in user.get("roles", []) and "admin" not in user.get("roles", []):
|
if "lighting" not in user.get("roles", []) and "admin" not in user.get(
|
||||||
|
"roles", []
|
||||||
|
):
|
||||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("Force reconnect requested via API")
|
logger.info("Force reconnect requested via API")
|
||||||
self._state.status = ConnectionStatus.CONNECTING
|
self._state.status = ConnectionStatus.CONNECTING
|
||||||
self._update_status_in_redis()
|
self._update_status_in_redis()
|
||||||
|
|
||||||
await self._connect(force=True)
|
await self._connect(force=True)
|
||||||
|
|
||||||
return JSONResponse(content={
|
return JSONResponse(
|
||||||
"message": "Reconnection successful",
|
content={
|
||||||
"status": self._state.status.value,
|
"message": "Reconnection successful",
|
||||||
})
|
"status": self._state.status.value,
|
||||||
|
}
|
||||||
|
)
|
||||||
except TwoFactorRequiredError:
|
except TwoFactorRequiredError:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=202,
|
status_code=202,
|
||||||
content={
|
content={
|
||||||
"message": "Reconnection requires 2FA",
|
"message": "Reconnection requires 2FA",
|
||||||
"status": ConnectionStatus.AWAITING_2FA.value,
|
"status": ConnectionStatus.AWAITING_2FA.value,
|
||||||
"action": "Submit 2FA code via POST /lighting/2fa"
|
"action": "Submit 2FA code via POST /lighting/2fa",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Force reconnect failed: {e}")
|
logger.error(f"Force reconnect failed: {e}")
|
||||||
|
|||||||
@@ -19,28 +19,28 @@ def normalize_for_search(s: str) -> str:
|
|||||||
Removes common variations that cause exact match failures.
|
Removes common variations that cause exact match failures.
|
||||||
"""
|
"""
|
||||||
s = s.lower().strip()
|
s = s.lower().strip()
|
||||||
|
|
||||||
# Remove parenthetical content: (Remastered), (feat. X), (2020 Remix), etc.
|
# Remove parenthetical content: (Remastered), (feat. X), (2020 Remix), etc.
|
||||||
s = re.sub(r'\s*\([^)]*\)\s*', ' ', s)
|
s = re.sub(r"\s*\([^)]*\)\s*", " ", s)
|
||||||
|
|
||||||
# Remove bracketed content: [Explicit], [Deluxe Edition], etc.
|
# Remove bracketed content: [Explicit], [Deluxe Edition], etc.
|
||||||
s = re.sub(r'\s*\[[^\]]*\]\s*', ' ', s)
|
s = re.sub(r"\s*\[[^\]]*\]\s*", " ", s)
|
||||||
|
|
||||||
# Remove "feat.", "ft.", "featuring" and everything after
|
# Remove "feat.", "ft.", "featuring" and everything after
|
||||||
s = re.sub(r'\s*(feat\.?|ft\.?|featuring)\s+.*$', '', s, flags=re.IGNORECASE)
|
s = re.sub(r"\s*(feat\.?|ft\.?|featuring)\s+.*$", "", s, flags=re.IGNORECASE)
|
||||||
|
|
||||||
# Remove "The " prefix from artist names
|
# Remove "The " prefix from artist names
|
||||||
s = re.sub(r'^the\s+', '', s)
|
s = re.sub(r"^the\s+", "", s)
|
||||||
|
|
||||||
# Normalize & to "and"
|
# Normalize & to "and"
|
||||||
s = re.sub(r'\s*&\s*', ' and ', s)
|
s = re.sub(r"\s*&\s*", " and ", s)
|
||||||
|
|
||||||
# Remove punctuation except spaces
|
# Remove punctuation except spaces
|
||||||
s = re.sub(r"[^\w\s]", '', s)
|
s = re.sub(r"[^\w\s]", "", s)
|
||||||
|
|
||||||
# Collapse multiple spaces
|
# Collapse multiple spaces
|
||||||
s = re.sub(r'\s+', ' ', s).strip()
|
s = re.sub(r"\s+", " ", s).strip()
|
||||||
|
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
@@ -64,12 +64,12 @@ class LRCLib:
|
|||||||
) -> Optional[LyricsResult]:
|
) -> Optional[LyricsResult]:
|
||||||
"""
|
"""
|
||||||
LRCLib Local Database Search with normalization and smart fallback.
|
LRCLib Local Database Search with normalization and smart fallback.
|
||||||
|
|
||||||
Search strategy:
|
Search strategy:
|
||||||
1. Exact match on lowercased input (fastest, ~0.1ms)
|
1. Exact match on lowercased input (fastest, ~0.1ms)
|
||||||
2. Exact match on normalized input (fast, ~0.1ms)
|
2. Exact match on normalized input (fast, ~0.1ms)
|
||||||
3. Artist trigram + song exact within results (medium, ~50-200ms)
|
3. Artist trigram + song exact within results (medium, ~50-200ms)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
artist (str): the artist to search
|
artist (str): the artist to search
|
||||||
song (str): the song to search
|
song (str): the song to search
|
||||||
@@ -110,7 +110,7 @@ class LRCLib:
|
|||||||
if not best_match:
|
if not best_match:
|
||||||
artist_norm = normalize_for_search(artist)
|
artist_norm = normalize_for_search(artist)
|
||||||
song_norm = normalize_for_search(song)
|
song_norm = normalize_for_search(song)
|
||||||
|
|
||||||
if artist_norm != artist_lower or song_norm != song_lower:
|
if artist_norm != artist_lower or song_norm != song_lower:
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(
|
select(
|
||||||
@@ -133,7 +133,7 @@ class LRCLib:
|
|||||||
if not best_match:
|
if not best_match:
|
||||||
artist_norm = normalize_for_search(artist)
|
artist_norm = normalize_for_search(artist)
|
||||||
song_norm = normalize_for_search(song)
|
song_norm = normalize_for_search(song)
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(
|
select(
|
||||||
Tracks.artist_name,
|
Tracks.artist_name,
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ Usage examples:
|
|||||||
- Disable notifications: ./migrate_sqlite_to_pg.py --no-notify
|
- Disable notifications: ./migrate_sqlite_to_pg.py --no-notify
|
||||||
- Force re-import: ./migrate_sqlite_to_pg.py --force
|
- Force re-import: ./migrate_sqlite_to_pg.py --force
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -120,7 +121,7 @@ def clean_row(row: tuple, columns: list[tuple[str, str]]) -> tuple:
|
|||||||
|
|
||||||
|
|
||||||
def escape_copy_value(value, pg_type: str) -> str:
|
def escape_copy_value(value, pg_type: str) -> str:
|
||||||
"""Escape a value for PostgreSQL COPY format (tab-separated).\n
|
"""Escape a value for PostgreSQL COPY format (tab-separated).\n
|
||||||
This is much faster than INSERT for bulk loading.
|
This is much faster than INSERT for bulk loading.
|
||||||
"""
|
"""
|
||||||
if value is None:
|
if value is None:
|
||||||
@@ -153,7 +154,7 @@ def create_table(
|
|||||||
pg_conn, table: str, columns: list[tuple[str, str]], unlogged: bool = True
|
pg_conn, table: str, columns: list[tuple[str, str]], unlogged: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create a table in PostgreSQL based on SQLite schema.
|
"""Create a table in PostgreSQL based on SQLite schema.
|
||||||
|
|
||||||
Uses UNLOGGED tables by default for faster bulk import (no WAL writes).
|
Uses UNLOGGED tables by default for faster bulk import (no WAL writes).
|
||||||
"""
|
"""
|
||||||
cur = pg_conn.cursor()
|
cur = pg_conn.cursor()
|
||||||
@@ -336,16 +337,16 @@ def create_database(db_name: str) -> None:
|
|||||||
|
|
||||||
def terminate_connections(db_name: str, max_wait: int = 10) -> bool:
|
def terminate_connections(db_name: str, max_wait: int = 10) -> bool:
|
||||||
"""Terminate all connections to a database.
|
"""Terminate all connections to a database.
|
||||||
|
|
||||||
Returns True if all connections were terminated, False if some remain.
|
Returns True if all connections were terminated, False if some remain.
|
||||||
Won't fail on permission errors (e.g., can't terminate superuser connections).
|
Won't fail on permission errors (e.g., can't terminate superuser connections).
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
|
|
||||||
conn = pg_connect("postgres")
|
conn = pg_connect("postgres")
|
||||||
conn.autocommit = True
|
conn.autocommit = True
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
|
|
||||||
for attempt in range(max_wait):
|
for attempt in range(max_wait):
|
||||||
# Check how many connections exist
|
# Check how many connections exist
|
||||||
cur.execute(
|
cur.execute(
|
||||||
@@ -354,14 +355,14 @@ def terminate_connections(db_name: str, max_wait: int = 10) -> bool:
|
|||||||
)
|
)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
count = int(row[0]) if row else 0
|
count = int(row[0]) if row else 0
|
||||||
|
|
||||||
if count == 0:
|
if count == 0:
|
||||||
cur.close()
|
cur.close()
|
||||||
conn.close()
|
conn.close()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
print(f" Terminating {count} connection(s) to {db_name}...")
|
print(f" Terminating {count} connection(s) to {db_name}...")
|
||||||
|
|
||||||
# Try to terminate - ignore errors for connections we can't kill
|
# Try to terminate - ignore errors for connections we can't kill
|
||||||
try:
|
try:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
@@ -376,10 +377,10 @@ def terminate_connections(db_name: str, max_wait: int = 10) -> bool:
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" Warning: {e}")
|
print(f" Warning: {e}")
|
||||||
|
|
||||||
# Brief wait for connections to close
|
# Brief wait for connections to close
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
# Final check
|
# Final check
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"SELECT COUNT(*) FROM pg_stat_activity WHERE datname = %s AND pid <> pg_backend_pid();",
|
"SELECT COUNT(*) FROM pg_stat_activity WHERE datname = %s AND pid <> pg_backend_pid();",
|
||||||
@@ -389,9 +390,11 @@ def terminate_connections(db_name: str, max_wait: int = 10) -> bool:
|
|||||||
remaining = int(row[0]) if row else 0
|
remaining = int(row[0]) if row else 0
|
||||||
cur.close()
|
cur.close()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
if remaining > 0:
|
if remaining > 0:
|
||||||
print(f" Warning: {remaining} connection(s) still active (may be superuser sessions)")
|
print(
|
||||||
|
f" Warning: {remaining} connection(s) still active (may be superuser sessions)"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -421,12 +424,12 @@ def rename_database(old_name: str, new_name: str) -> None:
|
|||||||
|
|
||||||
def drop_database(db_name: str) -> bool:
|
def drop_database(db_name: str) -> bool:
|
||||||
"""Drop a PostgreSQL database.
|
"""Drop a PostgreSQL database.
|
||||||
|
|
||||||
Returns True if dropped, False if failed (e.g., active connections).
|
Returns True if dropped, False if failed (e.g., active connections).
|
||||||
"""
|
"""
|
||||||
# First try to terminate connections
|
# First try to terminate connections
|
||||||
terminate_connections(db_name)
|
terminate_connections(db_name)
|
||||||
|
|
||||||
conn = pg_connect("postgres")
|
conn = pg_connect("postgres")
|
||||||
conn.autocommit = True
|
conn.autocommit = True
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
@@ -666,9 +669,7 @@ Examples:
|
|||||||
print(f"New dump available: {dump_date_str}")
|
print(f"New dump available: {dump_date_str}")
|
||||||
|
|
||||||
if notify_enabled:
|
if notify_enabled:
|
||||||
asyncio.run(
|
asyncio.run(notify_new_dump_found(latest["filename"], dump_date_str))
|
||||||
notify_new_dump_found(latest["filename"], dump_date_str)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Download
|
# Download
|
||||||
print(f"\nDownloading {latest['filename']}...")
|
print(f"\nDownloading {latest['filename']}...")
|
||||||
|
|||||||
@@ -521,7 +521,9 @@ async def download_and_extract_dump(
|
|||||||
|
|
||||||
# If an extracted sqlite file already exists, skip download and extraction
|
# If an extracted sqlite file already exists, skip download and extraction
|
||||||
if sqlite_path.exists() and sqlite_path.stat().st_size > 0:
|
if sqlite_path.exists() and sqlite_path.stat().st_size > 0:
|
||||||
print(f"Found existing extracted SQLite file {sqlite_path}; skipping download/extract")
|
print(
|
||||||
|
f"Found existing extracted SQLite file {sqlite_path}; skipping download/extract"
|
||||||
|
)
|
||||||
return str(sqlite_path), None
|
return str(sqlite_path), None
|
||||||
|
|
||||||
# Streaming download with retry and resume support
|
# Streaming download with retry and resume support
|
||||||
|
|||||||
Reference in New Issue
Block a user