Files
api/update_lrclib_db.py

2104 lines
74 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
Upsert SQLite dump data into PostgreSQL (Async Version).
This script handles incremental updates from SQLite dumps to PostgreSQL:
- New rows are inserted
- Existing rows are updated if changed
- Uses primary keys and unique constraints to identify duplicates
- Can automatically fetch new dumps from lrclib.net
- Generates detailed change reports
- Async with asyncpg for high performance with large datasets (17GB+)
"""
import asyncio
import aiosqlite
import asyncpg
import os
import gzip
import re
import json
import aiohttp
import aiofiles
import subprocess
from typing import List, Dict, Any, Optional
from datetime import datetime
from pathlib import Path
from dataclasses import dataclass, field, asdict
from datetime import timezone
from dotenv import load_dotenv
from playwright.async_api import async_playwright
try:
from playwright.async_api import async_playwright
PLAYWRIGHT_AVAILABLE = True
except ImportError:
PLAYWRIGHT_AVAILABLE = False
load_dotenv()
# Discord webhook for notifications
DISCORD_WEBHOOK_URL = "https://discord.com/api/webhooks/1332864106750939168/6y6MgFhHLX0-BnL2M3hXnt2vJsue7Q2Duf_HjZenHNlNj7sxQr4lqxrPVJnJWf7KVAm2"
# Configuration - using environment variables
PG_CONFIG = {
"host": os.getenv("POSTGRES_HOST", "localhost"),
"port": int(os.getenv("POSTGRES_PORT", "5432")),
"database": os.getenv("POSTGRES_DB", "lrclib"),
"user": os.getenv("POSTGRES_USER", "api"),
"password": os.getenv("POSTGRES_PASSWORD", ""),
}
CHUNK_SIZE = 10000 # Process rows in chunks (larger for async)
BATCH_SIZE = 1000 # Rows per INSERT statement
# LRCLib dump URL
LRCLIB_DUMPS_URL = "https://lrclib.net/db-dumps"
STATE_FILE = Path(__file__).parent / ".lrclib_upsert_state.json"
@dataclass
class TableReport:
"""Report for a single table's upsert operation."""
table_name: str
sqlite_rows: int = 0
pg_rows_before: int = 0
pg_rows_after: int = 0
rows_affected: int = 0
errors: int = 0
duration_seconds: float = 0.0
status: str = "pending"
error_message: str = ""
@property
def rows_inserted(self) -> int:
return max(0, self.pg_rows_after - self.pg_rows_before)
@property
def rows_updated(self) -> int:
return max(0, self.rows_affected - self.rows_inserted)
@dataclass
class UpsertReport:
"""Complete report for an upsert operation."""
sqlite_source: str
dump_date: Optional[str] = None
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
tables: List[TableReport] = field(default_factory=list)
@property
def total_sqlite_rows(self) -> int:
return sum(t.sqlite_rows for t in self.tables)
@property
def total_rows_affected(self) -> int:
return sum(t.rows_affected for t in self.tables)
@property
def total_rows_inserted(self) -> int:
return sum(t.rows_inserted for t in self.tables)
@property
def total_rows_updated(self) -> int:
return sum(t.rows_updated for t in self.tables)
@property
def total_errors(self) -> int:
return sum(t.errors for t in self.tables)
@property
def duration_seconds(self) -> float:
if self.start_time and self.end_time:
return (self.end_time - self.start_time).total_seconds()
return 0.0
@property
def successful_tables(self) -> List[str]:
return [t.table_name for t in self.tables if t.status == "success"]
@property
def failed_tables(self) -> List[str]:
return [t.table_name for t in self.tables if t.status == "failed"]
def generate_summary(self) -> str:
"""Generate a human-readable summary report."""
lines = [
"=" * 60,
"LRCLIB UPSERT REPORT",
"=" * 60,
f"Source: {self.sqlite_source}",
]
if self.dump_date:
lines.append(f"Dump Date: {self.dump_date}")
if self.start_time:
lines.append(f"Started: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
if self.end_time:
lines.append(f"Finished: {self.end_time.strftime('%Y-%m-%d %H:%M:%S')}")
lines.append(f"Duration: {self.duration_seconds:.1f} seconds")
lines.append("")
lines.append("-" * 60)
lines.append("TABLE DETAILS")
lines.append("-" * 60)
for table in self.tables:
status_icon = (
""
if table.status == "success"
else "" if table.status == "failed" else ""
)
lines.append(f"\n{status_icon} {table.table_name}")
lines.append(f" SQLite rows: {table.sqlite_rows:>12,}")
lines.append(f" PG before: {table.pg_rows_before:>12,}")
lines.append(f" PG after: {table.pg_rows_after:>12,}")
lines.append(f" Rows inserted: {table.rows_inserted:>12,}")
lines.append(f" Rows updated: {table.rows_updated:>12,}")
lines.append(f" Duration: {table.duration_seconds:>12.1f}s")
if table.errors > 0:
lines.append(f" Errors: {table.errors:>12}")
if table.error_message:
lines.append(f" Error: {table.error_message}")
lines.append("")
lines.append("-" * 60)
lines.append("SUMMARY")
lines.append("-" * 60)
lines.append(f"Tables processed: {len(self.tables)}")
lines.append(f" Successful: {len(self.successful_tables)}")
lines.append(f" Failed: {len(self.failed_tables)}")
lines.append(f"Total SQLite rows: {self.total_sqlite_rows:,}")
lines.append(f"Total inserted: {self.total_rows_inserted:,}")
lines.append(f"Total updated: {self.total_rows_updated:,}")
lines.append(f"Total errors: {self.total_errors}")
lines.append("=" * 60)
return "\n".join(lines)
def to_json(self) -> str:
"""Export report as JSON."""
data = {
"sqlite_source": self.sqlite_source,
"dump_date": self.dump_date,
"start_time": self.start_time.isoformat() if self.start_time else None,
"end_time": self.end_time.isoformat() if self.end_time else None,
"duration_seconds": self.duration_seconds,
"summary": {
"total_sqlite_rows": self.total_sqlite_rows,
"total_rows_inserted": self.total_rows_inserted,
"total_rows_updated": self.total_rows_updated,
"total_errors": self.total_errors,
"successful_tables": self.successful_tables,
"failed_tables": self.failed_tables,
},
"tables": [asdict(t) for t in self.tables],
}
return json.dumps(data, indent=2)
def load_state() -> Dict[str, Any]:
"""Load the last upsert state from file."""
if STATE_FILE.exists():
try:
return json.loads(STATE_FILE.read_text())
except (json.JSONDecodeError, IOError):
pass
return {}
def save_state(state: Dict[str, Any]) -> None:
"""Save the upsert state to file."""
STATE_FILE.write_text(json.dumps(state, indent=2))
# Discord notification colors
class DiscordColor:
INFO = 0x3498DB # Blue
SUCCESS = 0x2ECC71 # Green
WARNING = 0xF39C12 # Orange
ERROR = 0xE74C3C # Red
NEUTRAL = 0x95A5A6 # Gray
async def discord_notify(
title: str,
description: str,
color: int = DiscordColor.INFO,
fields: Optional[List[Dict[str, Any]]] = None,
footer: Optional[str] = None,
) -> bool:
"""
Send a Discord webhook notification.
Returns True on success, False on failure.
"""
embed = {
"title": title,
"description": description[:2000] if description else "",
"color": color,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
if fields:
embed["fields"] = fields[:25] # Discord limit
if footer:
embed["footer"] = {"text": footer[:2048]}
payload = {"embeds": [embed]}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
DISCORD_WEBHOOK_URL,
json=payload,
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
if resp.status >= 400:
text = await resp.text()
print(f"Discord webhook failed ({resp.status}): {text}")
return False
return True
except Exception as e:
print(f"Discord notification error: {e}")
return False
async def notify_start() -> None:
"""Send notification that the weekly job is starting."""
await discord_notify(
title="🔄 LRCLib DB Update - Starting",
description="Weekly database sync job has started.",
color=DiscordColor.INFO,
fields=[
{
"name": "Time",
"value": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"inline": True,
},
],
)
async def notify_no_update(latest_dump: str, last_upsert: str) -> None:
"""Send notification that no new dump was found."""
await discord_notify(
title=" LRCLib DB Update - No Update Needed",
description="No newer database dump available.",
color=DiscordColor.NEUTRAL,
fields=[
{"name": "Latest Available", "value": latest_dump, "inline": True},
{"name": "Last Synced", "value": last_upsert, "inline": True},
],
)
async def notify_new_dump_found(dump_filename: str, dump_date: str) -> None:
"""Send notification that a new dump was found and download is starting."""
await discord_notify(
title="📥 LRCLib DB Update - New Dump Found",
description="Downloading and processing new database dump.",
color=DiscordColor.INFO,
fields=[
{"name": "Filename", "value": dump_filename, "inline": False},
{"name": "Dump Date", "value": dump_date, "inline": True},
],
)
async def notify_success(report: "UpsertReport") -> None:
"""Send notification for successful upsert."""
duration_min = report.duration_seconds / 60
# Build table summary
table_summary = []
for t in report.tables[:10]: # Limit to first 10 tables
status = "" if t.status == "success" else "" if t.status == "failed" else ""
table_summary.append(f"{status} {t.table_name}: +{t.rows_inserted:,}")
if len(report.tables) > 10:
table_summary.append(f"... and {len(report.tables) - 10} more tables")
await discord_notify(
title="✅ LRCLib DB Update - Success",
description="Database sync completed successfully.",
color=DiscordColor.SUCCESS,
fields=[
{"name": "Dump Date", "value": report.dump_date or "N/A", "inline": True},
{"name": "Duration", "value": f"{duration_min:.1f} min", "inline": True},
{"name": "Tables", "value": str(len(report.tables)), "inline": True},
{
"name": "Rows Inserted",
"value": f"{report.total_rows_inserted:,}",
"inline": True,
},
{
"name": "Rows Updated",
"value": f"{report.total_rows_updated:,}",
"inline": True,
},
{"name": "Errors", "value": str(report.total_errors), "inline": True},
{
"name": "Table Summary",
"value": "\n".join(table_summary) or "No tables",
"inline": False,
},
],
footer=f"Source: {report.sqlite_source}",
)
async def notify_failure(
error_message: str, stage: str, details: Optional[str] = None
) -> None:
"""Send notification for failed upsert."""
description = (
f"Database sync failed during: **{stage}**\n\n```{error_message[:1500]}```"
)
fields = [
{"name": "Stage", "value": stage, "inline": True},
{
"name": "Time",
"value": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"inline": True,
},
]
if details:
fields.append({"name": "Details", "value": details[:1000], "inline": False})
await discord_notify(
title="❌ LRCLib DB Update - Failed",
description=description,
color=DiscordColor.ERROR,
fields=fields,
)
def parse_dump_date(filename: str) -> Optional[datetime]:
"""
Extract date from dump filename.
Supports formats:
- lrclib-db-dump-20260122T091251Z.sqlite3.gz (ISO format with time)
- lrclib-db-dump-2025-01-20.sqlite3.gz (date only)
"""
# Try ISO format first: 20260122T091251Z
match = re.search(r"(\d{4})(\d{2})(\d{2})T(\d{2})(\d{2})(\d{2})Z", filename)
if match:
return datetime(
int(match.group(1)), # year
int(match.group(2)), # month
int(match.group(3)), # day
int(match.group(4)), # hour
int(match.group(5)), # minute
int(match.group(6)), # second
)
# Try simple date format: 2025-01-20
match = re.search(r"(\d{4})-(\d{2})-(\d{2})", filename)
if match:
return datetime.strptime(match.group(0), "%Y-%m-%d")
return None
async def fetch_latest_dump_info() -> Optional[Dict[str, Any]]:
"""
Fetch the latest dump info from lrclib.net/db-dumps.
Uses Playwright to render the JS-generated page content.
Returns dict with 'url', 'filename', 'date' or None if not found.
"""
print(f"Fetching dump list from {LRCLIB_DUMPS_URL}...")
if not PLAYWRIGHT_AVAILABLE:
print("Error: playwright is required for fetching dump info.")
print("Install with: pip install playwright && playwright install chromium")
return None
html = None
try:
async with async_playwright() as p:
# Launch headless browser
browser = await p.chromium.launch(headless=True)
page = await browser.new_page()
# Navigate and wait for content to load
await page.goto(LRCLIB_DUMPS_URL, wait_until="networkidle")
# Wait for anchor tags with .sqlite3.gz links to appear
try:
await page.wait_for_selector('a[href*=".sqlite3.gz"]', timeout=15000)
except Exception:
print("Warning: Timed out waiting for download links, trying anyway...")
# Get rendered HTML
html = await page.content()
await browser.close()
except Exception as e:
print(f"Error rendering page with Playwright: {e}")
return None
if not html:
print("Failed to get page content")
return None
# Find all .sqlite3.gz links
pattern = r'href=["\']([^"\']*\.sqlite3\.gz)["\']'
matches = re.findall(pattern, html, re.IGNORECASE)
# Deduplicate matches (same link may appear multiple times in HTML)
matches = list(dict.fromkeys(matches))
if not matches:
print("No .sqlite3.gz files found on the page")
# Debug: show what we got
print(f"Page content length: {len(html)} chars")
return None
print(f"Found {len(matches)} dump file(s)")
# Parse dates and find the newest
dumps = []
for match in matches:
# Handle relative URLs
if match.startswith("/"):
url = f"https://lrclib.net{match}"
elif not match.startswith("http"):
url = f"https://lrclib.net/db-dumps/{match}"
else:
url = match
filename = url.split("/")[-1]
date = parse_dump_date(filename)
if date:
dumps.append({"url": url, "filename": filename, "date": date})
print(f" Found: {filename} ({date.strftime('%Y-%m-%d %H:%M:%S')})")
if not dumps:
print("Could not parse dates from dump filenames")
return None
# Sort by date descending and return the newest
from datetime import datetime
def parse_date_safe(date_obj):
if isinstance(date_obj, datetime):
return date_obj
return datetime.min # Default to a minimum date if parsing fails
dumps.sort(key=lambda x: parse_date_safe(x.get("date", datetime.min)), reverse=True)
latest = dumps[0]
# Handle type issue with strftime
if isinstance(latest["date"], datetime):
formatted_date = latest["date"].strftime("%Y-%m-%d %H:%M:%S")
else:
formatted_date = "Unknown date"
print(f"Latest dump: {latest['filename']} ({formatted_date})")
return latest
async def download_and_extract_dump(
url: str, dest_dir: Optional[str] = None, max_retries: int = 5
) -> tuple[Optional[str], Optional[str]]:
"""
Download a .sqlite3.gz file and extract it with streaming.
Supports resume on connection failures.
Returns tuple of (path to extracted .sqlite3 file, error message).
On success: (path, None). On failure: (None, error_message).
"""
filename = url.split("/")[-1]
if dest_dir:
dest_path = Path(dest_dir)
dest_path.mkdir(parents=True, exist_ok=True)
else:
dest_path = Path("/nvme/tmp")
gz_path = dest_path / filename
sqlite_path = dest_path / filename.replace(".gz", "")
# If an extracted sqlite file already exists, skip download and extraction
if sqlite_path.exists() and sqlite_path.stat().st_size > 0:
print(f"Found existing extracted SQLite file {sqlite_path}; skipping download/extract")
return str(sqlite_path), None
# Streaming download with retry and resume support
print(f"Downloading {url}...")
for attempt in range(1, max_retries + 1):
try:
# Check if partial download exists for resume
downloaded = 0
headers = {}
if gz_path.exists():
downloaded = gz_path.stat().st_size
headers["Range"] = f"bytes={downloaded}-"
print(f" Resuming from {downloaded / (1024 * 1024):.1f} MB...")
timeout = aiohttp.ClientTimeout(
total=None, # No total timeout for large files
connect=60,
sock_read=300, # 5 min read timeout per chunk
)
async with aiohttp.ClientSession() as session:
async with session.get(
url, timeout=timeout, headers=headers
) as response:
# Handle resume response
if response.status == 416: # Range not satisfiable - file complete
print(" Download already complete.")
break
elif response.status == 206: # Partial content - resuming
total = downloaded + int(
response.headers.get("content-length", 0)
)
elif response.status == 200:
# Server doesn't support resume or fresh download
if downloaded > 0:
print(" Server doesn't support resume, restarting...")
downloaded = 0
total = int(response.headers.get("content-length", 0))
else:
response.raise_for_status()
total = int(response.headers.get("content-length", 0))
# Open in append mode if resuming, write mode if fresh
mode = "ab" if downloaded > 0 and response.status == 206 else "wb"
if mode == "wb":
downloaded = 0
async with aiofiles.open(gz_path, mode) as f:
async for chunk in response.content.iter_chunked(
1024 * 1024
): # 1MB chunks
await f.write(chunk)
downloaded += len(chunk)
if total:
pct = (downloaded / total) * 100
mb_down = downloaded / (1024 * 1024)
mb_total = total / (1024 * 1024)
print(
f" Downloaded: {mb_down:.1f} / {mb_total:.1f} MB ({pct:.1f}%)",
end="\r",
)
print() # Newline after progress
break # Success, exit retry loop
except (aiohttp.ClientError, ConnectionResetError, asyncio.TimeoutError) as e:
print(f"\n Download error (attempt {attempt}/{max_retries}): {e}")
if attempt < max_retries:
wait_time = min(30 * attempt, 120) # Exponential backoff, max 2 min
print(f" Retrying in {wait_time} seconds...")
await asyncio.sleep(wait_time)
else:
error_msg = f"Download failed after {max_retries} attempts: {e}"
print(f"Error: {error_msg}")
# Clean up partial file on final failure
if gz_path.exists():
gz_path.unlink()
return None, error_msg
# Verify download completed
if not gz_path.exists() or gz_path.stat().st_size == 0:
error_msg = "Download failed - file is empty or missing"
print(f"Error: {error_msg}")
return None, error_msg
# Extract (sync, but fast with streaming)
print(
f"Extracting {gz_path} ({gz_path.stat().st_size / (1024**3):.2f} GB compressed)..."
)
try:
# Use streaming decompression for large files
async with aiofiles.open(sqlite_path, "wb") as f_out:
with gzip.open(gz_path, "rb") as f_in:
extracted = 0
while True:
chunk = f_in.read(1024 * 1024 * 10) # 10MB chunks
if not chunk:
break
await f_out.write(chunk)
extracted += len(chunk)
print(f" Extracted: {extracted / (1024**3):.2f} GB", end="\r")
print()
except Exception as e:
error_msg = f"Extraction failed: {e}"
print(f"Error: {error_msg}")
return None, error_msg
print(
f"Extracted to: {sqlite_path} ({sqlite_path.stat().st_size / (1024**3):.2f} GB)"
)
# Note: do NOT remove the .gz file here. Keep source files until import completes successfully.
return str(sqlite_path), None
async def extract_gz_file(gz_path: str, output_path: str) -> None:
"""Extract a .gz file using pigz for multi-threaded decompression."""
try:
# Use pigz for faster decompression
print(f"Extracting {gz_path} to {output_path} using pigz...")
subprocess.run(
["pigz", "-d", "-p", str(os.cpu_count()), "-c", gz_path],
stdout=open(output_path, "wb"),
check=True,
)
print(f"Extraction completed: {output_path}")
except FileNotFoundError:
raise RuntimeError(
"pigz is not installed. Please install it for faster decompression."
)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Extraction failed: {e}")
async def get_sqlite_tables(db: aiosqlite.Connection) -> List[str]:
"""Get list of all tables in SQLite database."""
async with db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
) as cursor:
rows = await cursor.fetchall()
return [row[0] for row in rows]
async def get_sqlite_schema(
db: aiosqlite.Connection, table_name: str
) -> List[Dict[str, Any]]:
"""Get column info from SQLite table."""
async with db.execute(f"PRAGMA table_info({table_name});") as cursor:
rows = await cursor.fetchall()
return [
{
"cid": row[0],
"name": row[1],
"type": row[2],
"notnull": row[3],
"default": row[4],
"pk": row[5],
}
for row in rows
]
async def get_pg_schema(pool: asyncpg.Pool, table_name: str) -> Dict[str, Any]:
"""Get PostgreSQL table schema including primary key and unique constraints."""
async with pool.acquire() as conn:
# Get columns
columns = await conn.fetch(
"""
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_name = $1
ORDER BY ordinal_position;
""",
table_name,
)
columns = [
{"name": row["column_name"], "type": row["data_type"]} for row in columns
]
# Get primary key columns
pk_columns = await conn.fetch(
"""
SELECT a.attname
FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = $1::regclass
AND i.indisprimary;
""",
table_name,
)
pk_columns = [row["attname"] for row in pk_columns]
# Get unique constraint columns
unique_rows = await conn.fetch(
"""
SELECT a.attname, c.conname
FROM pg_constraint c
JOIN pg_attribute a ON a.attrelid = c.conrelid AND a.attnum = ANY(c.conkey)
WHERE c.conrelid = $1::regclass
AND c.contype = 'u'
ORDER BY c.conname, a.attnum;
""",
table_name,
)
unique_constraints: Dict[str, List[str]] = {}
for row in unique_rows:
constraint_name = row["conname"]
if constraint_name not in unique_constraints:
unique_constraints[constraint_name] = []
unique_constraints[constraint_name].append(row["attname"])
return {
"columns": columns,
"pk_columns": pk_columns,
"unique_constraints": unique_constraints,
}
def clean_value(value: Any) -> Any:
"""Basic cleaning for values (strip NULs). Use more specific coercion in per-column handling."""
if value is None:
return None
if isinstance(value, str):
return value.replace("\x00", "")
return value
def parse_datetime_string(s: str) -> Optional[datetime]:
"""Parse a datetime string into a timezone-aware datetime if possible."""
if s is None:
return None
if isinstance(s, datetime):
return s
if not isinstance(s, str):
return None
try:
# Normalize excessive fractional seconds (more than 6 digits) by truncating to 6
s_norm = re.sub(r"\.(\d{6})\d+", r".\1", s)
# Try fromisoformat (handles 'YYYY-MM-DD HH:MM:SS.mmm+00:00' and ISO variants)
dt = datetime.fromisoformat(s_norm)
# If naive, assume UTC
if dt.tzinfo is None:
from datetime import timezone
dt = dt.replace(tzinfo=timezone.utc)
return dt
except Exception:
# Try common SQLite format
for fmt in (
"%Y-%m-%d %H:%M:%S.%f%z",
"%Y-%m-%d %H:%M:%S.%f",
"%Y-%m-%d %H:%M:%S%z",
"%Y-%m-%d %H:%M:%S",
):
try:
s_norm = re.sub(r"\.(\d{6})\d+", r".\1", s)
dt = datetime.strptime(s_norm, fmt)
if dt.tzinfo is None:
from datetime import timezone
dt = dt.replace(tzinfo=timezone.utc)
return dt
except Exception:
continue
return None
def coerce_value_for_pg(value: Any, pg_type: str) -> Any:
"""Coerce a SQLite value into an appropriate Python type for asyncpg based on PostgreSQL data type."""
if value is None:
return None
pg_type = (pg_type or "").lower()
# Boolean
if "boolean" in pg_type:
# SQLite often stores booleans as 0/1
if isinstance(value, (int, float)):
return bool(value)
if isinstance(value, str):
v = value.strip().lower()
if v in ("1", "t", "true", "yes"):
return True
if v in ("0", "f", "false", "no"):
return False
if isinstance(value, bool):
return value
# Fallback
return bool(value)
# Timestamps
if "timestamp" in pg_type or "date" in pg_type or "time" in pg_type:
if isinstance(value, datetime):
return value
if isinstance(value, str):
dt = parse_datetime_string(value)
if dt:
return dt
# If SQLite stores as numeric (unix epoch)
if isinstance(value, (int, float)):
try:
return datetime.fromtimestamp(value)
except Exception:
pass
return value
# JSON
if "json" in pg_type:
if isinstance(value, str):
try:
return json.loads(value)
except Exception:
return value
return value
# Numeric integer types
if pg_type in ("integer", "bigint", "smallint") or "int" in pg_type:
if isinstance(value, int):
return value
if isinstance(value, float):
return int(value)
if isinstance(value, str) and value.isdigit():
return int(value)
try:
return int(value)
except Exception:
return value
# Floating point
if (
pg_type in ("real", "double precision", "numeric", "decimal")
or "numeric" in pg_type
):
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, str):
try:
return float(value)
except Exception:
return value
# Text-like types: coerce to string if not None
if any(k in pg_type for k in ("char", "text", "uuid")):
try:
return str(value)
except Exception:
return clean_value(value)
# Default: basic cleaning
return clean_value(value)
# Diagnostics directory for failing chunks
DIAGNOSTICS_DIR = Path(__file__).parent / "diagnostics"
DIAGNOSTICS_DIR.mkdir(exist_ok=True)
async def save_failed_rows(
dump_basename: str, table_name: str, offset: int, rows: List[tuple], error: str
) -> str:
"""Save failing rows as a gzipped JSONL file for inspection.
Returns path to diagnostics file.
"""
fname = f"{dump_basename}.{table_name}.offset{offset}.failed.jsonl.gz"
path = DIAGNOSTICS_DIR / fname
try:
import gzip as _gzip
with _gzip.open(path, "wt", encoding="utf-8") as fh:
for r in rows:
record = {"row": r, "error": error}
fh.write(json.dumps(record, default=str) + "\n")
return str(path)
except Exception as e:
print(f"Failed to write diagnostics file: {e}")
return ""
async def attempt_insert_chunk(
conn,
upsert_sql: str,
cleaned_rows: List[tuple],
table_name: str,
offset: int,
dump_basename: str,
batch_size: int = BATCH_SIZE,
per_row_limit: int = 100,
) -> tuple[int, int, Optional[str]]:
"""Try inserting a chunk; on failure, split and try smaller batches then per-row.
Returns (successful_rows, errors, diagnostics_path_or_None).
"""
success = 0
errors = 0
diagnostics_path = None
# Fast path: whole chunk
try:
await conn.executemany(upsert_sql, cleaned_rows)
return len(cleaned_rows), 0, None
except Exception as e:
# Fall through to batch splitting
first_exc = str(e)
# Try batch splitting
for i in range(0, len(cleaned_rows), batch_size):
batch = cleaned_rows[i : i + batch_size]
try:
await conn.executemany(upsert_sql, batch)
success += len(batch)
continue
except Exception:
# Try per-row for this batch
failing = []
for j, row in enumerate(batch):
try:
await conn.execute(upsert_sql, *row)
success += 1
except Exception as re_exc:
errors += 1
failing.append((j + i, row, str(re_exc)))
# If too many per-row failures, stop collecting
if errors >= per_row_limit:
break
# If failures were recorded, save diagnostics
if failing:
# Prepare rows for saving (row index and values)
rows_to_save = [r[1] for r in failing]
diagnostics_path = await save_failed_rows(
dump_basename, table_name, offset + i, rows_to_save, first_exc
)
# Continue to next batch
return success, errors, diagnostics_path
async def _quote_ident(name: str) -> str:
"""Return a safely quoted SQL identifier."""
return '"' + name.replace('"', '""') + '"'
async def fallback_upsert_using_temp(
conn,
table_name: str,
common_columns: List[str],
pk_columns: List[str],
rows_iterable: List[tuple],
dump_basename: str,
) -> tuple[int, int, Optional[str]]:
"""Fallback upsert that uses a temporary table and performs UPDATE then INSERT.
Returns (rows_upserted_count, errors_count, diagnostics_path_or_None).
"""
errors = 0
diagnostics = None
inserted = 0
updated = 0
# Build identifiers
tmp_name = f"tmp_upsert_{table_name}_{os.getpid()}" # per-process unique-ish
q_table = await _quote_ident(table_name)
q_tmp = await _quote_ident(tmp_name)
col_names_str = ", ".join([f'"{c}"' for c in common_columns])
# Create temp table like target
try:
await conn.execute(f"CREATE TEMP TABLE {q_tmp} (LIKE {q_table} INCLUDING ALL)")
except Exception as e:
return (
0,
1,
await save_failed_rows(
dump_basename, table_name, 0, rows_iterable[:100], str(e)
),
)
# Insert into temp (in batches)
try:
# Insert all rows into temp
insert_sql = f"INSERT INTO {q_tmp} ({col_names_str}) VALUES ({', '.join([f'${i + 1}' for i in range(len(common_columns))])})"
for i in range(0, len(rows_iterable), BATCH_SIZE):
batch = rows_iterable[i : i + BATCH_SIZE]
try:
await conn.executemany(insert_sql, batch)
except Exception as e:
# Save diagnostics and abort
diagnostics = await save_failed_rows(
dump_basename,
table_name,
i,
[tuple(r) for r in batch[:100]],
str(e),
)
errors += len(batch)
return 0, errors, diagnostics
# Count tmp rows
tmp_count = await conn.fetchval(f"SELECT COUNT(*) FROM {q_tmp}")
# Count matches in target (existing rows)
pk_cond = " AND ".join(
["t." + f'"{k}"' + " = tmp." + f'"{k}"' for k in pk_columns]
)
match_count = await conn.fetchval(
f"SELECT COUNT(*) FROM {q_tmp} tmp JOIN {q_table} t ON {pk_cond}"
)
# Perform update of matching rows (only non-PK columns)
update_cols = [c for c in common_columns if c not in pk_columns]
if update_cols:
set_clauses = ", ".join(
["{q_table}." + f'"{c}"' + " = tmp." + f'"{c}"' for c in update_cols]
)
await conn.execute(
f"UPDATE {q_table} SET {set_clauses} FROM {q_tmp} tmp WHERE {pk_cond}"
)
updated = match_count
# Insert non-existing rows
insert_sql2 = f"INSERT INTO {q_table} ({col_names_str}) SELECT {col_names_str} FROM {q_tmp} tmp WHERE NOT EXISTS (SELECT 1 FROM {q_table} t WHERE {pk_cond})"
await conn.execute(insert_sql2)
# Approximate inserted count
inserted = tmp_count - match_count
except Exception as e:
diagnostics = await save_failed_rows(
dump_basename, table_name, 0, rows_iterable[:100], str(e)
)
errors += 1
finally:
# Drop temp table
try:
await conn.execute(f"DROP TABLE IF EXISTS {q_tmp}")
except Exception:
pass
return inserted + updated, errors, diagnostics
async def validate_sample_rows(
sqlite_db: aiosqlite.Connection,
pg_col_types: List[str],
col_names_str: str,
table_name: str,
dump_basename: str,
sample_n: int = 1000,
) -> tuple[int, int, Optional[str]]:
"""Validate a sample of rows by attempting to coerce values and detecting likely mismatches.
Returns (fail_count, total_checked, diagnostics_path_or_None).
"""
fail_count = 0
total = 0
failing_rows = []
async with sqlite_db.execute(
f"SELECT {col_names_str} FROM {table_name} LIMIT {sample_n};"
) as cursor:
rows = await cursor.fetchall()
for i, row in enumerate(rows):
total += 1
# Coerce and check types
for val, pg_type in zip(row, pg_col_types):
coerced = coerce_value_for_pg(val, pg_type)
# Heuristics for likely failure
pg_type_l = (pg_type or "").lower()
if "timestamp" in pg_type_l or "date" in pg_type_l or "time" in pg_type_l:
if coerced is None or not isinstance(coerced, datetime):
fail_count += 1
failing_rows.append(row)
break
elif "boolean" in pg_type_l:
if not isinstance(coerced, bool):
fail_count += 1
failing_rows.append(row)
break
elif any(k in pg_type_l for k in ("int", "bigint", "smallint")):
if not isinstance(coerced, int):
# allow numeric strings that can be int
if not (isinstance(coerced, str) and coerced.isdigit()):
fail_count += 1
failing_rows.append(row)
break
# else: assume text/json ok
diag = None
if failing_rows:
rows_to_save = [tuple(r) for r in failing_rows[:100]]
diag = await save_failed_rows(
dump_basename,
f"{table_name}.sample",
0,
rows_to_save,
"Validation failures",
)
return fail_count, total, diag
def determine_conflict_columns(
pg_schema: Dict[str, Any],
sqlite_schema: List[Dict[str, Any]],
table_name: str,
) -> List[str]:
"""Determine which columns to use for conflict detection.
Priority:
1. PostgreSQL primary key
2. PostgreSQL unique constraint
3. SQLite primary key (only if PostgreSQL has a matching unique/PK constraint)
"""
# 1) PostgreSQL primary key
if pg_schema["pk_columns"]:
return pg_schema["pk_columns"]
# 2) PostgreSQL unique constraint
if pg_schema["unique_constraints"]:
first_constraint = list(pg_schema["unique_constraints"].values())[0]
return first_constraint
# 3) Consider SQLite primary key only if PostgreSQL has a matching constraint
sqlite_pk_columns = [col["name"] for col in sqlite_schema if col["pk"] > 0]
sqlite_pk_columns = [
col["name"]
for col in sorted(
[c for c in sqlite_schema if c["pk"] > 0], key=lambda x: x["pk"]
)
]
if sqlite_pk_columns:
# Check if any PG unique constraint matches these columns (order-insensitive)
for cols in pg_schema["unique_constraints"].values():
if set(cols) == set(sqlite_pk_columns):
return sqlite_pk_columns
# Also accept if PG has a primary key covering same columns (already handled), so if we reach here
# PG has no matching constraint — cannot safely perform ON CONFLICT based upsert
raise ValueError(
f"Table {table_name} has a SQLite primary key ({', '.join(sqlite_pk_columns)}) but no matching unique/PK constraint in PostgreSQL. Cannot perform upsert."
)
# Nothing found
raise ValueError(
f"Table {table_name} has no primary key or unique constraint in either PostgreSQL or SQLite. Cannot perform upsert."
)
def build_upsert_sql(
table_name: str,
columns: List[str],
conflict_columns: List[str],
update_columns: Optional[List[str]] = None,
) -> str:
"""Build PostgreSQL upsert SQL with $1, $2, etc. placeholders."""
col_names_str = ", ".join([f'"{col}"' for col in columns])
placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
conflict_str = ", ".join([f'"{col}"' for col in conflict_columns])
if update_columns is None:
update_columns = [c for c in columns if c not in conflict_columns]
if not update_columns:
return f"INSERT INTO {table_name} ({col_names_str}) VALUES ({placeholders}) ON CONFLICT ({conflict_str}) DO NOTHING"
set_clauses = [f'"{col}" = EXCLUDED."{col}"' for col in update_columns]
set_str = ", ".join(set_clauses)
return f"INSERT INTO {table_name} ({col_names_str}) VALUES ({placeholders}) ON CONFLICT ({conflict_str}) DO UPDATE SET {set_str}"
async def upsert_table(
sqlite_db: aiosqlite.Connection,
pg_pool: asyncpg.Pool,
table_name: str,
dry_run: bool = False,
verbose: bool = False,
dump_basename: Optional[str] = None,
) -> TableReport:
"""Upsert data from SQLite table to PostgreSQL table."""
import time
start_time = time.time()
report = TableReport(table_name=table_name)
# Get schemas
sqlite_schema = await get_sqlite_schema(sqlite_db, table_name)
pg_schema = await get_pg_schema(pg_pool, table_name)
if not pg_schema["columns"]:
report.status = "skipped"
report.error_message = "Table does not exist in PostgreSQL"
print(f" ⚠ Table {table_name} does not exist in PostgreSQL, skipping...")
return report
# Match columns
sqlite_columns = [col["name"] for col in sqlite_schema]
pg_columns = [col["name"] for col in pg_schema["columns"]]
common_columns = [col for col in sqlite_columns if col in pg_columns]
if not common_columns:
report.status = "skipped"
report.error_message = "No common columns"
print(f" ⚠ No common columns for table {table_name}")
return report
if verbose:
print(f" Common columns: {', '.join(common_columns)}")
print(f" PG Primary keys: {pg_schema['pk_columns']}")
# Determine conflict columns (tries PG first, then SQLite)
use_temp_fallback = False
use_exists_upsert = False
sqlite_pk_columns = []
pk_column = None
try:
conflict_columns = determine_conflict_columns(
pg_schema, sqlite_schema, table_name
)
except ValueError as e:
sqlite_pk_columns = [col["name"] for col in sqlite_schema if col["pk"] > 0]
if sqlite_pk_columns:
# Prefer existence-based upsert for single-column PKs; fallback to temp-table for composite PKs
if len(sqlite_pk_columns) == 1:
print(
f" ⚠ Table {table_name} has a SQLite primary key ({sqlite_pk_columns[0]}) but no matching unique/PK constraint in PostgreSQL. Will use existence-based update/insert fallback."
)
use_exists_upsert = True
pk_column = sqlite_pk_columns[0]
conflict_columns = [pk_column]
else:
print(
f" ⚠ Table {table_name} has a SQLite primary key ({', '.join(sqlite_pk_columns)}) but no matching unique/PK constraint in PostgreSQL. Will attempt safer temp-table fallback."
)
use_temp_fallback = True
conflict_columns = sqlite_pk_columns
else:
report.status = "skipped"
report.error_message = str(e)
print(f"{e}")
return report
if verbose:
print(f" Using conflict columns: {conflict_columns}")
# Build upsert SQL for standard ON CONFLICT approach
upsert_sql = build_upsert_sql(table_name, common_columns, conflict_columns)
# Get counts
async with sqlite_db.execute(f"SELECT COUNT(*) FROM {table_name};") as cursor:
row = await cursor.fetchone()
report.sqlite_rows = row[0] if row else 0
async with pg_pool.acquire() as conn:
report.pg_rows_before = await conn.fetchval(
f"SELECT COUNT(*) FROM {table_name};"
)
print(
f" SQLite rows: {report.sqlite_rows:,}, PG rows before: {report.pg_rows_before:,}"
)
if report.sqlite_rows == 0:
report.status = "success"
report.pg_rows_after = report.pg_rows_before
print(" Table is empty, nothing to upsert.")
return report
# Prepare column query
col_names_str = ", ".join([f'"{col}"' for col in common_columns])
# Prepare per-column PG types for validation and coercion
pg_col_type_map = {c["name"]: c["type"] for c in pg_schema["columns"]}
pg_col_types = [pg_col_type_map.get(col, "text") for col in common_columns]
# Local chunk size (adaptive)
chunk_size = CHUNK_SIZE
# Pre-validate a sample of rows to catch common type mismatches early and adapt
sample_n = min(1000, chunk_size, report.sqlite_rows)
if sample_n > 0:
fail_count, total_checked, sample_diag = await validate_sample_rows(
sqlite_db,
pg_col_types,
col_names_str,
table_name,
dump_basename or "unknown_dump",
sample_n,
)
if fail_count:
pct = (fail_count / total_checked) * 100
print(
f" ⚠ Validation: {fail_count}/{total_checked} sample rows ({pct:.1f}%) appear problematic; reducing chunk size and writing diagnostics"
)
if sample_diag:
print(f" ⚠ Sample diagnostics: {sample_diag}")
# Reduce chunk size to be safer
old_chunk = chunk_size
chunk_size = max(1000, chunk_size // 4)
print(f" ⚠ Adjusted chunk_size: {old_chunk} -> {chunk_size}")
# If we determined earlier that we need to use the temp-table fallback, handle that path now
if use_temp_fallback:
print(
" Using temp-table fallback upsert (may be slower but safe on missing PG constraints)"
)
# Process in chunks using cursor iteration but push to temp upsert helper
offset = 0
async with pg_pool.acquire() as conn:
while offset < report.sqlite_rows:
async with sqlite_db.execute(
f"SELECT {col_names_str} FROM {table_name} LIMIT {chunk_size} OFFSET {offset};"
) as cursor:
rows = await cursor.fetchall()
if not rows:
break
# Coerce
cleaned_rows = []
for row in rows:
coerced = []
for val, pg_type in zip(row, pg_col_types):
coerced.append(coerce_value_for_pg(val, pg_type))
cleaned_rows.append(tuple(coerced))
# Attempt fallback upsert
try:
async with conn.transaction():
upsert_count, errs, diag = await fallback_upsert_using_temp(
conn,
table_name,
common_columns,
sqlite_pk_columns,
cleaned_rows,
dump_basename or "unknown_dump",
)
except Exception as e:
upsert_count, errs, diag = (
0,
1,
await save_failed_rows(
dump_basename or "unknown_dump",
table_name,
offset,
cleaned_rows[:100],
str(e),
),
)
report.rows_affected += upsert_count
report.errors += errs
if diag:
print(f"\n ⚠ Diagnostics written: {diag}")
if report.errors > 10:
print(" ✗ Too many errors, aborting table")
report.status = "failed"
report.error_message = f"Too many errors ({report.errors})"
break
offset += chunk_size
# Set final counts
async with pg_pool.acquire() as conn:
report.pg_rows_after = await conn.fetchval(
f"SELECT COUNT(*) FROM {table_name};"
)
report.duration_seconds = time.time() - start_time
if report.errors == 0:
report.status = "success"
elif report.errors <= 10:
report.status = "success"
else:
report.status = "failed"
return report
if use_exists_upsert:
print(
" Using existence-based upsert (INSERT for new rows, UPDATE for existing rows)"
)
# Process in chunks using cursor iteration
offset = 0
async with pg_pool.acquire() as conn:
while offset < report.sqlite_rows:
async with sqlite_db.execute(
f"SELECT {col_names_str} FROM {table_name} LIMIT {chunk_size} OFFSET {offset};"
) as cursor:
rows = await cursor.fetchall()
if not rows:
break
# Coerce
cleaned_rows = []
for row in rows:
coerced = []
for val, pg_type in zip(row, pg_col_types):
coerced.append(coerce_value_for_pg(val, pg_type))
cleaned_rows.append(tuple(coerced))
# Build list of PK values
pk_index = common_columns.index(pk_column)
pk_values = [r[pk_index] for r in cleaned_rows]
# Query existing PKs
existing = set()
try:
rows_found = await conn.fetch(
f"SELECT {pk_column} FROM {table_name} WHERE {pk_column} = ANY($1)",
pk_values,
)
existing = set([r[pk_column] for r in rows_found])
except Exception as e:
# On error, save diagnostics and abort
diag = await save_failed_rows(
dump_basename or "unknown_dump",
table_name,
offset,
cleaned_rows[:100],
str(e),
)
print(f"\n ⚠ Diagnostics written: {diag}")
report.errors += 1
break
insert_rows = [r for r in cleaned_rows if r[pk_index] not in existing]
update_rows = [r for r in cleaned_rows if r[pk_index] in existing]
# Perform inserts
if insert_rows:
col_names_csv = ", ".join([f'"{c}"' for c in common_columns])
placeholders = ", ".join(
[f"${i + 1}" for i in range(len(common_columns))]
)
insert_sql = f"INSERT INTO {table_name} ({col_names_csv}) VALUES ({placeholders})"
try:
await conn.executemany(insert_sql, insert_rows)
report.rows_affected += len(insert_rows)
except Exception as e:
diag = await save_failed_rows(
dump_basename or "unknown_dump",
table_name,
offset,
insert_rows[:100],
str(e),
)
print(f"\n ⚠ Diagnostics written: {diag}")
report.errors += 1
# Perform updates (per-row executemany)
if update_rows:
update_cols = [c for c in common_columns if c != pk_column]
set_clause = ", ".join(
[f'"{c}" = ${i + 1}' for i, c in enumerate(update_cols)]
)
# For executemany we'll pass values ordered as [col1, col2, ..., pk]
update_sql = f'UPDATE {table_name} SET {set_clause} WHERE "{pk_column}" = ${len(update_cols) + 1}'
update_values = []
for r in update_rows:
vals = [r[common_columns.index(c)] for c in update_cols]
vals.append(r[pk_index])
update_values.append(tuple(vals))
try:
await conn.executemany(update_sql, update_values)
report.rows_affected += len(update_rows)
except Exception as e:
diag = await save_failed_rows(
dump_basename or "unknown_dump",
table_name,
offset,
update_rows[:100],
str(e),
)
print(f"\n ⚠ Diagnostics written: {diag}")
report.errors += 1
if report.errors > 10:
print(" ✗ Too many errors, aborting table")
report.status = "failed"
report.error_message = f"Too many errors ({report.errors})"
break
offset += chunk_size
# Set final counts
async with pg_pool.acquire() as conn:
report.pg_rows_after = await conn.fetchval(
f"SELECT COUNT(*) FROM {table_name};"
)
report.duration_seconds = time.time() - start_time
if report.errors == 0:
report.status = "success"
elif report.errors <= 10:
report.status = "success"
else:
report.status = "failed"
return report
# Process in chunks using cursor iteration
offset = 0
async with pg_pool.acquire() as conn:
while offset < report.sqlite_rows:
try:
# Fetch chunk from SQLite
async with sqlite_db.execute(
f"SELECT {col_names_str} FROM {table_name} LIMIT {chunk_size} OFFSET {offset};"
) as cursor:
rows = await cursor.fetchall()
if not rows:
break
# Coerce data per-column based on PG types
pg_col_types = []
pg_col_type_map = {c["name"]: c["type"] for c in pg_schema["columns"]}
for col in common_columns:
pg_col_types.append(pg_col_type_map.get(col, "text"))
cleaned_rows = []
for row in rows:
coerced = []
for val, pg_type in zip(row, pg_col_types):
coerced.append(coerce_value_for_pg(val, pg_type))
cleaned_rows.append(tuple(coerced))
if not dry_run:
# Try inserting chunk with adaptive recovery
dump_basename_local = dump_basename or "unknown_dump"
success_count, err_count, diag = await attempt_insert_chunk(
conn,
upsert_sql,
cleaned_rows,
table_name,
offset,
dump_basename_local,
)
report.rows_affected += success_count
report.errors += err_count
if diag:
print(f"\n ⚠ Diagnostics written: {diag}")
# Record in report for visibility
report.error_message = (
(report.error_message + "; ")
if report.error_message
else ""
) + f"Diagnostics: {diag}"
if err_count > 0:
print(
f"\n ⚠ Error at offset {offset}: {err_count} problematic rows (see diagnostics)"
)
if report.errors > 10:
print(" ✗ Too many errors, aborting table")
report.status = "failed"
report.error_message = f"Too many errors ({report.errors})" + (
f"; {report.error_message}" if report.error_message else ""
)
break
offset += CHUNK_SIZE
progress = min(offset, report.sqlite_rows)
pct = (progress / report.sqlite_rows) * 100
print(
f" Progress: {progress:,}/{report.sqlite_rows:,} ({pct:.1f}%)",
end="\r",
)
except Exception as e:
print(f"\n ⚠ Error at offset {offset}: {e}")
report.errors += 1
offset += CHUNK_SIZE
if report.errors > 10:
print(" ✗ Too many errors, aborting table")
report.status = "failed"
report.error_message = f"Too many errors ({report.errors})"
break
print() # Newline after progress
# Get final count
async with pg_pool.acquire() as conn:
report.pg_rows_after = await conn.fetchval(
f"SELECT COUNT(*) FROM {table_name};"
)
report.duration_seconds = time.time() - start_time
if report.errors == 0:
report.status = "success"
elif report.errors <= 10:
report.status = "success"
else:
report.status = "failed"
return report
async def upsert_database(
sqlite_path: str,
tables: Optional[List[str]] = None,
dry_run: bool = False,
verbose: bool = False,
dump_date: Optional[str] = None,
) -> UpsertReport:
"""Main upsert function - process SQLite dump into PostgreSQL."""
report = UpsertReport(sqlite_source=sqlite_path, dump_date=dump_date)
report.start_time = datetime.now()
print(f"{'[DRY RUN] ' if dry_run else ''}SQLite → PostgreSQL Upsert (Async)")
print(f"SQLite: {sqlite_path}")
print(f"PostgreSQL: {PG_CONFIG['database']}@{PG_CONFIG['host']}")
print(f"Chunk size: {CHUNK_SIZE:,} rows\n")
# Connect to databases
print("Connecting to databases...")
sqlite_db = await aiosqlite.connect(sqlite_path)
pg_pool = await asyncpg.create_pool(
host=PG_CONFIG["host"],
port=PG_CONFIG["port"],
database=PG_CONFIG["database"],
user=PG_CONFIG["user"],
password=PG_CONFIG["password"],
min_size=2,
max_size=10,
)
try:
# Get tables
available_tables = await get_sqlite_tables(sqlite_db)
if tables:
process_tables = [t for t in tables if t in available_tables]
missing = [t for t in tables if t not in available_tables]
if missing:
print(f"⚠ Tables not found in SQLite: {', '.join(missing)}")
else:
process_tables = available_tables
print(f"Tables to process: {', '.join(process_tables)}\n")
for i, table_name in enumerate(process_tables, 1):
print(f"[{i}/{len(process_tables)}] Processing table: {table_name}")
try:
dump_basename = Path(sqlite_path).name
table_report = await upsert_table(
sqlite_db, pg_pool, table_name, dry_run, verbose, dump_basename
)
report.tables.append(table_report)
if table_report.status == "success":
print(
f" ✓ Completed: +{table_report.rows_inserted:,} inserted, ~{table_report.rows_updated:,} updated"
)
elif table_report.status == "skipped":
print(" ○ Skipped")
else:
print(" ✗ Failed")
except Exception as e:
table_report = TableReport(
table_name=table_name, status="failed", error_message=str(e)
)
report.tables.append(table_report)
print(f" ✗ Failed: {e}")
print()
finally:
await sqlite_db.close()
await pg_pool.close()
report.end_time = datetime.now()
return report
async def test_pg_connection() -> bool:
"""Test PostgreSQL connection before proceeding with download."""
print("Testing PostgreSQL connection...")
try:
conn = await asyncpg.connect(
host=PG_CONFIG["host"],
port=PG_CONFIG["port"],
database=PG_CONFIG["database"],
user=PG_CONFIG["user"],
password=PG_CONFIG["password"],
)
version = await conn.fetchval("SELECT version();")
await conn.close()
print(f" ✓ Connected to PostgreSQL: {version[:50]}...")
return True
except Exception as e:
print(f" ✗ PostgreSQL connection failed: {e}")
return False
async def check_and_fetch_latest(
force: bool = False,
dest_dir: Optional[str] = None,
dry_run: bool = False,
verbose: bool = False,
notify: bool = False,
) -> Optional[UpsertReport]:
"""Check for new dumps on lrclib.net and upsert if newer than last run."""
# Send start notification
if notify:
await notify_start()
# Test PostgreSQL connection FIRST before downloading large files
if not await test_pg_connection():
error_msg = "Cannot connect to PostgreSQL database"
print(f"Aborting: {error_msg}")
if notify:
await notify_failure(
error_msg,
"Connection Test",
f"Host: {PG_CONFIG['host']}:{PG_CONFIG['port']}",
)
return None
# Get latest dump info
dump_info = await fetch_latest_dump_info()
if not dump_info:
error_msg = "Could not fetch latest dump info from lrclib.net"
print(error_msg)
if notify:
await notify_failure(
error_msg, "Fetch Dump Info", f"URL: {LRCLIB_DUMPS_URL}"
)
return None
dump_date = dump_info["date"]
dump_date_str = dump_date.strftime("%Y-%m-%d %H:%M:%S")
# Check against last upsert
state = load_state()
last_upsert_str = state.get("last_dump_date") # Full datetime string
if last_upsert_str:
try:
# Try parsing with time first
last_date = datetime.strptime(last_upsert_str, "%Y-%m-%d %H:%M:%S")
except ValueError:
# Fall back to date only
last_date = datetime.strptime(last_upsert_str, "%Y-%m-%d")
print(f"Last upsert dump date: {last_upsert_str}")
if dump_date <= last_date and not force:
print(
f"No new dump available (latest: {dump_date_str}, last upsert: {last_upsert_str})"
)
print("Use --force to upsert anyway")
if notify:
await notify_no_update(dump_date_str, last_upsert_str)
return None
print(f"New dump available: {dump_date_str} > {last_upsert_str}")
else:
print("No previous upsert recorded")
if dry_run:
print(f"[DRY RUN] Would download and upsert {dump_info['filename']}")
return None
# Notify about new dump found
if notify:
await notify_new_dump_found(dump_info["filename"], dump_date_str)
# Download and extract
sqlite_path, download_error = await download_and_extract_dump(
dump_info["url"], dest_dir
)
if not sqlite_path:
error_msg = download_error or "Unknown download/extract error"
print(f"Failed: {error_msg}")
if notify:
await notify_failure(
error_msg, "Download/Extract", f"URL: {dump_info['url']}"
)
return None
# Keep files until we confirm import success to allow resume/inspection
gz_path = str(Path(sqlite_path).with_suffix(Path(sqlite_path).suffix + ".gz"))
print(f"Downloaded and extracted: {sqlite_path}")
print(f"Keeping source compressed file for resume/inspection: {gz_path}")
# Define SQLITE_DB_PATH dynamically after extraction (persist globally)
global SQLITE_DB_PATH
SQLITE_DB_PATH = sqlite_path # Set to the extracted SQLite file path
print(f"SQLite database path set to: {SQLITE_DB_PATH}")
upsert_successful = False
try:
# Perform upsert
report = await upsert_database(
sqlite_path=sqlite_path,
tables=None,
dry_run=False,
verbose=verbose,
dump_date=dump_date_str,
)
# If we had no errors and at least one table was considered, determine if any rows were upserted
if report.total_errors == 0 or len(report.successful_tables) > 0:
# Determine if any table actually had rows inserted or updated
any_upserted = any(
(t.status == "success" and (t.rows_inserted > 0 or t.rows_updated > 0))
for t in report.tables
)
# Mark overall upsert success (only true if rows were actually upserted)
upsert_successful = any_upserted
if upsert_successful:
# Save state only when an actual upsert occurred to avoid skipping future attempts
state["last_dump_date"] = dump_date_str # Date from the dump filename
state["last_upsert_time"] = datetime.now().isoformat() # When we ran
state["last_dump_url"] = dump_info["url"]
state["last_dump_filename"] = dump_info["filename"]
save_state(state)
print(f"\nState saved: last_dump_date = {dump_date_str}")
if notify:
if any_upserted:
await notify_success(report)
else:
# All tables skipped or empty (no actual upsert)
await discord_notify(
title="⚠️ LRCLib DB Update - No Data Upserted",
description="Upsert completed, but no rows were inserted or updated. All tables may have been skipped due to missing primary keys or were empty. The dump files have been retained for inspection/resume.",
color=DiscordColor.WARNING,
fields=[
{
"name": "Tables Processed",
"value": str(len(report.tables)),
"inline": True,
},
{
"name": "Successful Tables",
"value": str(len(report.successful_tables)),
"inline": True,
},
{
"name": "Rows Inserted",
"value": str(report.total_rows_inserted),
"inline": True,
},
{
"name": "Rows Updated",
"value": str(report.total_rows_updated),
"inline": True,
},
],
footer=f"Source: {report.sqlite_source}",
)
else:
# All tables failed
if notify:
failed_tables = ", ".join(report.failed_tables[:5])
await notify_failure(
"All tables failed to upsert",
"Upsert",
f"Failed tables: {failed_tables}",
)
return report
except Exception as e:
error_msg = str(e)
print(f"Upsert error: {error_msg}")
if notify:
import traceback
tb = traceback.format_exc()
await notify_failure(error_msg, "Upsert", tb[:1000])
raise
finally:
# Only remove files when a successful upsert occurred. Keep them otherwise for resume/debug.
if upsert_successful:
# Clean up downloaded SQLite file
if sqlite_path and os.path.exists(sqlite_path):
print(f"Cleaning up SQLite dump (success): {sqlite_path}")
os.unlink(sqlite_path)
# Also remove the .gz file
if sqlite_path:
gz_path = sqlite_path + ".gz"
if os.path.exists(gz_path):
print(f"Cleaning up compressed file (success): {gz_path}")
os.unlink(gz_path)
else:
# Keep files for inspection/resume
if sqlite_path and os.path.exists(sqlite_path):
print(f"Retaining SQLite dump for inspection/resume: {sqlite_path}")
if sqlite_path:
gz_path = sqlite_path + ".gz"
if os.path.exists(gz_path):
print(f"Retaining compressed file for inspection/resume: {gz_path}")
async def create_new_schema(pg_conn: asyncpg.Connection, schema_name: str) -> None:
"""Create a new schema for the migration."""
await pg_conn.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}")
async def migrate_to_new_schema(
sqlite_conn: aiosqlite.Connection,
pg_conn: asyncpg.Connection,
schema_name: str,
table_name: str,
columns: List[str],
) -> None:
"""Migrate data to a new schema."""
staging_table = f"{schema_name}.{table_name}"
# Drop and recreate staging table in the new schema
await pg_conn.execute(f"DROP TABLE IF EXISTS {staging_table}")
await pg_conn.execute(
f"CREATE TABLE {staging_table} (LIKE {table_name} INCLUDING ALL)"
)
# Import data into the staging table
sqlite_cursor = await sqlite_conn.execute(
f"SELECT {', '.join(columns)} FROM {table_name}"
)
rows = await sqlite_cursor.fetchall()
await sqlite_cursor.close()
# Convert rows to a list to ensure compatibility with len()
rows = list(rows)
print(f"Imported {len(rows)} rows into {staging_table}")
placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
insert_sql = f"INSERT INTO {staging_table} ({', '.join([f'"{col}"' for col in columns])}) VALUES ({placeholders})"
await pg_conn.executemany(insert_sql, rows)
print(f"Imported {len(rows)} rows into {staging_table}")
async def swap_schemas(pg_conn: asyncpg.Connection, new_schema: str) -> None:
"""Swap the new schema with the public schema."""
async with pg_conn.transaction():
# Rename public schema to a backup schema
await pg_conn.execute("ALTER SCHEMA public RENAME TO backup")
# Rename new schema to public
await pg_conn.execute(f"ALTER SCHEMA {new_schema} RENAME TO public")
# Drop the backup schema
await pg_conn.execute("DROP SCHEMA backup CASCADE")
print(f"Swapped schema {new_schema} with public")
async def migrate_database(sqlite_path: Optional[str] = None):
"""Main migration function.
If `sqlite_path` is provided, use it. Otherwise fall back to the global `SQLITE_DB_PATH`.
"""
path = sqlite_path or SQLITE_DB_PATH
if not path:
raise ValueError(
"No SQLite path provided. Call migrate_database(sqlite_path=...) or set SQLITE_DB_PATH before invoking."
)
sqlite_conn = await aiosqlite.connect(path)
pg_conn = await asyncpg.connect(**PG_CONFIG)
new_schema = "staging"
try:
# Create a new schema for the migration
await create_new_schema(pg_conn, new_schema)
sqlite_tables = await sqlite_conn.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
)
tables = [row[0] for row in await sqlite_tables.fetchall()]
await sqlite_tables.close()
for table_name in tables:
sqlite_cursor = await sqlite_conn.execute(
f"PRAGMA table_info({table_name})"
)
columns = [row[1] for row in await sqlite_cursor.fetchall()]
await sqlite_cursor.close()
# Migrate data to the new schema
await migrate_to_new_schema(
sqlite_conn, pg_conn, new_schema, table_name, columns
)
# Swap schemas after successful migration
await swap_schemas(pg_conn, new_schema)
finally:
await sqlite_conn.close()
await pg_conn.close()
# Ensure SQLITE_DB_PATH is defined globally before usage
SQLITE_DB_PATH = None # Initialize as None to avoid NameError
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="LRCLib DB upsert / migration utility")
parser.add_argument(
"--migrate",
action="store_true",
help="Migrate a local SQLite file into a new staging schema and swap",
)
parser.add_argument(
"--sqlite",
type=str,
help="Path to local SQLite file to migrate (required for --migrate)",
)
parser.add_argument(
"--check",
action="store_true",
help="Check lrclib.net for latest dump and upsert if newer",
)
parser.add_argument(
"--force", action="store_true", help="Force upsert even if dump is not newer"
)
parser.add_argument(
"--dest", type=str, help="Destination directory for downloaded dumps"
)
parser.add_argument(
"--dry-run", action="store_true", help="Do not perform writes; just simulate"
)
parser.add_argument("--verbose", action="store_true", help="Verbose output")
parser.add_argument(
"--notify", action="store_true", help="Send Discord notifications"
)
args = parser.parse_args()
if args.migrate:
if not args.sqlite:
parser.error("--migrate requires --sqlite PATH")
print(f"Starting schema migration from local SQLite: {args.sqlite}")
asyncio.run(migrate_database(sqlite_path=args.sqlite))
else:
# Default behavior: check lrclib.net and upsert if needed
print("Checking for latest dump and performing upsert (if applicable)...")
asyncio.run(
check_and_fetch_latest(
force=args.force,
dest_dir=args.dest,
dry_run=args.dry_run,
verbose=args.verbose,
notify=args.notify,
)
)