Files
api/migrate_sqlite_to_pg.py

712 lines
21 KiB
Python
Raw Normal View History

2026-02-03 13:00:41 -05:00
#!/usr/bin/env python3
"""
SQLite -> PostgreSQL migrator.
This script integrates the download/check/notification helpers from
`update_lrclib_db.py` and migrates a given SQLite dump into a new
**staging** PostgreSQL database named with the dump date. After a
successful import the staging DB is swapped to replace the production
DB (with a backup rename of the previous DB).
Usage examples:
- Default (auto-fetch + notify): ./migrate_sqlite_to_pg.py
- Migrate a local file: ./migrate_sqlite_to_pg.py --sqlite /path/to/db.sqlite3
- Disable notifications: ./migrate_sqlite_to_pg.py --no-notify
- Force re-import: ./migrate_sqlite_to_pg.py --force
"""
from __future__ import annotations
import argparse
import asyncio
import io
import os
import sqlite3
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional
import psycopg2 # type: ignore[import]
from dotenv import load_dotenv
# Import helpers from update_lrclib_db.py (async functions)
from update_lrclib_db import (
fetch_latest_dump_info,
download_and_extract_dump,
parse_dump_date,
notify_start,
notify_new_dump_found,
notify_success,
notify_failure,
load_state,
save_state,
UpsertReport,
TableReport,
)
load_dotenv()
# ---------------------- Config ----------------------
PG_HOST = os.getenv("POSTGRES_HOST", "localhost")
PG_PORT = int(os.getenv("POSTGRES_PORT", "5432"))
PG_USER = os.getenv("POSTGRES_USER", "api")
PG_PASSWORD = os.getenv("POSTGRES_PASSWORD", "")
PG_DATABASE = os.getenv("POSTGRES_DB", "lrclib")
CHUNK_SIZE = int(os.getenv("MIGRATE_CHUNK_SIZE", "100000")) # 100k rows per batch
DEFAULT_DEST_DIR = Path(os.getenv("LRCLIB_DUMP_DIR", "/nvme/tmp"))
# ---------------------- SQLite -> PostgreSQL helpers ----------------------
def sqlite_to_pg_type(sqlite_type: str) -> str:
"""Convert SQLite type to PostgreSQL type."""
if not sqlite_type:
return "TEXT"
sqlite_type = sqlite_type.upper()
mapping = {
"INTEGER": "BIGINT",
"TEXT": "TEXT",
"REAL": "DOUBLE PRECISION",
"FLOAT": "DOUBLE PRECISION",
"BLOB": "BYTEA",
"NUMERIC": "NUMERIC",
"BOOLEAN": "BOOLEAN",
"DATE": "DATE",
"DATETIME": "TIMESTAMP",
"TIMESTAMP": "TIMESTAMPTZ",
}
for key, pg_type in mapping.items():
if key in sqlite_type:
return pg_type
return "TEXT"
def get_sqlite_tables(conn: sqlite3.Connection) -> list[str]:
"""Get list of user tables from SQLite database."""
cur = conn.cursor()
cur.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
)
return [row[0] for row in cur.fetchall()]
def get_table_schema(conn: sqlite3.Connection, table: str) -> list[tuple[str, str]]:
"""Get column names and types for a SQLite table."""
cur = conn.cursor()
cur.execute(f"PRAGMA table_info({table});")
return [(row[1], row[2]) for row in cur.fetchall()]
def clean_row(row: tuple, columns: list[tuple[str, str]]) -> tuple:
"""Clean row data for PostgreSQL insertion."""
cleaned: list[Any] = []
for i, value in enumerate(row):
col_name, col_type = columns[i]
pg_type = sqlite_to_pg_type(col_type)
if value is None:
cleaned.append(None)
elif pg_type == "BOOLEAN" and isinstance(value, int):
cleaned.append(bool(value))
elif isinstance(value, str):
# Remove NULL bytes which PostgreSQL text fields reject
cleaned.append(value.replace("\x00", ""))
else:
cleaned.append(value)
return tuple(cleaned)
def escape_copy_value(value, pg_type: str) -> str:
"""Escape a value for PostgreSQL COPY format (tab-separated).\n
This is much faster than INSERT for bulk loading.
"""
if value is None:
return "\\N" # NULL marker for COPY
if pg_type == "BOOLEAN" and isinstance(value, int):
return "t" if value else "f"
if isinstance(value, bool):
return "t" if value else "f"
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, bytes):
# BYTEA: encode as hex
return "\\\\x" + value.hex()
# String: escape special characters for COPY
s = str(value)
s = s.replace("\x00", "") # Remove NULL bytes
s = s.replace("\\", "\\\\") # Escape backslashes
s = s.replace("\t", "\\t") # Escape tabs
s = s.replace("\n", "\\n") # Escape newlines
s = s.replace("\r", "\\r") # Escape carriage returns
return s
def create_table(
pg_conn, table: str, columns: list[tuple[str, str]], unlogged: bool = True
) -> None:
"""Create a table in PostgreSQL based on SQLite schema.
Uses UNLOGGED tables by default for faster bulk import (no WAL writes).
"""
cur = pg_conn.cursor()
# Check if table exists
cur.execute(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = %s);",
(table,),
)
if cur.fetchone()[0]:
print(f" Table {table} already exists, skipping create")
return
col_defs = [f'"{name}" {sqlite_to_pg_type(coltype)}' for name, coltype in columns]
unlogged_kw = "UNLOGGED" if unlogged else ""
sql = f'CREATE {unlogged_kw} TABLE "{table}" ({", ".join(col_defs)});'
print(f" Creating {'UNLOGGED ' if unlogged else ''}table: {table}")
cur.execute(sql)
pg_conn.commit()
def migrate_table(
sqlite_conn: sqlite3.Connection,
pg_conn,
table: str,
columns: list[tuple[str, str]],
) -> TableReport:
"""Migrate data using PostgreSQL COPY for maximum speed.
COPY is 5-10x faster than INSERT for bulk loading.
"""
import time
start = time.time()
report = TableReport(table_name=table)
sqlite_cur = sqlite_conn.cursor()
pg_cur = pg_conn.cursor()
col_names = [c[0] for c in columns]
col_names_quoted = ", ".join([f'"{c}"' for c in col_names])
pg_types = [sqlite_to_pg_type(c[1]) for c in columns]
# Get row count
sqlite_cur.execute(f'SELECT COUNT(*) FROM "{table}";')
report.sqlite_rows = sqlite_cur.fetchone()[0]
report.pg_rows_before = 0 # Fresh table
if report.sqlite_rows == 0:
print(f" Skipping {table}: empty table")
report.pg_rows_after = 0
report.status = "success"
report.duration_seconds = time.time() - start
return report
print(f" Importing {table}: {report.sqlite_rows:,} rows using COPY")
offset = 0
total_copied = 0
errors = 0
while offset < report.sqlite_rows:
try:
sqlite_cur.execute(
f'SELECT {col_names_quoted} FROM "{table}" LIMIT {CHUNK_SIZE} OFFSET {offset};'
)
rows = sqlite_cur.fetchall()
if not rows:
break
# Build COPY data in memory
buffer = io.StringIO()
for row in rows:
line_parts = [
escape_copy_value(val, pg_types[i]) for i, val in enumerate(row)
]
buffer.write("\t".join(line_parts) + "\n")
buffer.seek(0)
# Use COPY FROM for fast bulk insert
pg_cur.copy_expert(
f'COPY "{table}" ({col_names_quoted}) FROM STDIN WITH (FORMAT text)',
buffer,
)
pg_conn.commit()
total_copied += len(rows)
offset += CHUNK_SIZE
pct = min(100.0, (offset / report.sqlite_rows) * 100)
elapsed = time.time() - start
rate = total_copied / elapsed if elapsed > 0 else 0
print(
f" Progress: {min(offset, report.sqlite_rows):,}/{report.sqlite_rows:,} ({pct:.1f}%) - {rate:,.0f} rows/sec",
end="\r",
)
except Exception as e:
print(f"\n ⚠ Error at offset {offset} for table {table}: {e}")
errors += 1
report.errors += 1
offset += CHUNK_SIZE
pg_conn.rollback()
if errors > 10:
report.status = "failed"
report.error_message = f"Too many errors ({errors})"
break
print() # newline after progress
# Final count
pg_cur.execute(f'SELECT COUNT(*) FROM "{table}";')
report.pg_rows_after = pg_cur.fetchone()[0]
report.rows_affected = total_copied
report.duration_seconds = time.time() - start
if report.status != "failed":
report.status = "success"
rate = (
report.rows_affected / report.duration_seconds
if report.duration_seconds > 0
else 0
)
print(
f" ✓ Table {table}: {report.rows_affected:,} rows in {report.duration_seconds:.1f}s ({rate:,.0f} rows/sec)"
)
return report
# ---------------------- PostgreSQL DB management ----------------------
def pg_connect(dbname: Optional[str] = None):
"""Connect to PostgreSQL database."""
return psycopg2.connect(
host=PG_HOST,
port=PG_PORT,
database=dbname or "postgres",
user=PG_USER,
password=PG_PASSWORD,
)
def create_database(db_name: str) -> None:
"""Create a new PostgreSQL database."""
print(f"Creating database: {db_name}")
conn = pg_connect("postgres")
conn.autocommit = True
cur = conn.cursor()
cur.execute("SELECT 1 FROM pg_database WHERE datname = %s", (db_name,))
if cur.fetchone():
print(f" Database {db_name} already exists")
cur.close()
conn.close()
return
# Try creating database; handle collation mismatch gracefully
try:
cur.execute(f'CREATE DATABASE "{db_name}" OWNER {PG_USER};')
print(f" Created database {db_name}")
except Exception as e:
err = str(e)
if "collation version" in err or "template1" in err:
print(
" ⚠ Detected collation mismatch on template1; retrying with TEMPLATE template0"
)
cur.execute(
f'CREATE DATABASE "{db_name}" OWNER {PG_USER} TEMPLATE template0;'
)
print(f" Created database {db_name} using TEMPLATE template0")
else:
raise
cur.close()
conn.close()
def terminate_connections(db_name: str, max_wait: int = 10) -> bool:
"""Terminate all connections to a database.
Returns True if all connections were terminated, False if some remain.
Won't fail on permission errors (e.g., can't terminate superuser connections).
"""
import time
conn = pg_connect("postgres")
conn.autocommit = True
cur = conn.cursor()
for attempt in range(max_wait):
# Check how many connections exist
cur.execute(
"SELECT COUNT(*) FROM pg_stat_activity WHERE datname = %s AND pid <> pg_backend_pid();",
(db_name,),
)
row = cur.fetchone()
count = int(row[0]) if row else 0
if count == 0:
cur.close()
conn.close()
return True
print(f" Terminating {count} connection(s) to {db_name}...")
# Try to terminate - ignore errors for connections we can't kill
try:
cur.execute(
"""
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE datname = %s
AND pid <> pg_backend_pid()
AND usename = current_user; -- Only terminate our own connections
""",
(db_name,),
)
except Exception as e:
print(f" Warning: {e}")
# Brief wait for connections to close
time.sleep(1)
# Final check
cur.execute(
"SELECT COUNT(*) FROM pg_stat_activity WHERE datname = %s AND pid <> pg_backend_pid();",
(db_name,),
)
row = cur.fetchone()
remaining = int(row[0]) if row else 0
cur.close()
conn.close()
if remaining > 0:
print(f" Warning: {remaining} connection(s) still active (may be superuser sessions)")
return False
return True
def database_exists(db_name: str) -> bool:
"""Check if a database exists."""
conn = pg_connect("postgres")
conn.autocommit = True
cur = conn.cursor()
cur.execute("SELECT 1 FROM pg_database WHERE datname = %s", (db_name,))
exists = cur.fetchone() is not None
cur.close()
conn.close()
return exists
def rename_database(old_name: str, new_name: str) -> None:
"""Rename a PostgreSQL database."""
conn = pg_connect("postgres")
conn.autocommit = True
cur = conn.cursor()
print(f" Renaming {old_name} -> {new_name}")
cur.execute(f'ALTER DATABASE "{old_name}" RENAME TO "{new_name}";')
cur.close()
conn.close()
def drop_database(db_name: str) -> bool:
"""Drop a PostgreSQL database.
Returns True if dropped, False if failed (e.g., active connections).
"""
# First try to terminate connections
terminate_connections(db_name)
conn = pg_connect("postgres")
conn.autocommit = True
cur = conn.cursor()
print(f" Dropping database {db_name}")
try:
cur.execute(f'DROP DATABASE IF EXISTS "{db_name}";')
cur.close()
conn.close()
return True
except Exception as e:
print(f" Warning: Could not drop {db_name}: {e}")
cur.close()
conn.close()
return False
conn.close()
# ---------------------- Main migration flow ----------------------
def run_migration(
sqlite_path: str,
dump_dt: Optional[datetime],
notify: bool = True,
dry_run: bool = False,
) -> UpsertReport:
"""
Create a timestamped staging DB, import data, then swap to production.
Returns an UpsertReport with migration details.
"""
dump_dt = dump_dt or datetime.now(timezone.utc)
ts = dump_dt.strftime("%Y%m%dT%H%M%SZ")
staging_db = f"{PG_DATABASE}_{ts}"
backup_db = f"{PG_DATABASE}_backup_{ts}"
# Initialize report
report = UpsertReport(sqlite_source=sqlite_path, dump_date=ts)
report.start_time = datetime.now(timezone.utc)
print(f"\n{'=' * 60}")
print("SQLite -> PostgreSQL Migration")
print(f"{'=' * 60}")
print(f"Source: {sqlite_path}")
print(f"Staging DB: {staging_db}")
print(f"Production: {PG_DATABASE}")
print(f"Chunk size: {CHUNK_SIZE:,}")
print(f"{'=' * 60}\n")
if dry_run:
print("DRY RUN - no changes will be made")
return report
# Create staging database
create_database(staging_db)
sqlite_conn = None
pg_conn = None
try:
sqlite_conn = sqlite3.connect(sqlite_path)
pg_conn = pg_connect(staging_db)
tables = get_sqlite_tables(sqlite_conn)
print(f"Found {len(tables)} tables: {', '.join(tables)}\n")
for i, table in enumerate(tables, 1):
print(f"[{i}/{len(tables)}] Processing: {table}")
columns = get_table_schema(sqlite_conn, table)
create_table(pg_conn, table, columns)
table_report = migrate_table(sqlite_conn, pg_conn, table, columns)
report.tables.append(table_report)
print()
# Close connections before renaming
pg_conn.close()
pg_conn = None
sqlite_conn.close()
sqlite_conn = None
# Perform the database swap
print(f"\n{'=' * 60}")
print("Migration complete — performing database swap")
print(f"{'=' * 60}\n")
terminate_connections(PG_DATABASE)
terminate_connections(staging_db)
if database_exists(PG_DATABASE):
print(f" Backing up production DB to {backup_db}")
rename_database(PG_DATABASE, backup_db)
else:
print(f" Production DB {PG_DATABASE} does not exist; nothing to back up")
rename_database(staging_db, PG_DATABASE)
print("\n✓ Database swap complete!")
# Update state
state = load_state()
state["last_dump_date"] = dump_dt.strftime("%Y-%m-%d %H:%M:%S")
state["last_upsert_time"] = datetime.now(timezone.utc).isoformat()
save_state(state)
report.end_time = datetime.now(timezone.utc)
# Send success notification
if notify:
asyncio.run(notify_success(report))
return report
except Exception as e:
print(f"\n✗ Migration failed: {e}")
# Cleanup staging DB on failure
try:
if pg_conn:
pg_conn.close()
terminate_connections(staging_db)
drop_database(staging_db)
except Exception as cleanup_err:
print(f" Cleanup error: {cleanup_err}")
if notify:
asyncio.run(notify_failure(str(e), "Migration"))
raise
finally:
if sqlite_conn:
try:
sqlite_conn.close()
except Exception:
pass
if pg_conn:
try:
pg_conn.close()
except Exception:
pass
# ---------------------- CLI ----------------------
def main() -> None:
parser = argparse.ArgumentParser(
description="Migrate SQLite dump to PostgreSQL with atomic swap",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
%(prog)s # Auto-fetch latest dump and migrate
%(prog)s --sqlite /path/to/db # Migrate local SQLite file
%(prog)s --force # Force migration even if dump is not newer
%(prog)s --no-notify # Disable Discord notifications
%(prog)s --dry-run # Simulate without making changes
""",
)
parser.add_argument(
"--sqlite",
type=str,
metavar="PATH",
help="Path to local SQLite file (skips auto-fetch if provided)",
)
parser.add_argument(
"--dest",
type=str,
metavar="DIR",
help="Destination directory for downloaded dumps",
)
parser.add_argument(
"--no-notify",
action="store_true",
help="Disable Discord notifications (enabled by default)",
)
parser.add_argument(
"--force",
action="store_true",
help="Force migration even if dump date is not newer",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Simulate without making changes",
)
args = parser.parse_args()
notify_enabled = not args.no_notify
sqlite_path: Optional[str] = None
dump_dt: Optional[datetime] = None
try:
if args.sqlite:
# Use provided SQLite file
sqlite_path = args.sqlite
assert sqlite_path is not None
if not Path(sqlite_path).exists():
print(f"Error: SQLite file not found: {sqlite_path}")
sys.exit(1)
# Try to parse date from filename
parsed = parse_dump_date(Path(sqlite_path).name)
dump_dt = parsed or datetime.now(timezone.utc)
else:
# Auto-fetch latest dump
if notify_enabled:
asyncio.run(notify_start())
print("Checking for latest dump on lrclib.net...")
latest = asyncio.run(fetch_latest_dump_info())
if not latest:
print("Error: Could not fetch latest dump info")
if notify_enabled:
asyncio.run(notify_failure("Could not fetch dump info", "Fetch"))
sys.exit(1)
dump_date = latest["date"]
dump_date_str = dump_date.strftime("%Y-%m-%d %H:%M:%S")
print(f"Latest dump: {latest['filename']} ({dump_date_str})")
# Check if we need to update
state = load_state()
last_str = state.get("last_dump_date")
if last_str and not args.force:
try:
last_dt = datetime.strptime(last_str, "%Y-%m-%d %H:%M:%S")
if dump_date <= last_dt:
print(f"No new dump available (last: {last_str})")
print("Use --force to migrate anyway")
sys.exit(0)
except ValueError:
pass # Invalid date format, proceed with migration
print(f"New dump available: {dump_date_str}")
if notify_enabled:
asyncio.run(
notify_new_dump_found(latest["filename"], dump_date_str)
)
# Download
print(f"\nDownloading {latest['filename']}...")
db_path, err = asyncio.run(
download_and_extract_dump(latest["url"], args.dest)
)
if not db_path:
print(f"Error: Download failed: {err}")
if notify_enabled:
asyncio.run(notify_failure(str(err), "Download"))
sys.exit(1)
sqlite_path = db_path
dump_dt = dump_date
# Run migration
assert sqlite_path is not None
run_migration(
sqlite_path=sqlite_path,
dump_dt=dump_dt,
notify=notify_enabled,
dry_run=args.dry_run,
)
print("\n✓ Migration completed successfully!")
except KeyboardInterrupt:
print("\n\nInterrupted by user")
sys.exit(130)
except Exception as e:
print(f"\nFatal error: {e}")
if notify_enabled:
asyncio.run(notify_failure(str(e), "Migration"))
sys.exit(1)
if __name__ == "__main__":
main()