#!/usr/bin/env python3 """ Upsert SQLite dump data into PostgreSQL (Async Version). This script handles incremental updates from SQLite dumps to PostgreSQL: - New rows are inserted - Existing rows are updated if changed - Uses primary keys and unique constraints to identify duplicates - Can automatically fetch new dumps from lrclib.net - Generates detailed change reports - Async with asyncpg for high performance with large datasets (17GB+) """ import asyncio import aiosqlite import asyncpg import os import gzip import re import json import aiohttp import aiofiles import subprocess from typing import List, Dict, Any, Optional from datetime import datetime from pathlib import Path from dataclasses import dataclass, field, asdict from datetime import timezone from dotenv import load_dotenv from playwright.async_api import async_playwright try: from playwright.async_api import async_playwright PLAYWRIGHT_AVAILABLE = True except ImportError: PLAYWRIGHT_AVAILABLE = False load_dotenv() # Discord webhook for notifications DISCORD_WEBHOOK_URL = "https://discord.com/api/webhooks/1332864106750939168/6y6MgFhHLX0-BnL2M3hXnt2vJsue7Q2Duf_HjZenHNlNj7sxQr4lqxrPVJnJWf7KVAm2" # Configuration - using environment variables PG_CONFIG = { "host": os.getenv("POSTGRES_HOST", "localhost"), "port": int(os.getenv("POSTGRES_PORT", "5432")), "database": os.getenv("POSTGRES_DB", "lrclib"), "user": os.getenv("POSTGRES_USER", "api"), "password": os.getenv("POSTGRES_PASSWORD", ""), } CHUNK_SIZE = 10000 # Process rows in chunks (larger for async) BATCH_SIZE = 1000 # Rows per INSERT statement # LRCLib dump URL LRCLIB_DUMPS_URL = "https://lrclib.net/db-dumps" STATE_FILE = Path(__file__).parent / ".lrclib_upsert_state.json" @dataclass class TableReport: """Report for a single table's upsert operation.""" table_name: str sqlite_rows: int = 0 pg_rows_before: int = 0 pg_rows_after: int = 0 rows_affected: int = 0 errors: int = 0 duration_seconds: float = 0.0 status: str = "pending" error_message: str = "" @property def rows_inserted(self) -> int: return max(0, self.pg_rows_after - self.pg_rows_before) @property def rows_updated(self) -> int: return max(0, self.rows_affected - self.rows_inserted) @dataclass class UpsertReport: """Complete report for an upsert operation.""" sqlite_source: str dump_date: Optional[str] = None start_time: Optional[datetime] = None end_time: Optional[datetime] = None tables: List[TableReport] = field(default_factory=list) @property def total_sqlite_rows(self) -> int: return sum(t.sqlite_rows for t in self.tables) @property def total_rows_affected(self) -> int: return sum(t.rows_affected for t in self.tables) @property def total_rows_inserted(self) -> int: return sum(t.rows_inserted for t in self.tables) @property def total_rows_updated(self) -> int: return sum(t.rows_updated for t in self.tables) @property def total_errors(self) -> int: return sum(t.errors for t in self.tables) @property def duration_seconds(self) -> float: if self.start_time and self.end_time: return (self.end_time - self.start_time).total_seconds() return 0.0 @property def successful_tables(self) -> List[str]: return [t.table_name for t in self.tables if t.status == "success"] @property def failed_tables(self) -> List[str]: return [t.table_name for t in self.tables if t.status == "failed"] def generate_summary(self) -> str: """Generate a human-readable summary report.""" lines = [ "=" * 60, "LRCLIB UPSERT REPORT", "=" * 60, f"Source: {self.sqlite_source}", ] if self.dump_date: lines.append(f"Dump Date: {self.dump_date}") if self.start_time: lines.append(f"Started: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}") if self.end_time: lines.append(f"Finished: {self.end_time.strftime('%Y-%m-%d %H:%M:%S')}") lines.append(f"Duration: {self.duration_seconds:.1f} seconds") lines.append("") lines.append("-" * 60) lines.append("TABLE DETAILS") lines.append("-" * 60) for table in self.tables: status_icon = ( "✓" if table.status == "success" else "✗" if table.status == "failed" else "○" ) lines.append(f"\n{status_icon} {table.table_name}") lines.append(f" SQLite rows: {table.sqlite_rows:>12,}") lines.append(f" PG before: {table.pg_rows_before:>12,}") lines.append(f" PG after: {table.pg_rows_after:>12,}") lines.append(f" Rows inserted: {table.rows_inserted:>12,}") lines.append(f" Rows updated: {table.rows_updated:>12,}") lines.append(f" Duration: {table.duration_seconds:>12.1f}s") if table.errors > 0: lines.append(f" Errors: {table.errors:>12}") if table.error_message: lines.append(f" Error: {table.error_message}") lines.append("") lines.append("-" * 60) lines.append("SUMMARY") lines.append("-" * 60) lines.append(f"Tables processed: {len(self.tables)}") lines.append(f" Successful: {len(self.successful_tables)}") lines.append(f" Failed: {len(self.failed_tables)}") lines.append(f"Total SQLite rows: {self.total_sqlite_rows:,}") lines.append(f"Total inserted: {self.total_rows_inserted:,}") lines.append(f"Total updated: {self.total_rows_updated:,}") lines.append(f"Total errors: {self.total_errors}") lines.append("=" * 60) return "\n".join(lines) def to_json(self) -> str: """Export report as JSON.""" data = { "sqlite_source": self.sqlite_source, "dump_date": self.dump_date, "start_time": self.start_time.isoformat() if self.start_time else None, "end_time": self.end_time.isoformat() if self.end_time else None, "duration_seconds": self.duration_seconds, "summary": { "total_sqlite_rows": self.total_sqlite_rows, "total_rows_inserted": self.total_rows_inserted, "total_rows_updated": self.total_rows_updated, "total_errors": self.total_errors, "successful_tables": self.successful_tables, "failed_tables": self.failed_tables, }, "tables": [asdict(t) for t in self.tables], } return json.dumps(data, indent=2) def load_state() -> Dict[str, Any]: """Load the last upsert state from file.""" if STATE_FILE.exists(): try: return json.loads(STATE_FILE.read_text()) except (json.JSONDecodeError, IOError): pass return {} def save_state(state: Dict[str, Any]) -> None: """Save the upsert state to file.""" STATE_FILE.write_text(json.dumps(state, indent=2)) # Discord notification colors class DiscordColor: INFO = 0x3498DB # Blue SUCCESS = 0x2ECC71 # Green WARNING = 0xF39C12 # Orange ERROR = 0xE74C3C # Red NEUTRAL = 0x95A5A6 # Gray async def discord_notify( title: str, description: str, color: int = DiscordColor.INFO, fields: Optional[List[Dict[str, Any]]] = None, footer: Optional[str] = None, ) -> bool: """ Send a Discord webhook notification. Returns True on success, False on failure. """ embed = { "title": title, "description": description[:2000] if description else "", "color": color, "timestamp": datetime.now(timezone.utc).isoformat(), } if fields: embed["fields"] = fields[:25] # Discord limit if footer: embed["footer"] = {"text": footer[:2048]} payload = {"embeds": [embed]} try: async with aiohttp.ClientSession() as session: async with session.post( DISCORD_WEBHOOK_URL, json=payload, timeout=aiohttp.ClientTimeout(total=15), ) as resp: if resp.status >= 400: text = await resp.text() print(f"Discord webhook failed ({resp.status}): {text}") return False return True except Exception as e: print(f"Discord notification error: {e}") return False async def notify_start() -> None: """Send notification that the weekly job is starting.""" await discord_notify( title="🔄 LRCLib DB Update - Starting", description="Weekly database sync job has started.", color=DiscordColor.INFO, fields=[ { "name": "Time", "value": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "inline": True, }, ], ) async def notify_no_update(latest_dump: str, last_upsert: str) -> None: """Send notification that no new dump was found.""" await discord_notify( title="â„šī¸ LRCLib DB Update - No Update Needed", description="No newer database dump available.", color=DiscordColor.NEUTRAL, fields=[ {"name": "Latest Available", "value": latest_dump, "inline": True}, {"name": "Last Synced", "value": last_upsert, "inline": True}, ], ) async def notify_new_dump_found(dump_filename: str, dump_date: str) -> None: """Send notification that a new dump was found and download is starting.""" await discord_notify( title="đŸ“Ĩ LRCLib DB Update - New Dump Found", description="Downloading and processing new database dump.", color=DiscordColor.INFO, fields=[ {"name": "Filename", "value": dump_filename, "inline": False}, {"name": "Dump Date", "value": dump_date, "inline": True}, ], ) async def notify_success(report: "UpsertReport") -> None: """Send notification for successful upsert.""" duration_min = report.duration_seconds / 60 # Build table summary table_summary = [] for t in report.tables[:10]: # Limit to first 10 tables status = "✓" if t.status == "success" else "✗" if t.status == "failed" else "○" table_summary.append(f"{status} {t.table_name}: +{t.rows_inserted:,}") if len(report.tables) > 10: table_summary.append(f"... and {len(report.tables) - 10} more tables") await discord_notify( title="✅ LRCLib DB Update - Success", description="Database sync completed successfully.", color=DiscordColor.SUCCESS, fields=[ {"name": "Dump Date", "value": report.dump_date or "N/A", "inline": True}, {"name": "Duration", "value": f"{duration_min:.1f} min", "inline": True}, {"name": "Tables", "value": str(len(report.tables)), "inline": True}, { "name": "Rows Inserted", "value": f"{report.total_rows_inserted:,}", "inline": True, }, { "name": "Rows Updated", "value": f"{report.total_rows_updated:,}", "inline": True, }, {"name": "Errors", "value": str(report.total_errors), "inline": True}, { "name": "Table Summary", "value": "\n".join(table_summary) or "No tables", "inline": False, }, ], footer=f"Source: {report.sqlite_source}", ) async def notify_failure( error_message: str, stage: str, details: Optional[str] = None ) -> None: """Send notification for failed upsert.""" description = ( f"Database sync failed during: **{stage}**\n\n```{error_message[:1500]}```" ) fields = [ {"name": "Stage", "value": stage, "inline": True}, { "name": "Time", "value": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "inline": True, }, ] if details: fields.append({"name": "Details", "value": details[:1000], "inline": False}) await discord_notify( title="❌ LRCLib DB Update - Failed", description=description, color=DiscordColor.ERROR, fields=fields, ) def parse_dump_date(filename: str) -> Optional[datetime]: """ Extract date from dump filename. Supports formats: - lrclib-db-dump-20260122T091251Z.sqlite3.gz (ISO format with time) - lrclib-db-dump-2025-01-20.sqlite3.gz (date only) """ # Try ISO format first: 20260122T091251Z match = re.search(r"(\d{4})(\d{2})(\d{2})T(\d{2})(\d{2})(\d{2})Z", filename) if match: return datetime( int(match.group(1)), # year int(match.group(2)), # month int(match.group(3)), # day int(match.group(4)), # hour int(match.group(5)), # minute int(match.group(6)), # second ) # Try simple date format: 2025-01-20 match = re.search(r"(\d{4})-(\d{2})-(\d{2})", filename) if match: return datetime.strptime(match.group(0), "%Y-%m-%d") return None async def fetch_latest_dump_info() -> Optional[Dict[str, Any]]: """ Fetch the latest dump info from lrclib.net/db-dumps. Uses Playwright to render the JS-generated page content. Returns dict with 'url', 'filename', 'date' or None if not found. """ print(f"Fetching dump list from {LRCLIB_DUMPS_URL}...") if not PLAYWRIGHT_AVAILABLE: print("Error: playwright is required for fetching dump info.") print("Install with: pip install playwright && playwright install chromium") return None html = None try: async with async_playwright() as p: # Launch headless browser browser = await p.chromium.launch(headless=True) page = await browser.new_page() # Navigate and wait for content to load await page.goto(LRCLIB_DUMPS_URL, wait_until="networkidle") # Wait for anchor tags with .sqlite3.gz links to appear try: await page.wait_for_selector('a[href*=".sqlite3.gz"]', timeout=15000) except Exception: print("Warning: Timed out waiting for download links, trying anyway...") # Get rendered HTML html = await page.content() await browser.close() except Exception as e: print(f"Error rendering page with Playwright: {e}") return None if not html: print("Failed to get page content") return None # Find all .sqlite3.gz links pattern = r'href=["\']([^"\']*\.sqlite3\.gz)["\']' matches = re.findall(pattern, html, re.IGNORECASE) # Deduplicate matches (same link may appear multiple times in HTML) matches = list(dict.fromkeys(matches)) if not matches: print("No .sqlite3.gz files found on the page") # Debug: show what we got print(f"Page content length: {len(html)} chars") return None print(f"Found {len(matches)} dump file(s)") # Parse dates and find the newest dumps = [] for match in matches: # Handle relative URLs if match.startswith("/"): url = f"https://lrclib.net{match}" elif not match.startswith("http"): url = f"https://lrclib.net/db-dumps/{match}" else: url = match filename = url.split("/")[-1] date = parse_dump_date(filename) if date: dumps.append({"url": url, "filename": filename, "date": date}) print(f" Found: {filename} ({date.strftime('%Y-%m-%d %H:%M:%S')})") if not dumps: print("Could not parse dates from dump filenames") return None # Sort by date descending and return the newest from datetime import datetime def parse_date_safe(date_obj): if isinstance(date_obj, datetime): return date_obj return datetime.min # Default to a minimum date if parsing fails dumps.sort(key=lambda x: parse_date_safe(x.get("date", datetime.min)), reverse=True) latest = dumps[0] # Handle type issue with strftime if isinstance(latest["date"], datetime): formatted_date = latest["date"].strftime("%Y-%m-%d %H:%M:%S") else: formatted_date = "Unknown date" print(f"Latest dump: {latest['filename']} ({formatted_date})") return latest async def download_and_extract_dump( url: str, dest_dir: Optional[str] = None, max_retries: int = 5 ) -> tuple[Optional[str], Optional[str]]: """ Download a .sqlite3.gz file and extract it with streaming. Supports resume on connection failures. Returns tuple of (path to extracted .sqlite3 file, error message). On success: (path, None). On failure: (None, error_message). """ filename = url.split("/")[-1] if dest_dir: dest_path = Path(dest_dir) dest_path.mkdir(parents=True, exist_ok=True) else: dest_path = Path("/nvme/tmp") gz_path = dest_path / filename sqlite_path = dest_path / filename.replace(".gz", "") # If an extracted sqlite file already exists, skip download and extraction if sqlite_path.exists() and sqlite_path.stat().st_size > 0: print(f"Found existing extracted SQLite file {sqlite_path}; skipping download/extract") return str(sqlite_path), None # Streaming download with retry and resume support print(f"Downloading {url}...") for attempt in range(1, max_retries + 1): try: # Check if partial download exists for resume downloaded = 0 headers = {} if gz_path.exists(): downloaded = gz_path.stat().st_size headers["Range"] = f"bytes={downloaded}-" print(f" Resuming from {downloaded / (1024 * 1024):.1f} MB...") timeout = aiohttp.ClientTimeout( total=None, # No total timeout for large files connect=60, sock_read=300, # 5 min read timeout per chunk ) async with aiohttp.ClientSession() as session: async with session.get( url, timeout=timeout, headers=headers ) as response: # Handle resume response if response.status == 416: # Range not satisfiable - file complete print(" Download already complete.") break elif response.status == 206: # Partial content - resuming total = downloaded + int( response.headers.get("content-length", 0) ) elif response.status == 200: # Server doesn't support resume or fresh download if downloaded > 0: print(" Server doesn't support resume, restarting...") downloaded = 0 total = int(response.headers.get("content-length", 0)) else: response.raise_for_status() total = int(response.headers.get("content-length", 0)) # Open in append mode if resuming, write mode if fresh mode = "ab" if downloaded > 0 and response.status == 206 else "wb" if mode == "wb": downloaded = 0 async with aiofiles.open(gz_path, mode) as f: async for chunk in response.content.iter_chunked( 1024 * 1024 ): # 1MB chunks await f.write(chunk) downloaded += len(chunk) if total: pct = (downloaded / total) * 100 mb_down = downloaded / (1024 * 1024) mb_total = total / (1024 * 1024) print( f" Downloaded: {mb_down:.1f} / {mb_total:.1f} MB ({pct:.1f}%)", end="\r", ) print() # Newline after progress break # Success, exit retry loop except (aiohttp.ClientError, ConnectionResetError, asyncio.TimeoutError) as e: print(f"\n Download error (attempt {attempt}/{max_retries}): {e}") if attempt < max_retries: wait_time = min(30 * attempt, 120) # Exponential backoff, max 2 min print(f" Retrying in {wait_time} seconds...") await asyncio.sleep(wait_time) else: error_msg = f"Download failed after {max_retries} attempts: {e}" print(f"Error: {error_msg}") # Clean up partial file on final failure if gz_path.exists(): gz_path.unlink() return None, error_msg # Verify download completed if not gz_path.exists() or gz_path.stat().st_size == 0: error_msg = "Download failed - file is empty or missing" print(f"Error: {error_msg}") return None, error_msg # Extract (sync, but fast with streaming) print( f"Extracting {gz_path} ({gz_path.stat().st_size / (1024**3):.2f} GB compressed)..." ) try: # Use streaming decompression for large files async with aiofiles.open(sqlite_path, "wb") as f_out: with gzip.open(gz_path, "rb") as f_in: extracted = 0 while True: chunk = f_in.read(1024 * 1024 * 10) # 10MB chunks if not chunk: break await f_out.write(chunk) extracted += len(chunk) print(f" Extracted: {extracted / (1024**3):.2f} GB", end="\r") print() except Exception as e: error_msg = f"Extraction failed: {e}" print(f"Error: {error_msg}") return None, error_msg print( f"Extracted to: {sqlite_path} ({sqlite_path.stat().st_size / (1024**3):.2f} GB)" ) # Note: do NOT remove the .gz file here. Keep source files until import completes successfully. return str(sqlite_path), None async def extract_gz_file(gz_path: str, output_path: str) -> None: """Extract a .gz file using pigz for multi-threaded decompression.""" try: # Use pigz for faster decompression print(f"Extracting {gz_path} to {output_path} using pigz...") subprocess.run( ["pigz", "-d", "-p", str(os.cpu_count()), "-c", gz_path], stdout=open(output_path, "wb"), check=True, ) print(f"Extraction completed: {output_path}") except FileNotFoundError: raise RuntimeError( "pigz is not installed. Please install it for faster decompression." ) except subprocess.CalledProcessError as e: raise RuntimeError(f"Extraction failed: {e}") async def get_sqlite_tables(db: aiosqlite.Connection) -> List[str]: """Get list of all tables in SQLite database.""" async with db.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';" ) as cursor: rows = await cursor.fetchall() return [row[0] for row in rows] async def get_sqlite_schema( db: aiosqlite.Connection, table_name: str ) -> List[Dict[str, Any]]: """Get column info from SQLite table.""" async with db.execute(f"PRAGMA table_info({table_name});") as cursor: rows = await cursor.fetchall() return [ { "cid": row[0], "name": row[1], "type": row[2], "notnull": row[3], "default": row[4], "pk": row[5], } for row in rows ] async def get_pg_schema(pool: asyncpg.Pool, table_name: str) -> Dict[str, Any]: """Get PostgreSQL table schema including primary key and unique constraints.""" async with pool.acquire() as conn: # Get columns columns = await conn.fetch( """ SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_name = $1 ORDER BY ordinal_position; """, table_name, ) columns = [ {"name": row["column_name"], "type": row["data_type"]} for row in columns ] # Get primary key columns pk_columns = await conn.fetch( """ SELECT a.attname FROM pg_index i JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) WHERE i.indrelid = $1::regclass AND i.indisprimary; """, table_name, ) pk_columns = [row["attname"] for row in pk_columns] # Get unique constraint columns unique_rows = await conn.fetch( """ SELECT a.attname, c.conname FROM pg_constraint c JOIN pg_attribute a ON a.attrelid = c.conrelid AND a.attnum = ANY(c.conkey) WHERE c.conrelid = $1::regclass AND c.contype = 'u' ORDER BY c.conname, a.attnum; """, table_name, ) unique_constraints: Dict[str, List[str]] = {} for row in unique_rows: constraint_name = row["conname"] if constraint_name not in unique_constraints: unique_constraints[constraint_name] = [] unique_constraints[constraint_name].append(row["attname"]) return { "columns": columns, "pk_columns": pk_columns, "unique_constraints": unique_constraints, } def clean_value(value: Any) -> Any: """Basic cleaning for values (strip NULs). Use more specific coercion in per-column handling.""" if value is None: return None if isinstance(value, str): return value.replace("\x00", "") return value def parse_datetime_string(s: str) -> Optional[datetime]: """Parse a datetime string into a timezone-aware datetime if possible.""" if s is None: return None if isinstance(s, datetime): return s if not isinstance(s, str): return None try: # Normalize excessive fractional seconds (more than 6 digits) by truncating to 6 s_norm = re.sub(r"\.(\d{6})\d+", r".\1", s) # Try fromisoformat (handles 'YYYY-MM-DD HH:MM:SS.mmm+00:00' and ISO variants) dt = datetime.fromisoformat(s_norm) # If naive, assume UTC if dt.tzinfo is None: from datetime import timezone dt = dt.replace(tzinfo=timezone.utc) return dt except Exception: # Try common SQLite format for fmt in ( "%Y-%m-%d %H:%M:%S.%f%z", "%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S%z", "%Y-%m-%d %H:%M:%S", ): try: s_norm = re.sub(r"\.(\d{6})\d+", r".\1", s) dt = datetime.strptime(s_norm, fmt) if dt.tzinfo is None: from datetime import timezone dt = dt.replace(tzinfo=timezone.utc) return dt except Exception: continue return None def coerce_value_for_pg(value: Any, pg_type: str) -> Any: """Coerce a SQLite value into an appropriate Python type for asyncpg based on PostgreSQL data type.""" if value is None: return None pg_type = (pg_type or "").lower() # Boolean if "boolean" in pg_type: # SQLite often stores booleans as 0/1 if isinstance(value, (int, float)): return bool(value) if isinstance(value, str): v = value.strip().lower() if v in ("1", "t", "true", "yes"): return True if v in ("0", "f", "false", "no"): return False if isinstance(value, bool): return value # Fallback return bool(value) # Timestamps if "timestamp" in pg_type or "date" in pg_type or "time" in pg_type: if isinstance(value, datetime): return value if isinstance(value, str): dt = parse_datetime_string(value) if dt: return dt # If SQLite stores as numeric (unix epoch) if isinstance(value, (int, float)): try: return datetime.fromtimestamp(value) except Exception: pass return value # JSON if "json" in pg_type: if isinstance(value, str): try: return json.loads(value) except Exception: return value return value # Numeric integer types if pg_type in ("integer", "bigint", "smallint") or "int" in pg_type: if isinstance(value, int): return value if isinstance(value, float): return int(value) if isinstance(value, str) and value.isdigit(): return int(value) try: return int(value) except Exception: return value # Floating point if ( pg_type in ("real", "double precision", "numeric", "decimal") or "numeric" in pg_type ): if isinstance(value, (int, float)): return float(value) if isinstance(value, str): try: return float(value) except Exception: return value # Text-like types: coerce to string if not None if any(k in pg_type for k in ("char", "text", "uuid")): try: return str(value) except Exception: return clean_value(value) # Default: basic cleaning return clean_value(value) # Diagnostics directory for failing chunks DIAGNOSTICS_DIR = Path(__file__).parent / "diagnostics" DIAGNOSTICS_DIR.mkdir(exist_ok=True) async def save_failed_rows( dump_basename: str, table_name: str, offset: int, rows: List[tuple], error: str ) -> str: """Save failing rows as a gzipped JSONL file for inspection. Returns path to diagnostics file. """ fname = f"{dump_basename}.{table_name}.offset{offset}.failed.jsonl.gz" path = DIAGNOSTICS_DIR / fname try: import gzip as _gzip with _gzip.open(path, "wt", encoding="utf-8") as fh: for r in rows: record = {"row": r, "error": error} fh.write(json.dumps(record, default=str) + "\n") return str(path) except Exception as e: print(f"Failed to write diagnostics file: {e}") return "" async def attempt_insert_chunk( conn, upsert_sql: str, cleaned_rows: List[tuple], table_name: str, offset: int, dump_basename: str, batch_size: int = BATCH_SIZE, per_row_limit: int = 100, ) -> tuple[int, int, Optional[str]]: """Try inserting a chunk; on failure, split and try smaller batches then per-row. Returns (successful_rows, errors, diagnostics_path_or_None). """ success = 0 errors = 0 diagnostics_path = None # Fast path: whole chunk try: await conn.executemany(upsert_sql, cleaned_rows) return len(cleaned_rows), 0, None except Exception as e: # Fall through to batch splitting first_exc = str(e) # Try batch splitting for i in range(0, len(cleaned_rows), batch_size): batch = cleaned_rows[i : i + batch_size] try: await conn.executemany(upsert_sql, batch) success += len(batch) continue except Exception: # Try per-row for this batch failing = [] for j, row in enumerate(batch): try: await conn.execute(upsert_sql, *row) success += 1 except Exception as re_exc: errors += 1 failing.append((j + i, row, str(re_exc))) # If too many per-row failures, stop collecting if errors >= per_row_limit: break # If failures were recorded, save diagnostics if failing: # Prepare rows for saving (row index and values) rows_to_save = [r[1] for r in failing] diagnostics_path = await save_failed_rows( dump_basename, table_name, offset + i, rows_to_save, first_exc ) # Continue to next batch return success, errors, diagnostics_path async def _quote_ident(name: str) -> str: """Return a safely quoted SQL identifier.""" return '"' + name.replace('"', '""') + '"' async def fallback_upsert_using_temp( conn, table_name: str, common_columns: List[str], pk_columns: List[str], rows_iterable: List[tuple], dump_basename: str, ) -> tuple[int, int, Optional[str]]: """Fallback upsert that uses a temporary table and performs UPDATE then INSERT. Returns (rows_upserted_count, errors_count, diagnostics_path_or_None). """ errors = 0 diagnostics = None inserted = 0 updated = 0 # Build identifiers tmp_name = f"tmp_upsert_{table_name}_{os.getpid()}" # per-process unique-ish q_table = await _quote_ident(table_name) q_tmp = await _quote_ident(tmp_name) col_names_str = ", ".join([f'"{c}"' for c in common_columns]) # Create temp table like target try: await conn.execute(f"CREATE TEMP TABLE {q_tmp} (LIKE {q_table} INCLUDING ALL)") except Exception as e: return ( 0, 1, await save_failed_rows( dump_basename, table_name, 0, rows_iterable[:100], str(e) ), ) # Insert into temp (in batches) try: # Insert all rows into temp insert_sql = f"INSERT INTO {q_tmp} ({col_names_str}) VALUES ({', '.join([f'${i + 1}' for i in range(len(common_columns))])})" for i in range(0, len(rows_iterable), BATCH_SIZE): batch = rows_iterable[i : i + BATCH_SIZE] try: await conn.executemany(insert_sql, batch) except Exception as e: # Save diagnostics and abort diagnostics = await save_failed_rows( dump_basename, table_name, i, [tuple(r) for r in batch[:100]], str(e), ) errors += len(batch) return 0, errors, diagnostics # Count tmp rows tmp_count = await conn.fetchval(f"SELECT COUNT(*) FROM {q_tmp}") # Count matches in target (existing rows) pk_cond = " AND ".join( ["t." + f'"{k}"' + " = tmp." + f'"{k}"' for k in pk_columns] ) match_count = await conn.fetchval( f"SELECT COUNT(*) FROM {q_tmp} tmp JOIN {q_table} t ON {pk_cond}" ) # Perform update of matching rows (only non-PK columns) update_cols = [c for c in common_columns if c not in pk_columns] if update_cols: set_clauses = ", ".join( ["{q_table}." + f'"{c}"' + " = tmp." + f'"{c}"' for c in update_cols] ) await conn.execute( f"UPDATE {q_table} SET {set_clauses} FROM {q_tmp} tmp WHERE {pk_cond}" ) updated = match_count # Insert non-existing rows insert_sql2 = f"INSERT INTO {q_table} ({col_names_str}) SELECT {col_names_str} FROM {q_tmp} tmp WHERE NOT EXISTS (SELECT 1 FROM {q_table} t WHERE {pk_cond})" await conn.execute(insert_sql2) # Approximate inserted count inserted = tmp_count - match_count except Exception as e: diagnostics = await save_failed_rows( dump_basename, table_name, 0, rows_iterable[:100], str(e) ) errors += 1 finally: # Drop temp table try: await conn.execute(f"DROP TABLE IF EXISTS {q_tmp}") except Exception: pass return inserted + updated, errors, diagnostics async def validate_sample_rows( sqlite_db: aiosqlite.Connection, pg_col_types: List[str], col_names_str: str, table_name: str, dump_basename: str, sample_n: int = 1000, ) -> tuple[int, int, Optional[str]]: """Validate a sample of rows by attempting to coerce values and detecting likely mismatches. Returns (fail_count, total_checked, diagnostics_path_or_None). """ fail_count = 0 total = 0 failing_rows = [] async with sqlite_db.execute( f"SELECT {col_names_str} FROM {table_name} LIMIT {sample_n};" ) as cursor: rows = await cursor.fetchall() for i, row in enumerate(rows): total += 1 # Coerce and check types for val, pg_type in zip(row, pg_col_types): coerced = coerce_value_for_pg(val, pg_type) # Heuristics for likely failure pg_type_l = (pg_type or "").lower() if "timestamp" in pg_type_l or "date" in pg_type_l or "time" in pg_type_l: if coerced is None or not isinstance(coerced, datetime): fail_count += 1 failing_rows.append(row) break elif "boolean" in pg_type_l: if not isinstance(coerced, bool): fail_count += 1 failing_rows.append(row) break elif any(k in pg_type_l for k in ("int", "bigint", "smallint")): if not isinstance(coerced, int): # allow numeric strings that can be int if not (isinstance(coerced, str) and coerced.isdigit()): fail_count += 1 failing_rows.append(row) break # else: assume text/json ok diag = None if failing_rows: rows_to_save = [tuple(r) for r in failing_rows[:100]] diag = await save_failed_rows( dump_basename, f"{table_name}.sample", 0, rows_to_save, "Validation failures", ) return fail_count, total, diag def determine_conflict_columns( pg_schema: Dict[str, Any], sqlite_schema: List[Dict[str, Any]], table_name: str, ) -> List[str]: """Determine which columns to use for conflict detection. Priority: 1. PostgreSQL primary key 2. PostgreSQL unique constraint 3. SQLite primary key (only if PostgreSQL has a matching unique/PK constraint) """ # 1) PostgreSQL primary key if pg_schema["pk_columns"]: return pg_schema["pk_columns"] # 2) PostgreSQL unique constraint if pg_schema["unique_constraints"]: first_constraint = list(pg_schema["unique_constraints"].values())[0] return first_constraint # 3) Consider SQLite primary key only if PostgreSQL has a matching constraint sqlite_pk_columns = [col["name"] for col in sqlite_schema if col["pk"] > 0] sqlite_pk_columns = [ col["name"] for col in sorted( [c for c in sqlite_schema if c["pk"] > 0], key=lambda x: x["pk"] ) ] if sqlite_pk_columns: # Check if any PG unique constraint matches these columns (order-insensitive) for cols in pg_schema["unique_constraints"].values(): if set(cols) == set(sqlite_pk_columns): return sqlite_pk_columns # Also accept if PG has a primary key covering same columns (already handled), so if we reach here # PG has no matching constraint — cannot safely perform ON CONFLICT based upsert raise ValueError( f"Table {table_name} has a SQLite primary key ({', '.join(sqlite_pk_columns)}) but no matching unique/PK constraint in PostgreSQL. Cannot perform upsert." ) # Nothing found raise ValueError( f"Table {table_name} has no primary key or unique constraint in either PostgreSQL or SQLite. Cannot perform upsert." ) def build_upsert_sql( table_name: str, columns: List[str], conflict_columns: List[str], update_columns: Optional[List[str]] = None, ) -> str: """Build PostgreSQL upsert SQL with $1, $2, etc. placeholders.""" col_names_str = ", ".join([f'"{col}"' for col in columns]) placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))]) conflict_str = ", ".join([f'"{col}"' for col in conflict_columns]) if update_columns is None: update_columns = [c for c in columns if c not in conflict_columns] if not update_columns: return f"INSERT INTO {table_name} ({col_names_str}) VALUES ({placeholders}) ON CONFLICT ({conflict_str}) DO NOTHING" set_clauses = [f'"{col}" = EXCLUDED."{col}"' for col in update_columns] set_str = ", ".join(set_clauses) return f"INSERT INTO {table_name} ({col_names_str}) VALUES ({placeholders}) ON CONFLICT ({conflict_str}) DO UPDATE SET {set_str}" async def upsert_table( sqlite_db: aiosqlite.Connection, pg_pool: asyncpg.Pool, table_name: str, dry_run: bool = False, verbose: bool = False, dump_basename: Optional[str] = None, ) -> TableReport: """Upsert data from SQLite table to PostgreSQL table.""" import time start_time = time.time() report = TableReport(table_name=table_name) # Get schemas sqlite_schema = await get_sqlite_schema(sqlite_db, table_name) pg_schema = await get_pg_schema(pg_pool, table_name) if not pg_schema["columns"]: report.status = "skipped" report.error_message = "Table does not exist in PostgreSQL" print(f" ⚠ Table {table_name} does not exist in PostgreSQL, skipping...") return report # Match columns sqlite_columns = [col["name"] for col in sqlite_schema] pg_columns = [col["name"] for col in pg_schema["columns"]] common_columns = [col for col in sqlite_columns if col in pg_columns] if not common_columns: report.status = "skipped" report.error_message = "No common columns" print(f" ⚠ No common columns for table {table_name}") return report if verbose: print(f" Common columns: {', '.join(common_columns)}") print(f" PG Primary keys: {pg_schema['pk_columns']}") # Determine conflict columns (tries PG first, then SQLite) use_temp_fallback = False use_exists_upsert = False sqlite_pk_columns = [] pk_column = None try: conflict_columns = determine_conflict_columns( pg_schema, sqlite_schema, table_name ) except ValueError as e: sqlite_pk_columns = [col["name"] for col in sqlite_schema if col["pk"] > 0] if sqlite_pk_columns: # Prefer existence-based upsert for single-column PKs; fallback to temp-table for composite PKs if len(sqlite_pk_columns) == 1: print( f" ⚠ Table {table_name} has a SQLite primary key ({sqlite_pk_columns[0]}) but no matching unique/PK constraint in PostgreSQL. Will use existence-based update/insert fallback." ) use_exists_upsert = True pk_column = sqlite_pk_columns[0] conflict_columns = [pk_column] else: print( f" ⚠ Table {table_name} has a SQLite primary key ({', '.join(sqlite_pk_columns)}) but no matching unique/PK constraint in PostgreSQL. Will attempt safer temp-table fallback." ) use_temp_fallback = True conflict_columns = sqlite_pk_columns else: report.status = "skipped" report.error_message = str(e) print(f" ⚠ {e}") return report if verbose: print(f" Using conflict columns: {conflict_columns}") # Build upsert SQL for standard ON CONFLICT approach upsert_sql = build_upsert_sql(table_name, common_columns, conflict_columns) # Get counts async with sqlite_db.execute(f"SELECT COUNT(*) FROM {table_name};") as cursor: row = await cursor.fetchone() report.sqlite_rows = row[0] if row else 0 async with pg_pool.acquire() as conn: report.pg_rows_before = await conn.fetchval( f"SELECT COUNT(*) FROM {table_name};" ) print( f" SQLite rows: {report.sqlite_rows:,}, PG rows before: {report.pg_rows_before:,}" ) if report.sqlite_rows == 0: report.status = "success" report.pg_rows_after = report.pg_rows_before print(" Table is empty, nothing to upsert.") return report # Prepare column query col_names_str = ", ".join([f'"{col}"' for col in common_columns]) # Prepare per-column PG types for validation and coercion pg_col_type_map = {c["name"]: c["type"] for c in pg_schema["columns"]} pg_col_types = [pg_col_type_map.get(col, "text") for col in common_columns] # Local chunk size (adaptive) chunk_size = CHUNK_SIZE # Pre-validate a sample of rows to catch common type mismatches early and adapt sample_n = min(1000, chunk_size, report.sqlite_rows) if sample_n > 0: fail_count, total_checked, sample_diag = await validate_sample_rows( sqlite_db, pg_col_types, col_names_str, table_name, dump_basename or "unknown_dump", sample_n, ) if fail_count: pct = (fail_count / total_checked) * 100 print( f" ⚠ Validation: {fail_count}/{total_checked} sample rows ({pct:.1f}%) appear problematic; reducing chunk size and writing diagnostics" ) if sample_diag: print(f" ⚠ Sample diagnostics: {sample_diag}") # Reduce chunk size to be safer old_chunk = chunk_size chunk_size = max(1000, chunk_size // 4) print(f" ⚠ Adjusted chunk_size: {old_chunk} -> {chunk_size}") # If we determined earlier that we need to use the temp-table fallback, handle that path now if use_temp_fallback: print( " Using temp-table fallback upsert (may be slower but safe on missing PG constraints)" ) # Process in chunks using cursor iteration but push to temp upsert helper offset = 0 async with pg_pool.acquire() as conn: while offset < report.sqlite_rows: async with sqlite_db.execute( f"SELECT {col_names_str} FROM {table_name} LIMIT {chunk_size} OFFSET {offset};" ) as cursor: rows = await cursor.fetchall() if not rows: break # Coerce cleaned_rows = [] for row in rows: coerced = [] for val, pg_type in zip(row, pg_col_types): coerced.append(coerce_value_for_pg(val, pg_type)) cleaned_rows.append(tuple(coerced)) # Attempt fallback upsert try: async with conn.transaction(): upsert_count, errs, diag = await fallback_upsert_using_temp( conn, table_name, common_columns, sqlite_pk_columns, cleaned_rows, dump_basename or "unknown_dump", ) except Exception as e: upsert_count, errs, diag = ( 0, 1, await save_failed_rows( dump_basename or "unknown_dump", table_name, offset, cleaned_rows[:100], str(e), ), ) report.rows_affected += upsert_count report.errors += errs if diag: print(f"\n ⚠ Diagnostics written: {diag}") if report.errors > 10: print(" ✗ Too many errors, aborting table") report.status = "failed" report.error_message = f"Too many errors ({report.errors})" break offset += chunk_size # Set final counts async with pg_pool.acquire() as conn: report.pg_rows_after = await conn.fetchval( f"SELECT COUNT(*) FROM {table_name};" ) report.duration_seconds = time.time() - start_time if report.errors == 0: report.status = "success" elif report.errors <= 10: report.status = "success" else: report.status = "failed" return report if use_exists_upsert: print( " Using existence-based upsert (INSERT for new rows, UPDATE for existing rows)" ) # Process in chunks using cursor iteration offset = 0 async with pg_pool.acquire() as conn: while offset < report.sqlite_rows: async with sqlite_db.execute( f"SELECT {col_names_str} FROM {table_name} LIMIT {chunk_size} OFFSET {offset};" ) as cursor: rows = await cursor.fetchall() if not rows: break # Coerce cleaned_rows = [] for row in rows: coerced = [] for val, pg_type in zip(row, pg_col_types): coerced.append(coerce_value_for_pg(val, pg_type)) cleaned_rows.append(tuple(coerced)) # Build list of PK values pk_index = common_columns.index(pk_column) pk_values = [r[pk_index] for r in cleaned_rows] # Query existing PKs existing = set() try: rows_found = await conn.fetch( f"SELECT {pk_column} FROM {table_name} WHERE {pk_column} = ANY($1)", pk_values, ) existing = set([r[pk_column] for r in rows_found]) except Exception as e: # On error, save diagnostics and abort diag = await save_failed_rows( dump_basename or "unknown_dump", table_name, offset, cleaned_rows[:100], str(e), ) print(f"\n ⚠ Diagnostics written: {diag}") report.errors += 1 break insert_rows = [r for r in cleaned_rows if r[pk_index] not in existing] update_rows = [r for r in cleaned_rows if r[pk_index] in existing] # Perform inserts if insert_rows: col_names_csv = ", ".join([f'"{c}"' for c in common_columns]) placeholders = ", ".join( [f"${i + 1}" for i in range(len(common_columns))] ) insert_sql = f"INSERT INTO {table_name} ({col_names_csv}) VALUES ({placeholders})" try: await conn.executemany(insert_sql, insert_rows) report.rows_affected += len(insert_rows) except Exception as e: diag = await save_failed_rows( dump_basename or "unknown_dump", table_name, offset, insert_rows[:100], str(e), ) print(f"\n ⚠ Diagnostics written: {diag}") report.errors += 1 # Perform updates (per-row executemany) if update_rows: update_cols = [c for c in common_columns if c != pk_column] set_clause = ", ".join( [f'"{c}" = ${i + 1}' for i, c in enumerate(update_cols)] ) # For executemany we'll pass values ordered as [col1, col2, ..., pk] update_sql = f'UPDATE {table_name} SET {set_clause} WHERE "{pk_column}" = ${len(update_cols) + 1}' update_values = [] for r in update_rows: vals = [r[common_columns.index(c)] for c in update_cols] vals.append(r[pk_index]) update_values.append(tuple(vals)) try: await conn.executemany(update_sql, update_values) report.rows_affected += len(update_rows) except Exception as e: diag = await save_failed_rows( dump_basename or "unknown_dump", table_name, offset, update_rows[:100], str(e), ) print(f"\n ⚠ Diagnostics written: {diag}") report.errors += 1 if report.errors > 10: print(" ✗ Too many errors, aborting table") report.status = "failed" report.error_message = f"Too many errors ({report.errors})" break offset += chunk_size # Set final counts async with pg_pool.acquire() as conn: report.pg_rows_after = await conn.fetchval( f"SELECT COUNT(*) FROM {table_name};" ) report.duration_seconds = time.time() - start_time if report.errors == 0: report.status = "success" elif report.errors <= 10: report.status = "success" else: report.status = "failed" return report # Process in chunks using cursor iteration offset = 0 async with pg_pool.acquire() as conn: while offset < report.sqlite_rows: try: # Fetch chunk from SQLite async with sqlite_db.execute( f"SELECT {col_names_str} FROM {table_name} LIMIT {chunk_size} OFFSET {offset};" ) as cursor: rows = await cursor.fetchall() if not rows: break # Coerce data per-column based on PG types pg_col_types = [] pg_col_type_map = {c["name"]: c["type"] for c in pg_schema["columns"]} for col in common_columns: pg_col_types.append(pg_col_type_map.get(col, "text")) cleaned_rows = [] for row in rows: coerced = [] for val, pg_type in zip(row, pg_col_types): coerced.append(coerce_value_for_pg(val, pg_type)) cleaned_rows.append(tuple(coerced)) if not dry_run: # Try inserting chunk with adaptive recovery dump_basename_local = dump_basename or "unknown_dump" success_count, err_count, diag = await attempt_insert_chunk( conn, upsert_sql, cleaned_rows, table_name, offset, dump_basename_local, ) report.rows_affected += success_count report.errors += err_count if diag: print(f"\n ⚠ Diagnostics written: {diag}") # Record in report for visibility report.error_message = ( (report.error_message + "; ") if report.error_message else "" ) + f"Diagnostics: {diag}" if err_count > 0: print( f"\n ⚠ Error at offset {offset}: {err_count} problematic rows (see diagnostics)" ) if report.errors > 10: print(" ✗ Too many errors, aborting table") report.status = "failed" report.error_message = f"Too many errors ({report.errors})" + ( f"; {report.error_message}" if report.error_message else "" ) break offset += CHUNK_SIZE progress = min(offset, report.sqlite_rows) pct = (progress / report.sqlite_rows) * 100 print( f" Progress: {progress:,}/{report.sqlite_rows:,} ({pct:.1f}%)", end="\r", ) except Exception as e: print(f"\n ⚠ Error at offset {offset}: {e}") report.errors += 1 offset += CHUNK_SIZE if report.errors > 10: print(" ✗ Too many errors, aborting table") report.status = "failed" report.error_message = f"Too many errors ({report.errors})" break print() # Newline after progress # Get final count async with pg_pool.acquire() as conn: report.pg_rows_after = await conn.fetchval( f"SELECT COUNT(*) FROM {table_name};" ) report.duration_seconds = time.time() - start_time if report.errors == 0: report.status = "success" elif report.errors <= 10: report.status = "success" else: report.status = "failed" return report async def upsert_database( sqlite_path: str, tables: Optional[List[str]] = None, dry_run: bool = False, verbose: bool = False, dump_date: Optional[str] = None, ) -> UpsertReport: """Main upsert function - process SQLite dump into PostgreSQL.""" report = UpsertReport(sqlite_source=sqlite_path, dump_date=dump_date) report.start_time = datetime.now() print(f"{'[DRY RUN] ' if dry_run else ''}SQLite → PostgreSQL Upsert (Async)") print(f"SQLite: {sqlite_path}") print(f"PostgreSQL: {PG_CONFIG['database']}@{PG_CONFIG['host']}") print(f"Chunk size: {CHUNK_SIZE:,} rows\n") # Connect to databases print("Connecting to databases...") sqlite_db = await aiosqlite.connect(sqlite_path) pg_pool = await asyncpg.create_pool( host=PG_CONFIG["host"], port=PG_CONFIG["port"], database=PG_CONFIG["database"], user=PG_CONFIG["user"], password=PG_CONFIG["password"], min_size=2, max_size=10, ) try: # Get tables available_tables = await get_sqlite_tables(sqlite_db) if tables: process_tables = [t for t in tables if t in available_tables] missing = [t for t in tables if t not in available_tables] if missing: print(f"⚠ Tables not found in SQLite: {', '.join(missing)}") else: process_tables = available_tables print(f"Tables to process: {', '.join(process_tables)}\n") for i, table_name in enumerate(process_tables, 1): print(f"[{i}/{len(process_tables)}] Processing table: {table_name}") try: dump_basename = Path(sqlite_path).name table_report = await upsert_table( sqlite_db, pg_pool, table_name, dry_run, verbose, dump_basename ) report.tables.append(table_report) if table_report.status == "success": print( f" ✓ Completed: +{table_report.rows_inserted:,} inserted, ~{table_report.rows_updated:,} updated" ) elif table_report.status == "skipped": print(" ○ Skipped") else: print(" ✗ Failed") except Exception as e: table_report = TableReport( table_name=table_name, status="failed", error_message=str(e) ) report.tables.append(table_report) print(f" ✗ Failed: {e}") print() finally: await sqlite_db.close() await pg_pool.close() report.end_time = datetime.now() return report async def test_pg_connection() -> bool: """Test PostgreSQL connection before proceeding with download.""" print("Testing PostgreSQL connection...") try: conn = await asyncpg.connect( host=PG_CONFIG["host"], port=PG_CONFIG["port"], database=PG_CONFIG["database"], user=PG_CONFIG["user"], password=PG_CONFIG["password"], ) version = await conn.fetchval("SELECT version();") await conn.close() print(f" ✓ Connected to PostgreSQL: {version[:50]}...") return True except Exception as e: print(f" ✗ PostgreSQL connection failed: {e}") return False async def check_and_fetch_latest( force: bool = False, dest_dir: Optional[str] = None, dry_run: bool = False, verbose: bool = False, notify: bool = False, ) -> Optional[UpsertReport]: """Check for new dumps on lrclib.net and upsert if newer than last run.""" # Send start notification if notify: await notify_start() # Test PostgreSQL connection FIRST before downloading large files if not await test_pg_connection(): error_msg = "Cannot connect to PostgreSQL database" print(f"Aborting: {error_msg}") if notify: await notify_failure( error_msg, "Connection Test", f"Host: {PG_CONFIG['host']}:{PG_CONFIG['port']}", ) return None # Get latest dump info dump_info = await fetch_latest_dump_info() if not dump_info: error_msg = "Could not fetch latest dump info from lrclib.net" print(error_msg) if notify: await notify_failure( error_msg, "Fetch Dump Info", f"URL: {LRCLIB_DUMPS_URL}" ) return None dump_date = dump_info["date"] dump_date_str = dump_date.strftime("%Y-%m-%d %H:%M:%S") # Check against last upsert state = load_state() last_upsert_str = state.get("last_dump_date") # Full datetime string if last_upsert_str: try: # Try parsing with time first last_date = datetime.strptime(last_upsert_str, "%Y-%m-%d %H:%M:%S") except ValueError: # Fall back to date only last_date = datetime.strptime(last_upsert_str, "%Y-%m-%d") print(f"Last upsert dump date: {last_upsert_str}") if dump_date <= last_date and not force: print( f"No new dump available (latest: {dump_date_str}, last upsert: {last_upsert_str})" ) print("Use --force to upsert anyway") if notify: await notify_no_update(dump_date_str, last_upsert_str) return None print(f"New dump available: {dump_date_str} > {last_upsert_str}") else: print("No previous upsert recorded") if dry_run: print(f"[DRY RUN] Would download and upsert {dump_info['filename']}") return None # Notify about new dump found if notify: await notify_new_dump_found(dump_info["filename"], dump_date_str) # Download and extract sqlite_path, download_error = await download_and_extract_dump( dump_info["url"], dest_dir ) if not sqlite_path: error_msg = download_error or "Unknown download/extract error" print(f"Failed: {error_msg}") if notify: await notify_failure( error_msg, "Download/Extract", f"URL: {dump_info['url']}" ) return None # Keep files until we confirm import success to allow resume/inspection gz_path = str(Path(sqlite_path).with_suffix(Path(sqlite_path).suffix + ".gz")) print(f"Downloaded and extracted: {sqlite_path}") print(f"Keeping source compressed file for resume/inspection: {gz_path}") # Define SQLITE_DB_PATH dynamically after extraction (persist globally) global SQLITE_DB_PATH SQLITE_DB_PATH = sqlite_path # Set to the extracted SQLite file path print(f"SQLite database path set to: {SQLITE_DB_PATH}") upsert_successful = False try: # Perform upsert report = await upsert_database( sqlite_path=sqlite_path, tables=None, dry_run=False, verbose=verbose, dump_date=dump_date_str, ) # If we had no errors and at least one table was considered, determine if any rows were upserted if report.total_errors == 0 or len(report.successful_tables) > 0: # Determine if any table actually had rows inserted or updated any_upserted = any( (t.status == "success" and (t.rows_inserted > 0 or t.rows_updated > 0)) for t in report.tables ) # Mark overall upsert success (only true if rows were actually upserted) upsert_successful = any_upserted if upsert_successful: # Save state only when an actual upsert occurred to avoid skipping future attempts state["last_dump_date"] = dump_date_str # Date from the dump filename state["last_upsert_time"] = datetime.now().isoformat() # When we ran state["last_dump_url"] = dump_info["url"] state["last_dump_filename"] = dump_info["filename"] save_state(state) print(f"\nState saved: last_dump_date = {dump_date_str}") if notify: if any_upserted: await notify_success(report) else: # All tables skipped or empty (no actual upsert) await discord_notify( title="âš ī¸ LRCLib DB Update - No Data Upserted", description="Upsert completed, but no rows were inserted or updated. All tables may have been skipped due to missing primary keys or were empty. The dump files have been retained for inspection/resume.", color=DiscordColor.WARNING, fields=[ { "name": "Tables Processed", "value": str(len(report.tables)), "inline": True, }, { "name": "Successful Tables", "value": str(len(report.successful_tables)), "inline": True, }, { "name": "Rows Inserted", "value": str(report.total_rows_inserted), "inline": True, }, { "name": "Rows Updated", "value": str(report.total_rows_updated), "inline": True, }, ], footer=f"Source: {report.sqlite_source}", ) else: # All tables failed if notify: failed_tables = ", ".join(report.failed_tables[:5]) await notify_failure( "All tables failed to upsert", "Upsert", f"Failed tables: {failed_tables}", ) return report except Exception as e: error_msg = str(e) print(f"Upsert error: {error_msg}") if notify: import traceback tb = traceback.format_exc() await notify_failure(error_msg, "Upsert", tb[:1000]) raise finally: # Only remove files when a successful upsert occurred. Keep them otherwise for resume/debug. if upsert_successful: # Clean up downloaded SQLite file if sqlite_path and os.path.exists(sqlite_path): print(f"Cleaning up SQLite dump (success): {sqlite_path}") os.unlink(sqlite_path) # Also remove the .gz file if sqlite_path: gz_path = sqlite_path + ".gz" if os.path.exists(gz_path): print(f"Cleaning up compressed file (success): {gz_path}") os.unlink(gz_path) else: # Keep files for inspection/resume if sqlite_path and os.path.exists(sqlite_path): print(f"Retaining SQLite dump for inspection/resume: {sqlite_path}") if sqlite_path: gz_path = sqlite_path + ".gz" if os.path.exists(gz_path): print(f"Retaining compressed file for inspection/resume: {gz_path}") async def create_new_schema(pg_conn: asyncpg.Connection, schema_name: str) -> None: """Create a new schema for the migration.""" await pg_conn.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}") async def migrate_to_new_schema( sqlite_conn: aiosqlite.Connection, pg_conn: asyncpg.Connection, schema_name: str, table_name: str, columns: List[str], ) -> None: """Migrate data to a new schema.""" staging_table = f"{schema_name}.{table_name}" # Drop and recreate staging table in the new schema await pg_conn.execute(f"DROP TABLE IF EXISTS {staging_table}") await pg_conn.execute( f"CREATE TABLE {staging_table} (LIKE {table_name} INCLUDING ALL)" ) # Import data into the staging table sqlite_cursor = await sqlite_conn.execute( f"SELECT {', '.join(columns)} FROM {table_name}" ) rows = await sqlite_cursor.fetchall() await sqlite_cursor.close() # Convert rows to a list to ensure compatibility with len() rows = list(rows) print(f"Imported {len(rows)} rows into {staging_table}") placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))]) insert_sql = f"INSERT INTO {staging_table} ({', '.join([f'"{col}"' for col in columns])}) VALUES ({placeholders})" await pg_conn.executemany(insert_sql, rows) print(f"Imported {len(rows)} rows into {staging_table}") async def swap_schemas(pg_conn: asyncpg.Connection, new_schema: str) -> None: """Swap the new schema with the public schema.""" async with pg_conn.transaction(): # Rename public schema to a backup schema await pg_conn.execute("ALTER SCHEMA public RENAME TO backup") # Rename new schema to public await pg_conn.execute(f"ALTER SCHEMA {new_schema} RENAME TO public") # Drop the backup schema await pg_conn.execute("DROP SCHEMA backup CASCADE") print(f"Swapped schema {new_schema} with public") async def migrate_database(sqlite_path: Optional[str] = None): """Main migration function. If `sqlite_path` is provided, use it. Otherwise fall back to the global `SQLITE_DB_PATH`. """ path = sqlite_path or SQLITE_DB_PATH if not path: raise ValueError( "No SQLite path provided. Call migrate_database(sqlite_path=...) or set SQLITE_DB_PATH before invoking." ) sqlite_conn = await aiosqlite.connect(path) pg_conn = await asyncpg.connect(**PG_CONFIG) new_schema = "staging" try: # Create a new schema for the migration await create_new_schema(pg_conn, new_schema) sqlite_tables = await sqlite_conn.execute( "SELECT name FROM sqlite_master WHERE type='table'" ) tables = [row[0] for row in await sqlite_tables.fetchall()] await sqlite_tables.close() for table_name in tables: sqlite_cursor = await sqlite_conn.execute( f"PRAGMA table_info({table_name})" ) columns = [row[1] for row in await sqlite_cursor.fetchall()] await sqlite_cursor.close() # Migrate data to the new schema await migrate_to_new_schema( sqlite_conn, pg_conn, new_schema, table_name, columns ) # Swap schemas after successful migration await swap_schemas(pg_conn, new_schema) finally: await sqlite_conn.close() await pg_conn.close() # Ensure SQLITE_DB_PATH is defined globally before usage SQLITE_DB_PATH = None # Initialize as None to avoid NameError if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="LRCLib DB upsert / migration utility") parser.add_argument( "--migrate", action="store_true", help="Migrate a local SQLite file into a new staging schema and swap", ) parser.add_argument( "--sqlite", type=str, help="Path to local SQLite file to migrate (required for --migrate)", ) parser.add_argument( "--check", action="store_true", help="Check lrclib.net for latest dump and upsert if newer", ) parser.add_argument( "--force", action="store_true", help="Force upsert even if dump is not newer" ) parser.add_argument( "--dest", type=str, help="Destination directory for downloaded dumps" ) parser.add_argument( "--dry-run", action="store_true", help="Do not perform writes; just simulate" ) parser.add_argument("--verbose", action="store_true", help="Verbose output") parser.add_argument( "--notify", action="store_true", help="Send Discord notifications" ) args = parser.parse_args() if args.migrate: if not args.sqlite: parser.error("--migrate requires --sqlite PATH") print(f"Starting schema migration from local SQLite: {args.sqlite}") asyncio.run(migrate_database(sqlite_path=args.sqlite)) else: # Default behavior: check lrclib.net and upsert if needed print("Checking for latest dump and performing upsert (if applicable)...") asyncio.run( check_and_fetch_latest( force=args.force, dest_dir=args.dest, dry_run=args.dry_run, verbose=args.verbose, notify=args.notify, ) )