diff --git a/migrate_sqlite_to_pg.py b/migrate_sqlite_to_pg.py new file mode 100644 index 0000000..fc785c5 --- /dev/null +++ b/migrate_sqlite_to_pg.py @@ -0,0 +1,711 @@ +#!/usr/bin/env python3 +""" +SQLite -> PostgreSQL migrator. + +This script integrates the download/check/notification helpers from +`update_lrclib_db.py` and migrates a given SQLite dump into a new +**staging** PostgreSQL database named with the dump date. After a +successful import the staging DB is swapped to replace the production +DB (with a backup rename of the previous DB). + +Usage examples: + - Default (auto-fetch + notify): ./migrate_sqlite_to_pg.py + - Migrate a local file: ./migrate_sqlite_to_pg.py --sqlite /path/to/db.sqlite3 + - Disable notifications: ./migrate_sqlite_to_pg.py --no-notify + - Force re-import: ./migrate_sqlite_to_pg.py --force +""" +from __future__ import annotations + +import argparse +import asyncio +import io +import os +import sqlite3 +import sys +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional + +import psycopg2 # type: ignore[import] +from dotenv import load_dotenv + +# Import helpers from update_lrclib_db.py (async functions) +from update_lrclib_db import ( + fetch_latest_dump_info, + download_and_extract_dump, + parse_dump_date, + notify_start, + notify_new_dump_found, + notify_success, + notify_failure, + load_state, + save_state, + UpsertReport, + TableReport, +) + +load_dotenv() + +# ---------------------- Config ---------------------- + +PG_HOST = os.getenv("POSTGRES_HOST", "localhost") +PG_PORT = int(os.getenv("POSTGRES_PORT", "5432")) +PG_USER = os.getenv("POSTGRES_USER", "api") +PG_PASSWORD = os.getenv("POSTGRES_PASSWORD", "") +PG_DATABASE = os.getenv("POSTGRES_DB", "lrclib") +CHUNK_SIZE = int(os.getenv("MIGRATE_CHUNK_SIZE", "100000")) # 100k rows per batch +DEFAULT_DEST_DIR = Path(os.getenv("LRCLIB_DUMP_DIR", "/nvme/tmp")) + + +# ---------------------- SQLite -> PostgreSQL helpers ---------------------- + + +def sqlite_to_pg_type(sqlite_type: str) -> str: + """Convert SQLite type to PostgreSQL type.""" + if not sqlite_type: + return "TEXT" + sqlite_type = sqlite_type.upper() + mapping = { + "INTEGER": "BIGINT", + "TEXT": "TEXT", + "REAL": "DOUBLE PRECISION", + "FLOAT": "DOUBLE PRECISION", + "BLOB": "BYTEA", + "NUMERIC": "NUMERIC", + "BOOLEAN": "BOOLEAN", + "DATE": "DATE", + "DATETIME": "TIMESTAMP", + "TIMESTAMP": "TIMESTAMPTZ", + } + for key, pg_type in mapping.items(): + if key in sqlite_type: + return pg_type + return "TEXT" + + +def get_sqlite_tables(conn: sqlite3.Connection) -> list[str]: + """Get list of user tables from SQLite database.""" + cur = conn.cursor() + cur.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';" + ) + return [row[0] for row in cur.fetchall()] + + +def get_table_schema(conn: sqlite3.Connection, table: str) -> list[tuple[str, str]]: + """Get column names and types for a SQLite table.""" + cur = conn.cursor() + cur.execute(f"PRAGMA table_info({table});") + return [(row[1], row[2]) for row in cur.fetchall()] + + +def clean_row(row: tuple, columns: list[tuple[str, str]]) -> tuple: + """Clean row data for PostgreSQL insertion.""" + cleaned: list[Any] = [] + for i, value in enumerate(row): + col_name, col_type = columns[i] + pg_type = sqlite_to_pg_type(col_type) + + if value is None: + cleaned.append(None) + elif pg_type == "BOOLEAN" and isinstance(value, int): + cleaned.append(bool(value)) + elif isinstance(value, str): + # Remove NULL bytes which PostgreSQL text fields reject + cleaned.append(value.replace("\x00", "")) + else: + cleaned.append(value) + + return tuple(cleaned) + + +def escape_copy_value(value, pg_type: str) -> str: + """Escape a value for PostgreSQL COPY format (tab-separated).\n + This is much faster than INSERT for bulk loading. + """ + if value is None: + return "\\N" # NULL marker for COPY + + if pg_type == "BOOLEAN" and isinstance(value, int): + return "t" if value else "f" + + if isinstance(value, bool): + return "t" if value else "f" + + if isinstance(value, (int, float)): + return str(value) + + if isinstance(value, bytes): + # BYTEA: encode as hex + return "\\\\x" + value.hex() + + # String: escape special characters for COPY + s = str(value) + s = s.replace("\x00", "") # Remove NULL bytes + s = s.replace("\\", "\\\\") # Escape backslashes + s = s.replace("\t", "\\t") # Escape tabs + s = s.replace("\n", "\\n") # Escape newlines + s = s.replace("\r", "\\r") # Escape carriage returns + return s + + +def create_table( + pg_conn, table: str, columns: list[tuple[str, str]], unlogged: bool = True +) -> None: + """Create a table in PostgreSQL based on SQLite schema. + + Uses UNLOGGED tables by default for faster bulk import (no WAL writes). + """ + cur = pg_conn.cursor() + + # Check if table exists + cur.execute( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = %s);", + (table,), + ) + if cur.fetchone()[0]: + print(f" Table {table} already exists, skipping create") + return + + col_defs = [f'"{name}" {sqlite_to_pg_type(coltype)}' for name, coltype in columns] + unlogged_kw = "UNLOGGED" if unlogged else "" + sql = f'CREATE {unlogged_kw} TABLE "{table}" ({", ".join(col_defs)});' + print(f" Creating {'UNLOGGED ' if unlogged else ''}table: {table}") + cur.execute(sql) + pg_conn.commit() + + +def migrate_table( + sqlite_conn: sqlite3.Connection, + pg_conn, + table: str, + columns: list[tuple[str, str]], +) -> TableReport: + """Migrate data using PostgreSQL COPY for maximum speed. + + COPY is 5-10x faster than INSERT for bulk loading. + """ + import time + + start = time.time() + report = TableReport(table_name=table) + + sqlite_cur = sqlite_conn.cursor() + pg_cur = pg_conn.cursor() + + col_names = [c[0] for c in columns] + col_names_quoted = ", ".join([f'"{c}"' for c in col_names]) + pg_types = [sqlite_to_pg_type(c[1]) for c in columns] + + # Get row count + sqlite_cur.execute(f'SELECT COUNT(*) FROM "{table}";') + report.sqlite_rows = sqlite_cur.fetchone()[0] + report.pg_rows_before = 0 # Fresh table + + if report.sqlite_rows == 0: + print(f" Skipping {table}: empty table") + report.pg_rows_after = 0 + report.status = "success" + report.duration_seconds = time.time() - start + return report + + print(f" Importing {table}: {report.sqlite_rows:,} rows using COPY") + + offset = 0 + total_copied = 0 + errors = 0 + + while offset < report.sqlite_rows: + try: + sqlite_cur.execute( + f'SELECT {col_names_quoted} FROM "{table}" LIMIT {CHUNK_SIZE} OFFSET {offset};' + ) + rows = sqlite_cur.fetchall() + if not rows: + break + + # Build COPY data in memory + buffer = io.StringIO() + for row in rows: + line_parts = [ + escape_copy_value(val, pg_types[i]) for i, val in enumerate(row) + ] + buffer.write("\t".join(line_parts) + "\n") + + buffer.seek(0) + + # Use COPY FROM for fast bulk insert + pg_cur.copy_expert( + f'COPY "{table}" ({col_names_quoted}) FROM STDIN WITH (FORMAT text)', + buffer, + ) + pg_conn.commit() + + total_copied += len(rows) + offset += CHUNK_SIZE + pct = min(100.0, (offset / report.sqlite_rows) * 100) + elapsed = time.time() - start + rate = total_copied / elapsed if elapsed > 0 else 0 + print( + f" Progress: {min(offset, report.sqlite_rows):,}/{report.sqlite_rows:,} ({pct:.1f}%) - {rate:,.0f} rows/sec", + end="\r", + ) + + except Exception as e: + print(f"\n ⚠ Error at offset {offset} for table {table}: {e}") + errors += 1 + report.errors += 1 + offset += CHUNK_SIZE + pg_conn.rollback() + + if errors > 10: + report.status = "failed" + report.error_message = f"Too many errors ({errors})" + break + + print() # newline after progress + + # Final count + pg_cur.execute(f'SELECT COUNT(*) FROM "{table}";') + report.pg_rows_after = pg_cur.fetchone()[0] + report.rows_affected = total_copied + report.duration_seconds = time.time() - start + + if report.status != "failed": + report.status = "success" + + rate = ( + report.rows_affected / report.duration_seconds + if report.duration_seconds > 0 + else 0 + ) + print( + f" ✓ Table {table}: {report.rows_affected:,} rows in {report.duration_seconds:.1f}s ({rate:,.0f} rows/sec)" + ) + return report + + +# ---------------------- PostgreSQL DB management ---------------------- + + +def pg_connect(dbname: Optional[str] = None): + """Connect to PostgreSQL database.""" + return psycopg2.connect( + host=PG_HOST, + port=PG_PORT, + database=dbname or "postgres", + user=PG_USER, + password=PG_PASSWORD, + ) + + +def create_database(db_name: str) -> None: + """Create a new PostgreSQL database.""" + print(f"Creating database: {db_name}") + conn = pg_connect("postgres") + conn.autocommit = True + cur = conn.cursor() + + cur.execute("SELECT 1 FROM pg_database WHERE datname = %s", (db_name,)) + if cur.fetchone(): + print(f" Database {db_name} already exists") + cur.close() + conn.close() + return + + # Try creating database; handle collation mismatch gracefully + try: + cur.execute(f'CREATE DATABASE "{db_name}" OWNER {PG_USER};') + print(f" Created database {db_name}") + except Exception as e: + err = str(e) + if "collation version" in err or "template1" in err: + print( + " ⚠ Detected collation mismatch on template1; retrying with TEMPLATE template0" + ) + cur.execute( + f'CREATE DATABASE "{db_name}" OWNER {PG_USER} TEMPLATE template0;' + ) + print(f" Created database {db_name} using TEMPLATE template0") + else: + raise + + cur.close() + conn.close() + + +def terminate_connections(db_name: str, max_wait: int = 10) -> bool: + """Terminate all connections to a database. + + Returns True if all connections were terminated, False if some remain. + Won't fail on permission errors (e.g., can't terminate superuser connections). + """ + import time + + conn = pg_connect("postgres") + conn.autocommit = True + cur = conn.cursor() + + for attempt in range(max_wait): + # Check how many connections exist + cur.execute( + "SELECT COUNT(*) FROM pg_stat_activity WHERE datname = %s AND pid <> pg_backend_pid();", + (db_name,), + ) + row = cur.fetchone() + count = int(row[0]) if row else 0 + + if count == 0: + cur.close() + conn.close() + return True + + print(f" Terminating {count} connection(s) to {db_name}...") + + # Try to terminate - ignore errors for connections we can't kill + try: + cur.execute( + """ + SELECT pg_terminate_backend(pid) + FROM pg_stat_activity + WHERE datname = %s + AND pid <> pg_backend_pid() + AND usename = current_user; -- Only terminate our own connections + """, + (db_name,), + ) + except Exception as e: + print(f" Warning: {e}") + + # Brief wait for connections to close + time.sleep(1) + + # Final check + cur.execute( + "SELECT COUNT(*) FROM pg_stat_activity WHERE datname = %s AND pid <> pg_backend_pid();", + (db_name,), + ) + row = cur.fetchone() + remaining = int(row[0]) if row else 0 + cur.close() + conn.close() + + if remaining > 0: + print(f" Warning: {remaining} connection(s) still active (may be superuser sessions)") + return False + return True + + +def database_exists(db_name: str) -> bool: + """Check if a database exists.""" + conn = pg_connect("postgres") + conn.autocommit = True + cur = conn.cursor() + cur.execute("SELECT 1 FROM pg_database WHERE datname = %s", (db_name,)) + exists = cur.fetchone() is not None + cur.close() + conn.close() + return exists + + +def rename_database(old_name: str, new_name: str) -> None: + """Rename a PostgreSQL database.""" + conn = pg_connect("postgres") + conn.autocommit = True + cur = conn.cursor() + print(f" Renaming {old_name} -> {new_name}") + cur.execute(f'ALTER DATABASE "{old_name}" RENAME TO "{new_name}";') + cur.close() + conn.close() + + +def drop_database(db_name: str) -> bool: + """Drop a PostgreSQL database. + + Returns True if dropped, False if failed (e.g., active connections). + """ + # First try to terminate connections + terminate_connections(db_name) + + conn = pg_connect("postgres") + conn.autocommit = True + cur = conn.cursor() + print(f" Dropping database {db_name}") + try: + cur.execute(f'DROP DATABASE IF EXISTS "{db_name}";') + cur.close() + conn.close() + return True + except Exception as e: + print(f" Warning: Could not drop {db_name}: {e}") + cur.close() + conn.close() + return False + conn.close() + + +# ---------------------- Main migration flow ---------------------- + + +def run_migration( + sqlite_path: str, + dump_dt: Optional[datetime], + notify: bool = True, + dry_run: bool = False, +) -> UpsertReport: + """ + Create a timestamped staging DB, import data, then swap to production. + + Returns an UpsertReport with migration details. + """ + dump_dt = dump_dt or datetime.now(timezone.utc) + ts = dump_dt.strftime("%Y%m%dT%H%M%SZ") + staging_db = f"{PG_DATABASE}_{ts}" + backup_db = f"{PG_DATABASE}_backup_{ts}" + + # Initialize report + report = UpsertReport(sqlite_source=sqlite_path, dump_date=ts) + report.start_time = datetime.now(timezone.utc) + + print(f"\n{'=' * 60}") + print("SQLite -> PostgreSQL Migration") + print(f"{'=' * 60}") + print(f"Source: {sqlite_path}") + print(f"Staging DB: {staging_db}") + print(f"Production: {PG_DATABASE}") + print(f"Chunk size: {CHUNK_SIZE:,}") + print(f"{'=' * 60}\n") + + if dry_run: + print("DRY RUN - no changes will be made") + return report + + # Create staging database + create_database(staging_db) + + sqlite_conn = None + pg_conn = None + + try: + sqlite_conn = sqlite3.connect(sqlite_path) + pg_conn = pg_connect(staging_db) + + tables = get_sqlite_tables(sqlite_conn) + print(f"Found {len(tables)} tables: {', '.join(tables)}\n") + + for i, table in enumerate(tables, 1): + print(f"[{i}/{len(tables)}] Processing: {table}") + columns = get_table_schema(sqlite_conn, table) + create_table(pg_conn, table, columns) + table_report = migrate_table(sqlite_conn, pg_conn, table, columns) + report.tables.append(table_report) + print() + + # Close connections before renaming + pg_conn.close() + pg_conn = None + sqlite_conn.close() + sqlite_conn = None + + # Perform the database swap + print(f"\n{'=' * 60}") + print("Migration complete — performing database swap") + print(f"{'=' * 60}\n") + + terminate_connections(PG_DATABASE) + terminate_connections(staging_db) + + if database_exists(PG_DATABASE): + print(f" Backing up production DB to {backup_db}") + rename_database(PG_DATABASE, backup_db) + else: + print(f" Production DB {PG_DATABASE} does not exist; nothing to back up") + + rename_database(staging_db, PG_DATABASE) + print("\n✓ Database swap complete!") + + # Update state + state = load_state() + state["last_dump_date"] = dump_dt.strftime("%Y-%m-%d %H:%M:%S") + state["last_upsert_time"] = datetime.now(timezone.utc).isoformat() + save_state(state) + + report.end_time = datetime.now(timezone.utc) + + # Send success notification + if notify: + asyncio.run(notify_success(report)) + + return report + + except Exception as e: + print(f"\n✗ Migration failed: {e}") + + # Cleanup staging DB on failure + try: + if pg_conn: + pg_conn.close() + terminate_connections(staging_db) + drop_database(staging_db) + except Exception as cleanup_err: + print(f" Cleanup error: {cleanup_err}") + + if notify: + asyncio.run(notify_failure(str(e), "Migration")) + + raise + + finally: + if sqlite_conn: + try: + sqlite_conn.close() + except Exception: + pass + if pg_conn: + try: + pg_conn.close() + except Exception: + pass + + +# ---------------------- CLI ---------------------- + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Migrate SQLite dump to PostgreSQL with atomic swap", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s # Auto-fetch latest dump and migrate + %(prog)s --sqlite /path/to/db # Migrate local SQLite file + %(prog)s --force # Force migration even if dump is not newer + %(prog)s --no-notify # Disable Discord notifications + %(prog)s --dry-run # Simulate without making changes + """, + ) + parser.add_argument( + "--sqlite", + type=str, + metavar="PATH", + help="Path to local SQLite file (skips auto-fetch if provided)", + ) + parser.add_argument( + "--dest", + type=str, + metavar="DIR", + help="Destination directory for downloaded dumps", + ) + parser.add_argument( + "--no-notify", + action="store_true", + help="Disable Discord notifications (enabled by default)", + ) + parser.add_argument( + "--force", + action="store_true", + help="Force migration even if dump date is not newer", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Simulate without making changes", + ) + + args = parser.parse_args() + notify_enabled = not args.no_notify + + sqlite_path: Optional[str] = None + dump_dt: Optional[datetime] = None + + try: + if args.sqlite: + # Use provided SQLite file + sqlite_path = args.sqlite + assert sqlite_path is not None + if not Path(sqlite_path).exists(): + print(f"Error: SQLite file not found: {sqlite_path}") + sys.exit(1) + + # Try to parse date from filename + parsed = parse_dump_date(Path(sqlite_path).name) + dump_dt = parsed or datetime.now(timezone.utc) + + else: + # Auto-fetch latest dump + if notify_enabled: + asyncio.run(notify_start()) + + print("Checking for latest dump on lrclib.net...") + latest = asyncio.run(fetch_latest_dump_info()) + + if not latest: + print("Error: Could not fetch latest dump info") + if notify_enabled: + asyncio.run(notify_failure("Could not fetch dump info", "Fetch")) + sys.exit(1) + + dump_date = latest["date"] + dump_date_str = dump_date.strftime("%Y-%m-%d %H:%M:%S") + print(f"Latest dump: {latest['filename']} ({dump_date_str})") + + # Check if we need to update + state = load_state() + last_str = state.get("last_dump_date") + + if last_str and not args.force: + try: + last_dt = datetime.strptime(last_str, "%Y-%m-%d %H:%M:%S") + if dump_date <= last_dt: + print(f"No new dump available (last: {last_str})") + print("Use --force to migrate anyway") + sys.exit(0) + except ValueError: + pass # Invalid date format, proceed with migration + + print(f"New dump available: {dump_date_str}") + + if notify_enabled: + asyncio.run( + notify_new_dump_found(latest["filename"], dump_date_str) + ) + + # Download + print(f"\nDownloading {latest['filename']}...") + db_path, err = asyncio.run( + download_and_extract_dump(latest["url"], args.dest) + ) + + if not db_path: + print(f"Error: Download failed: {err}") + if notify_enabled: + asyncio.run(notify_failure(str(err), "Download")) + sys.exit(1) + + sqlite_path = db_path + dump_dt = dump_date + + # Run migration + assert sqlite_path is not None + run_migration( + sqlite_path=sqlite_path, + dump_dt=dump_dt, + notify=notify_enabled, + dry_run=args.dry_run, + ) + + print("\n✓ Migration completed successfully!") + + except KeyboardInterrupt: + print("\n\nInterrupted by user") + sys.exit(130) + + except Exception as e: + print(f"\nFatal error: {e}") + if notify_enabled: + asyncio.run(notify_failure(str(e), "Migration")) + sys.exit(1) + + +if __name__ == "__main__": + main()