2104 lines
74 KiB
Python
2104 lines
74 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
Upsert SQLite dump data into PostgreSQL (Async Version).
|
|||
|
|
|
|||
|
|
This script handles incremental updates from SQLite dumps to PostgreSQL:
|
|||
|
|
- New rows are inserted
|
|||
|
|
- Existing rows are updated if changed
|
|||
|
|
- Uses primary keys and unique constraints to identify duplicates
|
|||
|
|
- Can automatically fetch new dumps from lrclib.net
|
|||
|
|
- Generates detailed change reports
|
|||
|
|
- Async with asyncpg for high performance with large datasets (17GB+)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
import aiosqlite
|
|||
|
|
import asyncpg
|
|||
|
|
import os
|
|||
|
|
import gzip
|
|||
|
|
import re
|
|||
|
|
import json
|
|||
|
|
import aiohttp
|
|||
|
|
import aiofiles
|
|||
|
|
import subprocess
|
|||
|
|
from typing import List, Dict, Any, Optional
|
|||
|
|
from datetime import datetime
|
|||
|
|
from pathlib import Path
|
|||
|
|
from dataclasses import dataclass, field, asdict
|
|||
|
|
from datetime import timezone
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
from playwright.async_api import async_playwright
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from playwright.async_api import async_playwright
|
|||
|
|
|
|||
|
|
PLAYWRIGHT_AVAILABLE = True
|
|||
|
|
except ImportError:
|
|||
|
|
PLAYWRIGHT_AVAILABLE = False
|
|||
|
|
|
|||
|
|
load_dotenv()
|
|||
|
|
|
|||
|
|
# Discord webhook for notifications
|
|||
|
|
DISCORD_WEBHOOK_URL = "https://discord.com/api/webhooks/1332864106750939168/6y6MgFhHLX0-BnL2M3hXnt2vJsue7Q2Duf_HjZenHNlNj7sxQr4lqxrPVJnJWf7KVAm2"
|
|||
|
|
|
|||
|
|
# Configuration - using environment variables
|
|||
|
|
PG_CONFIG = {
|
|||
|
|
"host": os.getenv("POSTGRES_HOST", "localhost"),
|
|||
|
|
"port": int(os.getenv("POSTGRES_PORT", "5432")),
|
|||
|
|
"database": os.getenv("POSTGRES_DB", "lrclib"),
|
|||
|
|
"user": os.getenv("POSTGRES_USER", "api"),
|
|||
|
|
"password": os.getenv("POSTGRES_PASSWORD", ""),
|
|||
|
|
}
|
|||
|
|
CHUNK_SIZE = 10000 # Process rows in chunks (larger for async)
|
|||
|
|
BATCH_SIZE = 1000 # Rows per INSERT statement
|
|||
|
|
|
|||
|
|
# LRCLib dump URL
|
|||
|
|
LRCLIB_DUMPS_URL = "https://lrclib.net/db-dumps"
|
|||
|
|
STATE_FILE = Path(__file__).parent / ".lrclib_upsert_state.json"
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class TableReport:
|
|||
|
|
"""Report for a single table's upsert operation."""
|
|||
|
|
|
|||
|
|
table_name: str
|
|||
|
|
sqlite_rows: int = 0
|
|||
|
|
pg_rows_before: int = 0
|
|||
|
|
pg_rows_after: int = 0
|
|||
|
|
rows_affected: int = 0
|
|||
|
|
errors: int = 0
|
|||
|
|
duration_seconds: float = 0.0
|
|||
|
|
status: str = "pending"
|
|||
|
|
error_message: str = ""
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def rows_inserted(self) -> int:
|
|||
|
|
return max(0, self.pg_rows_after - self.pg_rows_before)
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def rows_updated(self) -> int:
|
|||
|
|
return max(0, self.rows_affected - self.rows_inserted)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class UpsertReport:
|
|||
|
|
"""Complete report for an upsert operation."""
|
|||
|
|
|
|||
|
|
sqlite_source: str
|
|||
|
|
dump_date: Optional[str] = None
|
|||
|
|
start_time: Optional[datetime] = None
|
|||
|
|
end_time: Optional[datetime] = None
|
|||
|
|
tables: List[TableReport] = field(default_factory=list)
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def total_sqlite_rows(self) -> int:
|
|||
|
|
return sum(t.sqlite_rows for t in self.tables)
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def total_rows_affected(self) -> int:
|
|||
|
|
return sum(t.rows_affected for t in self.tables)
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def total_rows_inserted(self) -> int:
|
|||
|
|
return sum(t.rows_inserted for t in self.tables)
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def total_rows_updated(self) -> int:
|
|||
|
|
return sum(t.rows_updated for t in self.tables)
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def total_errors(self) -> int:
|
|||
|
|
return sum(t.errors for t in self.tables)
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def duration_seconds(self) -> float:
|
|||
|
|
if self.start_time and self.end_time:
|
|||
|
|
return (self.end_time - self.start_time).total_seconds()
|
|||
|
|
return 0.0
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def successful_tables(self) -> List[str]:
|
|||
|
|
return [t.table_name for t in self.tables if t.status == "success"]
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def failed_tables(self) -> List[str]:
|
|||
|
|
return [t.table_name for t in self.tables if t.status == "failed"]
|
|||
|
|
|
|||
|
|
def generate_summary(self) -> str:
|
|||
|
|
"""Generate a human-readable summary report."""
|
|||
|
|
lines = [
|
|||
|
|
"=" * 60,
|
|||
|
|
"LRCLIB UPSERT REPORT",
|
|||
|
|
"=" * 60,
|
|||
|
|
f"Source: {self.sqlite_source}",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
if self.dump_date:
|
|||
|
|
lines.append(f"Dump Date: {self.dump_date}")
|
|||
|
|
|
|||
|
|
if self.start_time:
|
|||
|
|
lines.append(f"Started: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
|||
|
|
if self.end_time:
|
|||
|
|
lines.append(f"Finished: {self.end_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
|||
|
|
|
|||
|
|
lines.append(f"Duration: {self.duration_seconds:.1f} seconds")
|
|||
|
|
lines.append("")
|
|||
|
|
lines.append("-" * 60)
|
|||
|
|
lines.append("TABLE DETAILS")
|
|||
|
|
lines.append("-" * 60)
|
|||
|
|
|
|||
|
|
for table in self.tables:
|
|||
|
|
status_icon = (
|
|||
|
|
"✓"
|
|||
|
|
if table.status == "success"
|
|||
|
|
else "✗" if table.status == "failed" else "○"
|
|||
|
|
)
|
|||
|
|
lines.append(f"\n{status_icon} {table.table_name}")
|
|||
|
|
lines.append(f" SQLite rows: {table.sqlite_rows:>12,}")
|
|||
|
|
lines.append(f" PG before: {table.pg_rows_before:>12,}")
|
|||
|
|
lines.append(f" PG after: {table.pg_rows_after:>12,}")
|
|||
|
|
lines.append(f" Rows inserted: {table.rows_inserted:>12,}")
|
|||
|
|
lines.append(f" Rows updated: {table.rows_updated:>12,}")
|
|||
|
|
lines.append(f" Duration: {table.duration_seconds:>12.1f}s")
|
|||
|
|
if table.errors > 0:
|
|||
|
|
lines.append(f" Errors: {table.errors:>12}")
|
|||
|
|
if table.error_message:
|
|||
|
|
lines.append(f" Error: {table.error_message}")
|
|||
|
|
|
|||
|
|
lines.append("")
|
|||
|
|
lines.append("-" * 60)
|
|||
|
|
lines.append("SUMMARY")
|
|||
|
|
lines.append("-" * 60)
|
|||
|
|
lines.append(f"Tables processed: {len(self.tables)}")
|
|||
|
|
lines.append(f" Successful: {len(self.successful_tables)}")
|
|||
|
|
lines.append(f" Failed: {len(self.failed_tables)}")
|
|||
|
|
lines.append(f"Total SQLite rows: {self.total_sqlite_rows:,}")
|
|||
|
|
lines.append(f"Total inserted: {self.total_rows_inserted:,}")
|
|||
|
|
lines.append(f"Total updated: {self.total_rows_updated:,}")
|
|||
|
|
lines.append(f"Total errors: {self.total_errors}")
|
|||
|
|
lines.append("=" * 60)
|
|||
|
|
|
|||
|
|
return "\n".join(lines)
|
|||
|
|
|
|||
|
|
def to_json(self) -> str:
|
|||
|
|
"""Export report as JSON."""
|
|||
|
|
data = {
|
|||
|
|
"sqlite_source": self.sqlite_source,
|
|||
|
|
"dump_date": self.dump_date,
|
|||
|
|
"start_time": self.start_time.isoformat() if self.start_time else None,
|
|||
|
|
"end_time": self.end_time.isoformat() if self.end_time else None,
|
|||
|
|
"duration_seconds": self.duration_seconds,
|
|||
|
|
"summary": {
|
|||
|
|
"total_sqlite_rows": self.total_sqlite_rows,
|
|||
|
|
"total_rows_inserted": self.total_rows_inserted,
|
|||
|
|
"total_rows_updated": self.total_rows_updated,
|
|||
|
|
"total_errors": self.total_errors,
|
|||
|
|
"successful_tables": self.successful_tables,
|
|||
|
|
"failed_tables": self.failed_tables,
|
|||
|
|
},
|
|||
|
|
"tables": [asdict(t) for t in self.tables],
|
|||
|
|
}
|
|||
|
|
return json.dumps(data, indent=2)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_state() -> Dict[str, Any]:
|
|||
|
|
"""Load the last upsert state from file."""
|
|||
|
|
if STATE_FILE.exists():
|
|||
|
|
try:
|
|||
|
|
return json.loads(STATE_FILE.read_text())
|
|||
|
|
except (json.JSONDecodeError, IOError):
|
|||
|
|
pass
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_state(state: Dict[str, Any]) -> None:
|
|||
|
|
"""Save the upsert state to file."""
|
|||
|
|
STATE_FILE.write_text(json.dumps(state, indent=2))
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Discord notification colors
|
|||
|
|
class DiscordColor:
|
|||
|
|
INFO = 0x3498DB # Blue
|
|||
|
|
SUCCESS = 0x2ECC71 # Green
|
|||
|
|
WARNING = 0xF39C12 # Orange
|
|||
|
|
ERROR = 0xE74C3C # Red
|
|||
|
|
NEUTRAL = 0x95A5A6 # Gray
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def discord_notify(
|
|||
|
|
title: str,
|
|||
|
|
description: str,
|
|||
|
|
color: int = DiscordColor.INFO,
|
|||
|
|
fields: Optional[List[Dict[str, Any]]] = None,
|
|||
|
|
footer: Optional[str] = None,
|
|||
|
|
) -> bool:
|
|||
|
|
"""
|
|||
|
|
Send a Discord webhook notification.
|
|||
|
|
Returns True on success, False on failure.
|
|||
|
|
"""
|
|||
|
|
embed = {
|
|||
|
|
"title": title,
|
|||
|
|
"description": description[:2000] if description else "",
|
|||
|
|
"color": color,
|
|||
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if fields:
|
|||
|
|
embed["fields"] = fields[:25] # Discord limit
|
|||
|
|
|
|||
|
|
if footer:
|
|||
|
|
embed["footer"] = {"text": footer[:2048]}
|
|||
|
|
|
|||
|
|
payload = {"embeds": [embed]}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
async with aiohttp.ClientSession() as session:
|
|||
|
|
async with session.post(
|
|||
|
|
DISCORD_WEBHOOK_URL,
|
|||
|
|
json=payload,
|
|||
|
|
timeout=aiohttp.ClientTimeout(total=15),
|
|||
|
|
) as resp:
|
|||
|
|
if resp.status >= 400:
|
|||
|
|
text = await resp.text()
|
|||
|
|
print(f"Discord webhook failed ({resp.status}): {text}")
|
|||
|
|
return False
|
|||
|
|
return True
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Discord notification error: {e}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def notify_start() -> None:
|
|||
|
|
"""Send notification that the weekly job is starting."""
|
|||
|
|
await discord_notify(
|
|||
|
|
title="🔄 LRCLib DB Update - Starting",
|
|||
|
|
description="Weekly database sync job has started.",
|
|||
|
|
color=DiscordColor.INFO,
|
|||
|
|
fields=[
|
|||
|
|
{
|
|||
|
|
"name": "Time",
|
|||
|
|
"value": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|||
|
|
"inline": True,
|
|||
|
|
},
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def notify_no_update(latest_dump: str, last_upsert: str) -> None:
|
|||
|
|
"""Send notification that no new dump was found."""
|
|||
|
|
await discord_notify(
|
|||
|
|
title="ℹ️ LRCLib DB Update - No Update Needed",
|
|||
|
|
description="No newer database dump available.",
|
|||
|
|
color=DiscordColor.NEUTRAL,
|
|||
|
|
fields=[
|
|||
|
|
{"name": "Latest Available", "value": latest_dump, "inline": True},
|
|||
|
|
{"name": "Last Synced", "value": last_upsert, "inline": True},
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def notify_new_dump_found(dump_filename: str, dump_date: str) -> None:
|
|||
|
|
"""Send notification that a new dump was found and download is starting."""
|
|||
|
|
await discord_notify(
|
|||
|
|
title="📥 LRCLib DB Update - New Dump Found",
|
|||
|
|
description="Downloading and processing new database dump.",
|
|||
|
|
color=DiscordColor.INFO,
|
|||
|
|
fields=[
|
|||
|
|
{"name": "Filename", "value": dump_filename, "inline": False},
|
|||
|
|
{"name": "Dump Date", "value": dump_date, "inline": True},
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def notify_success(report: "UpsertReport") -> None:
|
|||
|
|
"""Send notification for successful upsert."""
|
|||
|
|
duration_min = report.duration_seconds / 60
|
|||
|
|
|
|||
|
|
# Build table summary
|
|||
|
|
table_summary = []
|
|||
|
|
for t in report.tables[:10]: # Limit to first 10 tables
|
|||
|
|
status = "✓" if t.status == "success" else "✗" if t.status == "failed" else "○"
|
|||
|
|
table_summary.append(f"{status} {t.table_name}: +{t.rows_inserted:,}")
|
|||
|
|
|
|||
|
|
if len(report.tables) > 10:
|
|||
|
|
table_summary.append(f"... and {len(report.tables) - 10} more tables")
|
|||
|
|
|
|||
|
|
await discord_notify(
|
|||
|
|
title="✅ LRCLib DB Update - Success",
|
|||
|
|
description="Database sync completed successfully.",
|
|||
|
|
color=DiscordColor.SUCCESS,
|
|||
|
|
fields=[
|
|||
|
|
{"name": "Dump Date", "value": report.dump_date or "N/A", "inline": True},
|
|||
|
|
{"name": "Duration", "value": f"{duration_min:.1f} min", "inline": True},
|
|||
|
|
{"name": "Tables", "value": str(len(report.tables)), "inline": True},
|
|||
|
|
{
|
|||
|
|
"name": "Rows Inserted",
|
|||
|
|
"value": f"{report.total_rows_inserted:,}",
|
|||
|
|
"inline": True,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "Rows Updated",
|
|||
|
|
"value": f"{report.total_rows_updated:,}",
|
|||
|
|
"inline": True,
|
|||
|
|
},
|
|||
|
|
{"name": "Errors", "value": str(report.total_errors), "inline": True},
|
|||
|
|
{
|
|||
|
|
"name": "Table Summary",
|
|||
|
|
"value": "\n".join(table_summary) or "No tables",
|
|||
|
|
"inline": False,
|
|||
|
|
},
|
|||
|
|
],
|
|||
|
|
footer=f"Source: {report.sqlite_source}",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def notify_failure(
|
|||
|
|
error_message: str, stage: str, details: Optional[str] = None
|
|||
|
|
) -> None:
|
|||
|
|
"""Send notification for failed upsert."""
|
|||
|
|
description = (
|
|||
|
|
f"Database sync failed during: **{stage}**\n\n```{error_message[:1500]}```"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
fields = [
|
|||
|
|
{"name": "Stage", "value": stage, "inline": True},
|
|||
|
|
{
|
|||
|
|
"name": "Time",
|
|||
|
|
"value": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|||
|
|
"inline": True,
|
|||
|
|
},
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
if details:
|
|||
|
|
fields.append({"name": "Details", "value": details[:1000], "inline": False})
|
|||
|
|
|
|||
|
|
await discord_notify(
|
|||
|
|
title="❌ LRCLib DB Update - Failed",
|
|||
|
|
description=description,
|
|||
|
|
color=DiscordColor.ERROR,
|
|||
|
|
fields=fields,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def parse_dump_date(filename: str) -> Optional[datetime]:
|
|||
|
|
"""
|
|||
|
|
Extract date from dump filename.
|
|||
|
|
Supports formats:
|
|||
|
|
- lrclib-db-dump-20260122T091251Z.sqlite3.gz (ISO format with time)
|
|||
|
|
- lrclib-db-dump-2025-01-20.sqlite3.gz (date only)
|
|||
|
|
"""
|
|||
|
|
# Try ISO format first: 20260122T091251Z
|
|||
|
|
match = re.search(r"(\d{4})(\d{2})(\d{2})T(\d{2})(\d{2})(\d{2})Z", filename)
|
|||
|
|
if match:
|
|||
|
|
return datetime(
|
|||
|
|
int(match.group(1)), # year
|
|||
|
|
int(match.group(2)), # month
|
|||
|
|
int(match.group(3)), # day
|
|||
|
|
int(match.group(4)), # hour
|
|||
|
|
int(match.group(5)), # minute
|
|||
|
|
int(match.group(6)), # second
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Try simple date format: 2025-01-20
|
|||
|
|
match = re.search(r"(\d{4})-(\d{2})-(\d{2})", filename)
|
|||
|
|
if match:
|
|||
|
|
return datetime.strptime(match.group(0), "%Y-%m-%d")
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def fetch_latest_dump_info() -> Optional[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
Fetch the latest dump info from lrclib.net/db-dumps.
|
|||
|
|
Uses Playwright to render the JS-generated page content.
|
|||
|
|
Returns dict with 'url', 'filename', 'date' or None if not found.
|
|||
|
|
"""
|
|||
|
|
print(f"Fetching dump list from {LRCLIB_DUMPS_URL}...")
|
|||
|
|
|
|||
|
|
if not PLAYWRIGHT_AVAILABLE:
|
|||
|
|
print("Error: playwright is required for fetching dump info.")
|
|||
|
|
print("Install with: pip install playwright && playwright install chromium")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
html = None
|
|||
|
|
try:
|
|||
|
|
async with async_playwright() as p:
|
|||
|
|
# Launch headless browser
|
|||
|
|
browser = await p.chromium.launch(headless=True)
|
|||
|
|
page = await browser.new_page()
|
|||
|
|
|
|||
|
|
# Navigate and wait for content to load
|
|||
|
|
await page.goto(LRCLIB_DUMPS_URL, wait_until="networkidle")
|
|||
|
|
|
|||
|
|
# Wait for anchor tags with .sqlite3.gz links to appear
|
|||
|
|
try:
|
|||
|
|
await page.wait_for_selector('a[href*=".sqlite3.gz"]', timeout=15000)
|
|||
|
|
except Exception:
|
|||
|
|
print("Warning: Timed out waiting for download links, trying anyway...")
|
|||
|
|
|
|||
|
|
# Get rendered HTML
|
|||
|
|
html = await page.content()
|
|||
|
|
await browser.close()
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Error rendering page with Playwright: {e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
if not html:
|
|||
|
|
print("Failed to get page content")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# Find all .sqlite3.gz links
|
|||
|
|
pattern = r'href=["\']([^"\']*\.sqlite3\.gz)["\']'
|
|||
|
|
matches = re.findall(pattern, html, re.IGNORECASE)
|
|||
|
|
|
|||
|
|
# Deduplicate matches (same link may appear multiple times in HTML)
|
|||
|
|
matches = list(dict.fromkeys(matches))
|
|||
|
|
|
|||
|
|
if not matches:
|
|||
|
|
print("No .sqlite3.gz files found on the page")
|
|||
|
|
# Debug: show what we got
|
|||
|
|
print(f"Page content length: {len(html)} chars")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
print(f"Found {len(matches)} dump file(s)")
|
|||
|
|
|
|||
|
|
# Parse dates and find the newest
|
|||
|
|
dumps = []
|
|||
|
|
for match in matches:
|
|||
|
|
# Handle relative URLs
|
|||
|
|
if match.startswith("/"):
|
|||
|
|
url = f"https://lrclib.net{match}"
|
|||
|
|
elif not match.startswith("http"):
|
|||
|
|
url = f"https://lrclib.net/db-dumps/{match}"
|
|||
|
|
else:
|
|||
|
|
url = match
|
|||
|
|
|
|||
|
|
filename = url.split("/")[-1]
|
|||
|
|
date = parse_dump_date(filename)
|
|||
|
|
|
|||
|
|
if date:
|
|||
|
|
dumps.append({"url": url, "filename": filename, "date": date})
|
|||
|
|
print(f" Found: {filename} ({date.strftime('%Y-%m-%d %H:%M:%S')})")
|
|||
|
|
|
|||
|
|
if not dumps:
|
|||
|
|
print("Could not parse dates from dump filenames")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# Sort by date descending and return the newest
|
|||
|
|
from datetime import datetime
|
|||
|
|
|
|||
|
|
def parse_date_safe(date_obj):
|
|||
|
|
if isinstance(date_obj, datetime):
|
|||
|
|
return date_obj
|
|||
|
|
return datetime.min # Default to a minimum date if parsing fails
|
|||
|
|
|
|||
|
|
dumps.sort(key=lambda x: parse_date_safe(x.get("date", datetime.min)), reverse=True)
|
|||
|
|
latest = dumps[0]
|
|||
|
|
|
|||
|
|
# Handle type issue with strftime
|
|||
|
|
if isinstance(latest["date"], datetime):
|
|||
|
|
formatted_date = latest["date"].strftime("%Y-%m-%d %H:%M:%S")
|
|||
|
|
else:
|
|||
|
|
formatted_date = "Unknown date"
|
|||
|
|
|
|||
|
|
print(f"Latest dump: {latest['filename']} ({formatted_date})")
|
|||
|
|
return latest
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def download_and_extract_dump(
|
|||
|
|
url: str, dest_dir: Optional[str] = None, max_retries: int = 5
|
|||
|
|
) -> tuple[Optional[str], Optional[str]]:
|
|||
|
|
"""
|
|||
|
|
Download a .sqlite3.gz file and extract it with streaming.
|
|||
|
|
Supports resume on connection failures.
|
|||
|
|
Returns tuple of (path to extracted .sqlite3 file, error message).
|
|||
|
|
On success: (path, None). On failure: (None, error_message).
|
|||
|
|
"""
|
|||
|
|
filename = url.split("/")[-1]
|
|||
|
|
|
|||
|
|
if dest_dir:
|
|||
|
|
dest_path = Path(dest_dir)
|
|||
|
|
dest_path.mkdir(parents=True, exist_ok=True)
|
|||
|
|
else:
|
|||
|
|
dest_path = Path("/nvme/tmp")
|
|||
|
|
|
|||
|
|
gz_path = dest_path / filename
|
|||
|
|
sqlite_path = dest_path / filename.replace(".gz", "")
|
|||
|
|
|
|||
|
|
# If an extracted sqlite file already exists, skip download and extraction
|
|||
|
|
if sqlite_path.exists() and sqlite_path.stat().st_size > 0:
|
|||
|
|
print(f"Found existing extracted SQLite file {sqlite_path}; skipping download/extract")
|
|||
|
|
return str(sqlite_path), None
|
|||
|
|
|
|||
|
|
# Streaming download with retry and resume support
|
|||
|
|
print(f"Downloading {url}...")
|
|||
|
|
|
|||
|
|
for attempt in range(1, max_retries + 1):
|
|||
|
|
try:
|
|||
|
|
# Check if partial download exists for resume
|
|||
|
|
downloaded = 0
|
|||
|
|
headers = {}
|
|||
|
|
if gz_path.exists():
|
|||
|
|
downloaded = gz_path.stat().st_size
|
|||
|
|
headers["Range"] = f"bytes={downloaded}-"
|
|||
|
|
print(f" Resuming from {downloaded / (1024 * 1024):.1f} MB...")
|
|||
|
|
|
|||
|
|
timeout = aiohttp.ClientTimeout(
|
|||
|
|
total=None, # No total timeout for large files
|
|||
|
|
connect=60,
|
|||
|
|
sock_read=300, # 5 min read timeout per chunk
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async with aiohttp.ClientSession() as session:
|
|||
|
|
async with session.get(
|
|||
|
|
url, timeout=timeout, headers=headers
|
|||
|
|
) as response:
|
|||
|
|
# Handle resume response
|
|||
|
|
if response.status == 416: # Range not satisfiable - file complete
|
|||
|
|
print(" Download already complete.")
|
|||
|
|
break
|
|||
|
|
elif response.status == 206: # Partial content - resuming
|
|||
|
|
total = downloaded + int(
|
|||
|
|
response.headers.get("content-length", 0)
|
|||
|
|
)
|
|||
|
|
elif response.status == 200:
|
|||
|
|
# Server doesn't support resume or fresh download
|
|||
|
|
if downloaded > 0:
|
|||
|
|
print(" Server doesn't support resume, restarting...")
|
|||
|
|
downloaded = 0
|
|||
|
|
total = int(response.headers.get("content-length", 0))
|
|||
|
|
else:
|
|||
|
|
response.raise_for_status()
|
|||
|
|
total = int(response.headers.get("content-length", 0))
|
|||
|
|
|
|||
|
|
# Open in append mode if resuming, write mode if fresh
|
|||
|
|
mode = "ab" if downloaded > 0 and response.status == 206 else "wb"
|
|||
|
|
if mode == "wb":
|
|||
|
|
downloaded = 0
|
|||
|
|
|
|||
|
|
async with aiofiles.open(gz_path, mode) as f:
|
|||
|
|
async for chunk in response.content.iter_chunked(
|
|||
|
|
1024 * 1024
|
|||
|
|
): # 1MB chunks
|
|||
|
|
await f.write(chunk)
|
|||
|
|
downloaded += len(chunk)
|
|||
|
|
if total:
|
|||
|
|
pct = (downloaded / total) * 100
|
|||
|
|
mb_down = downloaded / (1024 * 1024)
|
|||
|
|
mb_total = total / (1024 * 1024)
|
|||
|
|
print(
|
|||
|
|
f" Downloaded: {mb_down:.1f} / {mb_total:.1f} MB ({pct:.1f}%)",
|
|||
|
|
end="\r",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print() # Newline after progress
|
|||
|
|
break # Success, exit retry loop
|
|||
|
|
|
|||
|
|
except (aiohttp.ClientError, ConnectionResetError, asyncio.TimeoutError) as e:
|
|||
|
|
print(f"\n Download error (attempt {attempt}/{max_retries}): {e}")
|
|||
|
|
if attempt < max_retries:
|
|||
|
|
wait_time = min(30 * attempt, 120) # Exponential backoff, max 2 min
|
|||
|
|
print(f" Retrying in {wait_time} seconds...")
|
|||
|
|
await asyncio.sleep(wait_time)
|
|||
|
|
else:
|
|||
|
|
error_msg = f"Download failed after {max_retries} attempts: {e}"
|
|||
|
|
print(f"Error: {error_msg}")
|
|||
|
|
# Clean up partial file on final failure
|
|||
|
|
if gz_path.exists():
|
|||
|
|
gz_path.unlink()
|
|||
|
|
return None, error_msg
|
|||
|
|
|
|||
|
|
# Verify download completed
|
|||
|
|
if not gz_path.exists() or gz_path.stat().st_size == 0:
|
|||
|
|
error_msg = "Download failed - file is empty or missing"
|
|||
|
|
print(f"Error: {error_msg}")
|
|||
|
|
return None, error_msg
|
|||
|
|
|
|||
|
|
# Extract (sync, but fast with streaming)
|
|||
|
|
print(
|
|||
|
|
f"Extracting {gz_path} ({gz_path.stat().st_size / (1024**3):.2f} GB compressed)..."
|
|||
|
|
)
|
|||
|
|
try:
|
|||
|
|
# Use streaming decompression for large files
|
|||
|
|
async with aiofiles.open(sqlite_path, "wb") as f_out:
|
|||
|
|
with gzip.open(gz_path, "rb") as f_in:
|
|||
|
|
extracted = 0
|
|||
|
|
while True:
|
|||
|
|
chunk = f_in.read(1024 * 1024 * 10) # 10MB chunks
|
|||
|
|
if not chunk:
|
|||
|
|
break
|
|||
|
|
await f_out.write(chunk)
|
|||
|
|
extracted += len(chunk)
|
|||
|
|
print(f" Extracted: {extracted / (1024**3):.2f} GB", end="\r")
|
|||
|
|
print()
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"Extraction failed: {e}"
|
|||
|
|
print(f"Error: {error_msg}")
|
|||
|
|
return None, error_msg
|
|||
|
|
|
|||
|
|
print(
|
|||
|
|
f"Extracted to: {sqlite_path} ({sqlite_path.stat().st_size / (1024**3):.2f} GB)"
|
|||
|
|
)
|
|||
|
|
# Note: do NOT remove the .gz file here. Keep source files until import completes successfully.
|
|||
|
|
return str(sqlite_path), None
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def extract_gz_file(gz_path: str, output_path: str) -> None:
|
|||
|
|
"""Extract a .gz file using pigz for multi-threaded decompression."""
|
|||
|
|
try:
|
|||
|
|
# Use pigz for faster decompression
|
|||
|
|
print(f"Extracting {gz_path} to {output_path} using pigz...")
|
|||
|
|
subprocess.run(
|
|||
|
|
["pigz", "-d", "-p", str(os.cpu_count()), "-c", gz_path],
|
|||
|
|
stdout=open(output_path, "wb"),
|
|||
|
|
check=True,
|
|||
|
|
)
|
|||
|
|
print(f"Extraction completed: {output_path}")
|
|||
|
|
except FileNotFoundError:
|
|||
|
|
raise RuntimeError(
|
|||
|
|
"pigz is not installed. Please install it for faster decompression."
|
|||
|
|
)
|
|||
|
|
except subprocess.CalledProcessError as e:
|
|||
|
|
raise RuntimeError(f"Extraction failed: {e}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def get_sqlite_tables(db: aiosqlite.Connection) -> List[str]:
|
|||
|
|
"""Get list of all tables in SQLite database."""
|
|||
|
|
async with db.execute(
|
|||
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
|
|||
|
|
) as cursor:
|
|||
|
|
rows = await cursor.fetchall()
|
|||
|
|
return [row[0] for row in rows]
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def get_sqlite_schema(
|
|||
|
|
db: aiosqlite.Connection, table_name: str
|
|||
|
|
) -> List[Dict[str, Any]]:
|
|||
|
|
"""Get column info from SQLite table."""
|
|||
|
|
async with db.execute(f"PRAGMA table_info({table_name});") as cursor:
|
|||
|
|
rows = await cursor.fetchall()
|
|||
|
|
return [
|
|||
|
|
{
|
|||
|
|
"cid": row[0],
|
|||
|
|
"name": row[1],
|
|||
|
|
"type": row[2],
|
|||
|
|
"notnull": row[3],
|
|||
|
|
"default": row[4],
|
|||
|
|
"pk": row[5],
|
|||
|
|
}
|
|||
|
|
for row in rows
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def get_pg_schema(pool: asyncpg.Pool, table_name: str) -> Dict[str, Any]:
|
|||
|
|
"""Get PostgreSQL table schema including primary key and unique constraints."""
|
|||
|
|
async with pool.acquire() as conn:
|
|||
|
|
# Get columns
|
|||
|
|
columns = await conn.fetch(
|
|||
|
|
"""
|
|||
|
|
SELECT column_name, data_type, is_nullable, column_default
|
|||
|
|
FROM information_schema.columns
|
|||
|
|
WHERE table_name = $1
|
|||
|
|
ORDER BY ordinal_position;
|
|||
|
|
""",
|
|||
|
|
table_name,
|
|||
|
|
)
|
|||
|
|
columns = [
|
|||
|
|
{"name": row["column_name"], "type": row["data_type"]} for row in columns
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# Get primary key columns
|
|||
|
|
pk_columns = await conn.fetch(
|
|||
|
|
"""
|
|||
|
|
SELECT a.attname
|
|||
|
|
FROM pg_index i
|
|||
|
|
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
|
|||
|
|
WHERE i.indrelid = $1::regclass
|
|||
|
|
AND i.indisprimary;
|
|||
|
|
""",
|
|||
|
|
table_name,
|
|||
|
|
)
|
|||
|
|
pk_columns = [row["attname"] for row in pk_columns]
|
|||
|
|
|
|||
|
|
# Get unique constraint columns
|
|||
|
|
unique_rows = await conn.fetch(
|
|||
|
|
"""
|
|||
|
|
SELECT a.attname, c.conname
|
|||
|
|
FROM pg_constraint c
|
|||
|
|
JOIN pg_attribute a ON a.attrelid = c.conrelid AND a.attnum = ANY(c.conkey)
|
|||
|
|
WHERE c.conrelid = $1::regclass
|
|||
|
|
AND c.contype = 'u'
|
|||
|
|
ORDER BY c.conname, a.attnum;
|
|||
|
|
""",
|
|||
|
|
table_name,
|
|||
|
|
)
|
|||
|
|
unique_constraints: Dict[str, List[str]] = {}
|
|||
|
|
for row in unique_rows:
|
|||
|
|
constraint_name = row["conname"]
|
|||
|
|
if constraint_name not in unique_constraints:
|
|||
|
|
unique_constraints[constraint_name] = []
|
|||
|
|
unique_constraints[constraint_name].append(row["attname"])
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"columns": columns,
|
|||
|
|
"pk_columns": pk_columns,
|
|||
|
|
"unique_constraints": unique_constraints,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def clean_value(value: Any) -> Any:
|
|||
|
|
"""Basic cleaning for values (strip NULs). Use more specific coercion in per-column handling."""
|
|||
|
|
if value is None:
|
|||
|
|
return None
|
|||
|
|
if isinstance(value, str):
|
|||
|
|
return value.replace("\x00", "")
|
|||
|
|
return value
|
|||
|
|
|
|||
|
|
|
|||
|
|
def parse_datetime_string(s: str) -> Optional[datetime]:
|
|||
|
|
"""Parse a datetime string into a timezone-aware datetime if possible."""
|
|||
|
|
if s is None:
|
|||
|
|
return None
|
|||
|
|
if isinstance(s, datetime):
|
|||
|
|
return s
|
|||
|
|
if not isinstance(s, str):
|
|||
|
|
return None
|
|||
|
|
try:
|
|||
|
|
# Normalize excessive fractional seconds (more than 6 digits) by truncating to 6
|
|||
|
|
s_norm = re.sub(r"\.(\d{6})\d+", r".\1", s)
|
|||
|
|
# Try fromisoformat (handles 'YYYY-MM-DD HH:MM:SS.mmm+00:00' and ISO variants)
|
|||
|
|
dt = datetime.fromisoformat(s_norm)
|
|||
|
|
# If naive, assume UTC
|
|||
|
|
if dt.tzinfo is None:
|
|||
|
|
from datetime import timezone
|
|||
|
|
|
|||
|
|
dt = dt.replace(tzinfo=timezone.utc)
|
|||
|
|
return dt
|
|||
|
|
except Exception:
|
|||
|
|
# Try common SQLite format
|
|||
|
|
for fmt in (
|
|||
|
|
"%Y-%m-%d %H:%M:%S.%f%z",
|
|||
|
|
"%Y-%m-%d %H:%M:%S.%f",
|
|||
|
|
"%Y-%m-%d %H:%M:%S%z",
|
|||
|
|
"%Y-%m-%d %H:%M:%S",
|
|||
|
|
):
|
|||
|
|
try:
|
|||
|
|
s_norm = re.sub(r"\.(\d{6})\d+", r".\1", s)
|
|||
|
|
dt = datetime.strptime(s_norm, fmt)
|
|||
|
|
if dt.tzinfo is None:
|
|||
|
|
from datetime import timezone
|
|||
|
|
|
|||
|
|
dt = dt.replace(tzinfo=timezone.utc)
|
|||
|
|
return dt
|
|||
|
|
except Exception:
|
|||
|
|
continue
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def coerce_value_for_pg(value: Any, pg_type: str) -> Any:
|
|||
|
|
"""Coerce a SQLite value into an appropriate Python type for asyncpg based on PostgreSQL data type."""
|
|||
|
|
if value is None:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
pg_type = (pg_type or "").lower()
|
|||
|
|
|
|||
|
|
# Boolean
|
|||
|
|
if "boolean" in pg_type:
|
|||
|
|
# SQLite often stores booleans as 0/1
|
|||
|
|
if isinstance(value, (int, float)):
|
|||
|
|
return bool(value)
|
|||
|
|
if isinstance(value, str):
|
|||
|
|
v = value.strip().lower()
|
|||
|
|
if v in ("1", "t", "true", "yes"):
|
|||
|
|
return True
|
|||
|
|
if v in ("0", "f", "false", "no"):
|
|||
|
|
return False
|
|||
|
|
if isinstance(value, bool):
|
|||
|
|
return value
|
|||
|
|
# Fallback
|
|||
|
|
return bool(value)
|
|||
|
|
|
|||
|
|
# Timestamps
|
|||
|
|
if "timestamp" in pg_type or "date" in pg_type or "time" in pg_type:
|
|||
|
|
if isinstance(value, datetime):
|
|||
|
|
return value
|
|||
|
|
if isinstance(value, str):
|
|||
|
|
dt = parse_datetime_string(value)
|
|||
|
|
if dt:
|
|||
|
|
return dt
|
|||
|
|
# If SQLite stores as numeric (unix epoch)
|
|||
|
|
if isinstance(value, (int, float)):
|
|||
|
|
try:
|
|||
|
|
return datetime.fromtimestamp(value)
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
return value
|
|||
|
|
|
|||
|
|
# JSON
|
|||
|
|
if "json" in pg_type:
|
|||
|
|
if isinstance(value, str):
|
|||
|
|
try:
|
|||
|
|
return json.loads(value)
|
|||
|
|
except Exception:
|
|||
|
|
return value
|
|||
|
|
return value
|
|||
|
|
|
|||
|
|
# Numeric integer types
|
|||
|
|
if pg_type in ("integer", "bigint", "smallint") or "int" in pg_type:
|
|||
|
|
if isinstance(value, int):
|
|||
|
|
return value
|
|||
|
|
if isinstance(value, float):
|
|||
|
|
return int(value)
|
|||
|
|
if isinstance(value, str) and value.isdigit():
|
|||
|
|
return int(value)
|
|||
|
|
try:
|
|||
|
|
return int(value)
|
|||
|
|
except Exception:
|
|||
|
|
return value
|
|||
|
|
|
|||
|
|
# Floating point
|
|||
|
|
if (
|
|||
|
|
pg_type in ("real", "double precision", "numeric", "decimal")
|
|||
|
|
or "numeric" in pg_type
|
|||
|
|
):
|
|||
|
|
if isinstance(value, (int, float)):
|
|||
|
|
return float(value)
|
|||
|
|
if isinstance(value, str):
|
|||
|
|
try:
|
|||
|
|
return float(value)
|
|||
|
|
except Exception:
|
|||
|
|
return value
|
|||
|
|
|
|||
|
|
# Text-like types: coerce to string if not None
|
|||
|
|
if any(k in pg_type for k in ("char", "text", "uuid")):
|
|||
|
|
try:
|
|||
|
|
return str(value)
|
|||
|
|
except Exception:
|
|||
|
|
return clean_value(value)
|
|||
|
|
|
|||
|
|
# Default: basic cleaning
|
|||
|
|
return clean_value(value)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Diagnostics directory for failing chunks
|
|||
|
|
DIAGNOSTICS_DIR = Path(__file__).parent / "diagnostics"
|
|||
|
|
DIAGNOSTICS_DIR.mkdir(exist_ok=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def save_failed_rows(
|
|||
|
|
dump_basename: str, table_name: str, offset: int, rows: List[tuple], error: str
|
|||
|
|
) -> str:
|
|||
|
|
"""Save failing rows as a gzipped JSONL file for inspection.
|
|||
|
|
|
|||
|
|
Returns path to diagnostics file.
|
|||
|
|
"""
|
|||
|
|
fname = f"{dump_basename}.{table_name}.offset{offset}.failed.jsonl.gz"
|
|||
|
|
path = DIAGNOSTICS_DIR / fname
|
|||
|
|
try:
|
|||
|
|
import gzip as _gzip
|
|||
|
|
|
|||
|
|
with _gzip.open(path, "wt", encoding="utf-8") as fh:
|
|||
|
|
for r in rows:
|
|||
|
|
record = {"row": r, "error": error}
|
|||
|
|
fh.write(json.dumps(record, default=str) + "\n")
|
|||
|
|
return str(path)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Failed to write diagnostics file: {e}")
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def attempt_insert_chunk(
|
|||
|
|
conn,
|
|||
|
|
upsert_sql: str,
|
|||
|
|
cleaned_rows: List[tuple],
|
|||
|
|
table_name: str,
|
|||
|
|
offset: int,
|
|||
|
|
dump_basename: str,
|
|||
|
|
batch_size: int = BATCH_SIZE,
|
|||
|
|
per_row_limit: int = 100,
|
|||
|
|
) -> tuple[int, int, Optional[str]]:
|
|||
|
|
"""Try inserting a chunk; on failure, split and try smaller batches then per-row.
|
|||
|
|
|
|||
|
|
Returns (successful_rows, errors, diagnostics_path_or_None).
|
|||
|
|
"""
|
|||
|
|
success = 0
|
|||
|
|
errors = 0
|
|||
|
|
diagnostics_path = None
|
|||
|
|
|
|||
|
|
# Fast path: whole chunk
|
|||
|
|
try:
|
|||
|
|
await conn.executemany(upsert_sql, cleaned_rows)
|
|||
|
|
return len(cleaned_rows), 0, None
|
|||
|
|
except Exception as e:
|
|||
|
|
# Fall through to batch splitting
|
|||
|
|
first_exc = str(e)
|
|||
|
|
|
|||
|
|
# Try batch splitting
|
|||
|
|
for i in range(0, len(cleaned_rows), batch_size):
|
|||
|
|
batch = cleaned_rows[i : i + batch_size]
|
|||
|
|
try:
|
|||
|
|
await conn.executemany(upsert_sql, batch)
|
|||
|
|
success += len(batch)
|
|||
|
|
continue
|
|||
|
|
except Exception:
|
|||
|
|
# Try per-row for this batch
|
|||
|
|
failing = []
|
|||
|
|
for j, row in enumerate(batch):
|
|||
|
|
try:
|
|||
|
|
await conn.execute(upsert_sql, *row)
|
|||
|
|
success += 1
|
|||
|
|
except Exception as re_exc:
|
|||
|
|
errors += 1
|
|||
|
|
failing.append((j + i, row, str(re_exc)))
|
|||
|
|
# If too many per-row failures, stop collecting
|
|||
|
|
if errors >= per_row_limit:
|
|||
|
|
break
|
|||
|
|
# If failures were recorded, save diagnostics
|
|||
|
|
if failing:
|
|||
|
|
# Prepare rows for saving (row index and values)
|
|||
|
|
rows_to_save = [r[1] for r in failing]
|
|||
|
|
diagnostics_path = await save_failed_rows(
|
|||
|
|
dump_basename, table_name, offset + i, rows_to_save, first_exc
|
|||
|
|
)
|
|||
|
|
# Continue to next batch
|
|||
|
|
|
|||
|
|
return success, errors, diagnostics_path
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def _quote_ident(name: str) -> str:
|
|||
|
|
"""Return a safely quoted SQL identifier."""
|
|||
|
|
return '"' + name.replace('"', '""') + '"'
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def fallback_upsert_using_temp(
|
|||
|
|
conn,
|
|||
|
|
table_name: str,
|
|||
|
|
common_columns: List[str],
|
|||
|
|
pk_columns: List[str],
|
|||
|
|
rows_iterable: List[tuple],
|
|||
|
|
dump_basename: str,
|
|||
|
|
) -> tuple[int, int, Optional[str]]:
|
|||
|
|
"""Fallback upsert that uses a temporary table and performs UPDATE then INSERT.
|
|||
|
|
|
|||
|
|
Returns (rows_upserted_count, errors_count, diagnostics_path_or_None).
|
|||
|
|
"""
|
|||
|
|
errors = 0
|
|||
|
|
diagnostics = None
|
|||
|
|
inserted = 0
|
|||
|
|
updated = 0
|
|||
|
|
|
|||
|
|
# Build identifiers
|
|||
|
|
tmp_name = f"tmp_upsert_{table_name}_{os.getpid()}" # per-process unique-ish
|
|||
|
|
q_table = await _quote_ident(table_name)
|
|||
|
|
q_tmp = await _quote_ident(tmp_name)
|
|||
|
|
col_names_str = ", ".join([f'"{c}"' for c in common_columns])
|
|||
|
|
|
|||
|
|
# Create temp table like target
|
|||
|
|
try:
|
|||
|
|
await conn.execute(f"CREATE TEMP TABLE {q_tmp} (LIKE {q_table} INCLUDING ALL)")
|
|||
|
|
except Exception as e:
|
|||
|
|
return (
|
|||
|
|
0,
|
|||
|
|
1,
|
|||
|
|
await save_failed_rows(
|
|||
|
|
dump_basename, table_name, 0, rows_iterable[:100], str(e)
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Insert into temp (in batches)
|
|||
|
|
try:
|
|||
|
|
# Insert all rows into temp
|
|||
|
|
insert_sql = f"INSERT INTO {q_tmp} ({col_names_str}) VALUES ({', '.join([f'${i + 1}' for i in range(len(common_columns))])})"
|
|||
|
|
for i in range(0, len(rows_iterable), BATCH_SIZE):
|
|||
|
|
batch = rows_iterable[i : i + BATCH_SIZE]
|
|||
|
|
try:
|
|||
|
|
await conn.executemany(insert_sql, batch)
|
|||
|
|
except Exception as e:
|
|||
|
|
# Save diagnostics and abort
|
|||
|
|
diagnostics = await save_failed_rows(
|
|||
|
|
dump_basename,
|
|||
|
|
table_name,
|
|||
|
|
i,
|
|||
|
|
[tuple(r) for r in batch[:100]],
|
|||
|
|
str(e),
|
|||
|
|
)
|
|||
|
|
errors += len(batch)
|
|||
|
|
return 0, errors, diagnostics
|
|||
|
|
|
|||
|
|
# Count tmp rows
|
|||
|
|
tmp_count = await conn.fetchval(f"SELECT COUNT(*) FROM {q_tmp}")
|
|||
|
|
|
|||
|
|
# Count matches in target (existing rows)
|
|||
|
|
pk_cond = " AND ".join(
|
|||
|
|
["t." + f'"{k}"' + " = tmp." + f'"{k}"' for k in pk_columns]
|
|||
|
|
)
|
|||
|
|
match_count = await conn.fetchval(
|
|||
|
|
f"SELECT COUNT(*) FROM {q_tmp} tmp JOIN {q_table} t ON {pk_cond}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Perform update of matching rows (only non-PK columns)
|
|||
|
|
update_cols = [c for c in common_columns if c not in pk_columns]
|
|||
|
|
if update_cols:
|
|||
|
|
set_clauses = ", ".join(
|
|||
|
|
["{q_table}." + f'"{c}"' + " = tmp." + f'"{c}"' for c in update_cols]
|
|||
|
|
)
|
|||
|
|
await conn.execute(
|
|||
|
|
f"UPDATE {q_table} SET {set_clauses} FROM {q_tmp} tmp WHERE {pk_cond}"
|
|||
|
|
)
|
|||
|
|
updated = match_count
|
|||
|
|
|
|||
|
|
# Insert non-existing rows
|
|||
|
|
insert_sql2 = f"INSERT INTO {q_table} ({col_names_str}) SELECT {col_names_str} FROM {q_tmp} tmp WHERE NOT EXISTS (SELECT 1 FROM {q_table} t WHERE {pk_cond})"
|
|||
|
|
await conn.execute(insert_sql2)
|
|||
|
|
|
|||
|
|
# Approximate inserted count
|
|||
|
|
inserted = tmp_count - match_count
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
diagnostics = await save_failed_rows(
|
|||
|
|
dump_basename, table_name, 0, rows_iterable[:100], str(e)
|
|||
|
|
)
|
|||
|
|
errors += 1
|
|||
|
|
finally:
|
|||
|
|
# Drop temp table
|
|||
|
|
try:
|
|||
|
|
await conn.execute(f"DROP TABLE IF EXISTS {q_tmp}")
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
return inserted + updated, errors, diagnostics
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def validate_sample_rows(
|
|||
|
|
sqlite_db: aiosqlite.Connection,
|
|||
|
|
pg_col_types: List[str],
|
|||
|
|
col_names_str: str,
|
|||
|
|
table_name: str,
|
|||
|
|
dump_basename: str,
|
|||
|
|
sample_n: int = 1000,
|
|||
|
|
) -> tuple[int, int, Optional[str]]:
|
|||
|
|
"""Validate a sample of rows by attempting to coerce values and detecting likely mismatches.
|
|||
|
|
|
|||
|
|
Returns (fail_count, total_checked, diagnostics_path_or_None).
|
|||
|
|
"""
|
|||
|
|
fail_count = 0
|
|||
|
|
total = 0
|
|||
|
|
failing_rows = []
|
|||
|
|
|
|||
|
|
async with sqlite_db.execute(
|
|||
|
|
f"SELECT {col_names_str} FROM {table_name} LIMIT {sample_n};"
|
|||
|
|
) as cursor:
|
|||
|
|
rows = await cursor.fetchall()
|
|||
|
|
|
|||
|
|
for i, row in enumerate(rows):
|
|||
|
|
total += 1
|
|||
|
|
# Coerce and check types
|
|||
|
|
for val, pg_type in zip(row, pg_col_types):
|
|||
|
|
coerced = coerce_value_for_pg(val, pg_type)
|
|||
|
|
# Heuristics for likely failure
|
|||
|
|
pg_type_l = (pg_type or "").lower()
|
|||
|
|
if "timestamp" in pg_type_l or "date" in pg_type_l or "time" in pg_type_l:
|
|||
|
|
if coerced is None or not isinstance(coerced, datetime):
|
|||
|
|
fail_count += 1
|
|||
|
|
failing_rows.append(row)
|
|||
|
|
break
|
|||
|
|
elif "boolean" in pg_type_l:
|
|||
|
|
if not isinstance(coerced, bool):
|
|||
|
|
fail_count += 1
|
|||
|
|
failing_rows.append(row)
|
|||
|
|
break
|
|||
|
|
elif any(k in pg_type_l for k in ("int", "bigint", "smallint")):
|
|||
|
|
if not isinstance(coerced, int):
|
|||
|
|
# allow numeric strings that can be int
|
|||
|
|
if not (isinstance(coerced, str) and coerced.isdigit()):
|
|||
|
|
fail_count += 1
|
|||
|
|
failing_rows.append(row)
|
|||
|
|
break
|
|||
|
|
# else: assume text/json ok
|
|||
|
|
|
|||
|
|
diag = None
|
|||
|
|
if failing_rows:
|
|||
|
|
rows_to_save = [tuple(r) for r in failing_rows[:100]]
|
|||
|
|
diag = await save_failed_rows(
|
|||
|
|
dump_basename,
|
|||
|
|
f"{table_name}.sample",
|
|||
|
|
0,
|
|||
|
|
rows_to_save,
|
|||
|
|
"Validation failures",
|
|||
|
|
)
|
|||
|
|
return fail_count, total, diag
|
|||
|
|
|
|||
|
|
|
|||
|
|
def determine_conflict_columns(
|
|||
|
|
pg_schema: Dict[str, Any],
|
|||
|
|
sqlite_schema: List[Dict[str, Any]],
|
|||
|
|
table_name: str,
|
|||
|
|
) -> List[str]:
|
|||
|
|
"""Determine which columns to use for conflict detection.
|
|||
|
|
|
|||
|
|
Priority:
|
|||
|
|
1. PostgreSQL primary key
|
|||
|
|
2. PostgreSQL unique constraint
|
|||
|
|
3. SQLite primary key (only if PostgreSQL has a matching unique/PK constraint)
|
|||
|
|
"""
|
|||
|
|
# 1) PostgreSQL primary key
|
|||
|
|
if pg_schema["pk_columns"]:
|
|||
|
|
return pg_schema["pk_columns"]
|
|||
|
|
|
|||
|
|
# 2) PostgreSQL unique constraint
|
|||
|
|
if pg_schema["unique_constraints"]:
|
|||
|
|
first_constraint = list(pg_schema["unique_constraints"].values())[0]
|
|||
|
|
return first_constraint
|
|||
|
|
|
|||
|
|
# 3) Consider SQLite primary key only if PostgreSQL has a matching constraint
|
|||
|
|
sqlite_pk_columns = [col["name"] for col in sqlite_schema if col["pk"] > 0]
|
|||
|
|
sqlite_pk_columns = [
|
|||
|
|
col["name"]
|
|||
|
|
for col in sorted(
|
|||
|
|
[c for c in sqlite_schema if c["pk"] > 0], key=lambda x: x["pk"]
|
|||
|
|
)
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
if sqlite_pk_columns:
|
|||
|
|
# Check if any PG unique constraint matches these columns (order-insensitive)
|
|||
|
|
for cols in pg_schema["unique_constraints"].values():
|
|||
|
|
if set(cols) == set(sqlite_pk_columns):
|
|||
|
|
return sqlite_pk_columns
|
|||
|
|
# Also accept if PG has a primary key covering same columns (already handled), so if we reach here
|
|||
|
|
# PG has no matching constraint — cannot safely perform ON CONFLICT based upsert
|
|||
|
|
raise ValueError(
|
|||
|
|
f"Table {table_name} has a SQLite primary key ({', '.join(sqlite_pk_columns)}) but no matching unique/PK constraint in PostgreSQL. Cannot perform upsert."
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Nothing found
|
|||
|
|
raise ValueError(
|
|||
|
|
f"Table {table_name} has no primary key or unique constraint in either PostgreSQL or SQLite. Cannot perform upsert."
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_upsert_sql(
|
|||
|
|
table_name: str,
|
|||
|
|
columns: List[str],
|
|||
|
|
conflict_columns: List[str],
|
|||
|
|
update_columns: Optional[List[str]] = None,
|
|||
|
|
) -> str:
|
|||
|
|
"""Build PostgreSQL upsert SQL with $1, $2, etc. placeholders."""
|
|||
|
|
col_names_str = ", ".join([f'"{col}"' for col in columns])
|
|||
|
|
placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
|
|||
|
|
conflict_str = ", ".join([f'"{col}"' for col in conflict_columns])
|
|||
|
|
|
|||
|
|
if update_columns is None:
|
|||
|
|
update_columns = [c for c in columns if c not in conflict_columns]
|
|||
|
|
|
|||
|
|
if not update_columns:
|
|||
|
|
return f"INSERT INTO {table_name} ({col_names_str}) VALUES ({placeholders}) ON CONFLICT ({conflict_str}) DO NOTHING"
|
|||
|
|
|
|||
|
|
set_clauses = [f'"{col}" = EXCLUDED."{col}"' for col in update_columns]
|
|||
|
|
set_str = ", ".join(set_clauses)
|
|||
|
|
|
|||
|
|
return f"INSERT INTO {table_name} ({col_names_str}) VALUES ({placeholders}) ON CONFLICT ({conflict_str}) DO UPDATE SET {set_str}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def upsert_table(
|
|||
|
|
sqlite_db: aiosqlite.Connection,
|
|||
|
|
pg_pool: asyncpg.Pool,
|
|||
|
|
table_name: str,
|
|||
|
|
dry_run: bool = False,
|
|||
|
|
verbose: bool = False,
|
|||
|
|
dump_basename: Optional[str] = None,
|
|||
|
|
) -> TableReport:
|
|||
|
|
"""Upsert data from SQLite table to PostgreSQL table."""
|
|||
|
|
import time
|
|||
|
|
|
|||
|
|
start_time = time.time()
|
|||
|
|
report = TableReport(table_name=table_name)
|
|||
|
|
|
|||
|
|
# Get schemas
|
|||
|
|
sqlite_schema = await get_sqlite_schema(sqlite_db, table_name)
|
|||
|
|
pg_schema = await get_pg_schema(pg_pool, table_name)
|
|||
|
|
|
|||
|
|
if not pg_schema["columns"]:
|
|||
|
|
report.status = "skipped"
|
|||
|
|
report.error_message = "Table does not exist in PostgreSQL"
|
|||
|
|
print(f" ⚠ Table {table_name} does not exist in PostgreSQL, skipping...")
|
|||
|
|
return report
|
|||
|
|
|
|||
|
|
# Match columns
|
|||
|
|
sqlite_columns = [col["name"] for col in sqlite_schema]
|
|||
|
|
pg_columns = [col["name"] for col in pg_schema["columns"]]
|
|||
|
|
common_columns = [col for col in sqlite_columns if col in pg_columns]
|
|||
|
|
|
|||
|
|
if not common_columns:
|
|||
|
|
report.status = "skipped"
|
|||
|
|
report.error_message = "No common columns"
|
|||
|
|
print(f" ⚠ No common columns for table {table_name}")
|
|||
|
|
return report
|
|||
|
|
|
|||
|
|
if verbose:
|
|||
|
|
print(f" Common columns: {', '.join(common_columns)}")
|
|||
|
|
print(f" PG Primary keys: {pg_schema['pk_columns']}")
|
|||
|
|
|
|||
|
|
# Determine conflict columns (tries PG first, then SQLite)
|
|||
|
|
use_temp_fallback = False
|
|||
|
|
use_exists_upsert = False
|
|||
|
|
sqlite_pk_columns = []
|
|||
|
|
pk_column = None
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
conflict_columns = determine_conflict_columns(
|
|||
|
|
pg_schema, sqlite_schema, table_name
|
|||
|
|
)
|
|||
|
|
except ValueError as e:
|
|||
|
|
sqlite_pk_columns = [col["name"] for col in sqlite_schema if col["pk"] > 0]
|
|||
|
|
if sqlite_pk_columns:
|
|||
|
|
# Prefer existence-based upsert for single-column PKs; fallback to temp-table for composite PKs
|
|||
|
|
if len(sqlite_pk_columns) == 1:
|
|||
|
|
print(
|
|||
|
|
f" ⚠ Table {table_name} has a SQLite primary key ({sqlite_pk_columns[0]}) but no matching unique/PK constraint in PostgreSQL. Will use existence-based update/insert fallback."
|
|||
|
|
)
|
|||
|
|
use_exists_upsert = True
|
|||
|
|
pk_column = sqlite_pk_columns[0]
|
|||
|
|
conflict_columns = [pk_column]
|
|||
|
|
else:
|
|||
|
|
print(
|
|||
|
|
f" ⚠ Table {table_name} has a SQLite primary key ({', '.join(sqlite_pk_columns)}) but no matching unique/PK constraint in PostgreSQL. Will attempt safer temp-table fallback."
|
|||
|
|
)
|
|||
|
|
use_temp_fallback = True
|
|||
|
|
conflict_columns = sqlite_pk_columns
|
|||
|
|
else:
|
|||
|
|
report.status = "skipped"
|
|||
|
|
report.error_message = str(e)
|
|||
|
|
print(f" ⚠ {e}")
|
|||
|
|
return report
|
|||
|
|
|
|||
|
|
if verbose:
|
|||
|
|
print(f" Using conflict columns: {conflict_columns}")
|
|||
|
|
|
|||
|
|
# Build upsert SQL for standard ON CONFLICT approach
|
|||
|
|
upsert_sql = build_upsert_sql(table_name, common_columns, conflict_columns)
|
|||
|
|
|
|||
|
|
# Get counts
|
|||
|
|
async with sqlite_db.execute(f"SELECT COUNT(*) FROM {table_name};") as cursor:
|
|||
|
|
row = await cursor.fetchone()
|
|||
|
|
report.sqlite_rows = row[0] if row else 0
|
|||
|
|
|
|||
|
|
async with pg_pool.acquire() as conn:
|
|||
|
|
report.pg_rows_before = await conn.fetchval(
|
|||
|
|
f"SELECT COUNT(*) FROM {table_name};"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print(
|
|||
|
|
f" SQLite rows: {report.sqlite_rows:,}, PG rows before: {report.pg_rows_before:,}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if report.sqlite_rows == 0:
|
|||
|
|
report.status = "success"
|
|||
|
|
report.pg_rows_after = report.pg_rows_before
|
|||
|
|
print(" Table is empty, nothing to upsert.")
|
|||
|
|
return report
|
|||
|
|
|
|||
|
|
# Prepare column query
|
|||
|
|
col_names_str = ", ".join([f'"{col}"' for col in common_columns])
|
|||
|
|
|
|||
|
|
# Prepare per-column PG types for validation and coercion
|
|||
|
|
pg_col_type_map = {c["name"]: c["type"] for c in pg_schema["columns"]}
|
|||
|
|
pg_col_types = [pg_col_type_map.get(col, "text") for col in common_columns]
|
|||
|
|
|
|||
|
|
# Local chunk size (adaptive)
|
|||
|
|
chunk_size = CHUNK_SIZE
|
|||
|
|
|
|||
|
|
# Pre-validate a sample of rows to catch common type mismatches early and adapt
|
|||
|
|
sample_n = min(1000, chunk_size, report.sqlite_rows)
|
|||
|
|
if sample_n > 0:
|
|||
|
|
fail_count, total_checked, sample_diag = await validate_sample_rows(
|
|||
|
|
sqlite_db,
|
|||
|
|
pg_col_types,
|
|||
|
|
col_names_str,
|
|||
|
|
table_name,
|
|||
|
|
dump_basename or "unknown_dump",
|
|||
|
|
sample_n,
|
|||
|
|
)
|
|||
|
|
if fail_count:
|
|||
|
|
pct = (fail_count / total_checked) * 100
|
|||
|
|
print(
|
|||
|
|
f" ⚠ Validation: {fail_count}/{total_checked} sample rows ({pct:.1f}%) appear problematic; reducing chunk size and writing diagnostics"
|
|||
|
|
)
|
|||
|
|
if sample_diag:
|
|||
|
|
print(f" ⚠ Sample diagnostics: {sample_diag}")
|
|||
|
|
# Reduce chunk size to be safer
|
|||
|
|
old_chunk = chunk_size
|
|||
|
|
chunk_size = max(1000, chunk_size // 4)
|
|||
|
|
print(f" ⚠ Adjusted chunk_size: {old_chunk} -> {chunk_size}")
|
|||
|
|
|
|||
|
|
# If we determined earlier that we need to use the temp-table fallback, handle that path now
|
|||
|
|
if use_temp_fallback:
|
|||
|
|
print(
|
|||
|
|
" Using temp-table fallback upsert (may be slower but safe on missing PG constraints)"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Process in chunks using cursor iteration but push to temp upsert helper
|
|||
|
|
offset = 0
|
|||
|
|
async with pg_pool.acquire() as conn:
|
|||
|
|
while offset < report.sqlite_rows:
|
|||
|
|
async with sqlite_db.execute(
|
|||
|
|
f"SELECT {col_names_str} FROM {table_name} LIMIT {chunk_size} OFFSET {offset};"
|
|||
|
|
) as cursor:
|
|||
|
|
rows = await cursor.fetchall()
|
|||
|
|
|
|||
|
|
if not rows:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# Coerce
|
|||
|
|
cleaned_rows = []
|
|||
|
|
for row in rows:
|
|||
|
|
coerced = []
|
|||
|
|
for val, pg_type in zip(row, pg_col_types):
|
|||
|
|
coerced.append(coerce_value_for_pg(val, pg_type))
|
|||
|
|
cleaned_rows.append(tuple(coerced))
|
|||
|
|
|
|||
|
|
# Attempt fallback upsert
|
|||
|
|
try:
|
|||
|
|
async with conn.transaction():
|
|||
|
|
upsert_count, errs, diag = await fallback_upsert_using_temp(
|
|||
|
|
conn,
|
|||
|
|
table_name,
|
|||
|
|
common_columns,
|
|||
|
|
sqlite_pk_columns,
|
|||
|
|
cleaned_rows,
|
|||
|
|
dump_basename or "unknown_dump",
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
upsert_count, errs, diag = (
|
|||
|
|
0,
|
|||
|
|
1,
|
|||
|
|
await save_failed_rows(
|
|||
|
|
dump_basename or "unknown_dump",
|
|||
|
|
table_name,
|
|||
|
|
offset,
|
|||
|
|
cleaned_rows[:100],
|
|||
|
|
str(e),
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
report.rows_affected += upsert_count
|
|||
|
|
report.errors += errs
|
|||
|
|
if diag:
|
|||
|
|
print(f"\n ⚠ Diagnostics written: {diag}")
|
|||
|
|
|
|||
|
|
if report.errors > 10:
|
|||
|
|
print(" ✗ Too many errors, aborting table")
|
|||
|
|
report.status = "failed"
|
|||
|
|
report.error_message = f"Too many errors ({report.errors})"
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
offset += chunk_size
|
|||
|
|
|
|||
|
|
# Set final counts
|
|||
|
|
async with pg_pool.acquire() as conn:
|
|||
|
|
report.pg_rows_after = await conn.fetchval(
|
|||
|
|
f"SELECT COUNT(*) FROM {table_name};"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
report.duration_seconds = time.time() - start_time
|
|||
|
|
|
|||
|
|
if report.errors == 0:
|
|||
|
|
report.status = "success"
|
|||
|
|
elif report.errors <= 10:
|
|||
|
|
report.status = "success"
|
|||
|
|
else:
|
|||
|
|
report.status = "failed"
|
|||
|
|
|
|||
|
|
return report
|
|||
|
|
|
|||
|
|
if use_exists_upsert:
|
|||
|
|
print(
|
|||
|
|
" Using existence-based upsert (INSERT for new rows, UPDATE for existing rows)"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Process in chunks using cursor iteration
|
|||
|
|
offset = 0
|
|||
|
|
async with pg_pool.acquire() as conn:
|
|||
|
|
while offset < report.sqlite_rows:
|
|||
|
|
async with sqlite_db.execute(
|
|||
|
|
f"SELECT {col_names_str} FROM {table_name} LIMIT {chunk_size} OFFSET {offset};"
|
|||
|
|
) as cursor:
|
|||
|
|
rows = await cursor.fetchall()
|
|||
|
|
|
|||
|
|
if not rows:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# Coerce
|
|||
|
|
cleaned_rows = []
|
|||
|
|
for row in rows:
|
|||
|
|
coerced = []
|
|||
|
|
for val, pg_type in zip(row, pg_col_types):
|
|||
|
|
coerced.append(coerce_value_for_pg(val, pg_type))
|
|||
|
|
cleaned_rows.append(tuple(coerced))
|
|||
|
|
|
|||
|
|
# Build list of PK values
|
|||
|
|
pk_index = common_columns.index(pk_column)
|
|||
|
|
pk_values = [r[pk_index] for r in cleaned_rows]
|
|||
|
|
|
|||
|
|
# Query existing PKs
|
|||
|
|
existing = set()
|
|||
|
|
try:
|
|||
|
|
rows_found = await conn.fetch(
|
|||
|
|
f"SELECT {pk_column} FROM {table_name} WHERE {pk_column} = ANY($1)",
|
|||
|
|
pk_values,
|
|||
|
|
)
|
|||
|
|
existing = set([r[pk_column] for r in rows_found])
|
|||
|
|
except Exception as e:
|
|||
|
|
# On error, save diagnostics and abort
|
|||
|
|
diag = await save_failed_rows(
|
|||
|
|
dump_basename or "unknown_dump",
|
|||
|
|
table_name,
|
|||
|
|
offset,
|
|||
|
|
cleaned_rows[:100],
|
|||
|
|
str(e),
|
|||
|
|
)
|
|||
|
|
print(f"\n ⚠ Diagnostics written: {diag}")
|
|||
|
|
report.errors += 1
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
insert_rows = [r for r in cleaned_rows if r[pk_index] not in existing]
|
|||
|
|
update_rows = [r for r in cleaned_rows if r[pk_index] in existing]
|
|||
|
|
|
|||
|
|
# Perform inserts
|
|||
|
|
if insert_rows:
|
|||
|
|
col_names_csv = ", ".join([f'"{c}"' for c in common_columns])
|
|||
|
|
placeholders = ", ".join(
|
|||
|
|
[f"${i + 1}" for i in range(len(common_columns))]
|
|||
|
|
)
|
|||
|
|
insert_sql = f"INSERT INTO {table_name} ({col_names_csv}) VALUES ({placeholders})"
|
|||
|
|
try:
|
|||
|
|
await conn.executemany(insert_sql, insert_rows)
|
|||
|
|
report.rows_affected += len(insert_rows)
|
|||
|
|
except Exception as e:
|
|||
|
|
diag = await save_failed_rows(
|
|||
|
|
dump_basename or "unknown_dump",
|
|||
|
|
table_name,
|
|||
|
|
offset,
|
|||
|
|
insert_rows[:100],
|
|||
|
|
str(e),
|
|||
|
|
)
|
|||
|
|
print(f"\n ⚠ Diagnostics written: {diag}")
|
|||
|
|
report.errors += 1
|
|||
|
|
|
|||
|
|
# Perform updates (per-row executemany)
|
|||
|
|
if update_rows:
|
|||
|
|
update_cols = [c for c in common_columns if c != pk_column]
|
|||
|
|
set_clause = ", ".join(
|
|||
|
|
[f'"{c}" = ${i + 1}' for i, c in enumerate(update_cols)]
|
|||
|
|
)
|
|||
|
|
# For executemany we'll pass values ordered as [col1, col2, ..., pk]
|
|||
|
|
update_sql = f'UPDATE {table_name} SET {set_clause} WHERE "{pk_column}" = ${len(update_cols) + 1}'
|
|||
|
|
update_values = []
|
|||
|
|
for r in update_rows:
|
|||
|
|
vals = [r[common_columns.index(c)] for c in update_cols]
|
|||
|
|
vals.append(r[pk_index])
|
|||
|
|
update_values.append(tuple(vals))
|
|||
|
|
try:
|
|||
|
|
await conn.executemany(update_sql, update_values)
|
|||
|
|
report.rows_affected += len(update_rows)
|
|||
|
|
except Exception as e:
|
|||
|
|
diag = await save_failed_rows(
|
|||
|
|
dump_basename or "unknown_dump",
|
|||
|
|
table_name,
|
|||
|
|
offset,
|
|||
|
|
update_rows[:100],
|
|||
|
|
str(e),
|
|||
|
|
)
|
|||
|
|
print(f"\n ⚠ Diagnostics written: {diag}")
|
|||
|
|
report.errors += 1
|
|||
|
|
|
|||
|
|
if report.errors > 10:
|
|||
|
|
print(" ✗ Too many errors, aborting table")
|
|||
|
|
report.status = "failed"
|
|||
|
|
report.error_message = f"Too many errors ({report.errors})"
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
offset += chunk_size
|
|||
|
|
|
|||
|
|
# Set final counts
|
|||
|
|
async with pg_pool.acquire() as conn:
|
|||
|
|
report.pg_rows_after = await conn.fetchval(
|
|||
|
|
f"SELECT COUNT(*) FROM {table_name};"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
report.duration_seconds = time.time() - start_time
|
|||
|
|
|
|||
|
|
if report.errors == 0:
|
|||
|
|
report.status = "success"
|
|||
|
|
elif report.errors <= 10:
|
|||
|
|
report.status = "success"
|
|||
|
|
else:
|
|||
|
|
report.status = "failed"
|
|||
|
|
|
|||
|
|
return report
|
|||
|
|
|
|||
|
|
# Process in chunks using cursor iteration
|
|||
|
|
offset = 0
|
|||
|
|
async with pg_pool.acquire() as conn:
|
|||
|
|
while offset < report.sqlite_rows:
|
|||
|
|
try:
|
|||
|
|
# Fetch chunk from SQLite
|
|||
|
|
async with sqlite_db.execute(
|
|||
|
|
f"SELECT {col_names_str} FROM {table_name} LIMIT {chunk_size} OFFSET {offset};"
|
|||
|
|
) as cursor:
|
|||
|
|
rows = await cursor.fetchall()
|
|||
|
|
|
|||
|
|
if not rows:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# Coerce data per-column based on PG types
|
|||
|
|
pg_col_types = []
|
|||
|
|
pg_col_type_map = {c["name"]: c["type"] for c in pg_schema["columns"]}
|
|||
|
|
for col in common_columns:
|
|||
|
|
pg_col_types.append(pg_col_type_map.get(col, "text"))
|
|||
|
|
|
|||
|
|
cleaned_rows = []
|
|||
|
|
for row in rows:
|
|||
|
|
coerced = []
|
|||
|
|
for val, pg_type in zip(row, pg_col_types):
|
|||
|
|
coerced.append(coerce_value_for_pg(val, pg_type))
|
|||
|
|
cleaned_rows.append(tuple(coerced))
|
|||
|
|
|
|||
|
|
if not dry_run:
|
|||
|
|
# Try inserting chunk with adaptive recovery
|
|||
|
|
dump_basename_local = dump_basename or "unknown_dump"
|
|||
|
|
success_count, err_count, diag = await attempt_insert_chunk(
|
|||
|
|
conn,
|
|||
|
|
upsert_sql,
|
|||
|
|
cleaned_rows,
|
|||
|
|
table_name,
|
|||
|
|
offset,
|
|||
|
|
dump_basename_local,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
report.rows_affected += success_count
|
|||
|
|
report.errors += err_count
|
|||
|
|
|
|||
|
|
if diag:
|
|||
|
|
print(f"\n ⚠ Diagnostics written: {diag}")
|
|||
|
|
# Record in report for visibility
|
|||
|
|
report.error_message = (
|
|||
|
|
(report.error_message + "; ")
|
|||
|
|
if report.error_message
|
|||
|
|
else ""
|
|||
|
|
) + f"Diagnostics: {diag}"
|
|||
|
|
|
|||
|
|
if err_count > 0:
|
|||
|
|
print(
|
|||
|
|
f"\n ⚠ Error at offset {offset}: {err_count} problematic rows (see diagnostics)"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if report.errors > 10:
|
|||
|
|
print(" ✗ Too many errors, aborting table")
|
|||
|
|
report.status = "failed"
|
|||
|
|
report.error_message = f"Too many errors ({report.errors})" + (
|
|||
|
|
f"; {report.error_message}" if report.error_message else ""
|
|||
|
|
)
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
offset += CHUNK_SIZE
|
|||
|
|
|
|||
|
|
progress = min(offset, report.sqlite_rows)
|
|||
|
|
pct = (progress / report.sqlite_rows) * 100
|
|||
|
|
print(
|
|||
|
|
f" Progress: {progress:,}/{report.sqlite_rows:,} ({pct:.1f}%)",
|
|||
|
|
end="\r",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"\n ⚠ Error at offset {offset}: {e}")
|
|||
|
|
report.errors += 1
|
|||
|
|
offset += CHUNK_SIZE
|
|||
|
|
|
|||
|
|
if report.errors > 10:
|
|||
|
|
print(" ✗ Too many errors, aborting table")
|
|||
|
|
report.status = "failed"
|
|||
|
|
report.error_message = f"Too many errors ({report.errors})"
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
print() # Newline after progress
|
|||
|
|
|
|||
|
|
# Get final count
|
|||
|
|
async with pg_pool.acquire() as conn:
|
|||
|
|
report.pg_rows_after = await conn.fetchval(
|
|||
|
|
f"SELECT COUNT(*) FROM {table_name};"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
report.duration_seconds = time.time() - start_time
|
|||
|
|
|
|||
|
|
if report.errors == 0:
|
|||
|
|
report.status = "success"
|
|||
|
|
elif report.errors <= 10:
|
|||
|
|
report.status = "success"
|
|||
|
|
else:
|
|||
|
|
report.status = "failed"
|
|||
|
|
|
|||
|
|
return report
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def upsert_database(
|
|||
|
|
sqlite_path: str,
|
|||
|
|
tables: Optional[List[str]] = None,
|
|||
|
|
dry_run: bool = False,
|
|||
|
|
verbose: bool = False,
|
|||
|
|
dump_date: Optional[str] = None,
|
|||
|
|
) -> UpsertReport:
|
|||
|
|
"""Main upsert function - process SQLite dump into PostgreSQL."""
|
|||
|
|
report = UpsertReport(sqlite_source=sqlite_path, dump_date=dump_date)
|
|||
|
|
report.start_time = datetime.now()
|
|||
|
|
|
|||
|
|
print(f"{'[DRY RUN] ' if dry_run else ''}SQLite → PostgreSQL Upsert (Async)")
|
|||
|
|
print(f"SQLite: {sqlite_path}")
|
|||
|
|
print(f"PostgreSQL: {PG_CONFIG['database']}@{PG_CONFIG['host']}")
|
|||
|
|
print(f"Chunk size: {CHUNK_SIZE:,} rows\n")
|
|||
|
|
|
|||
|
|
# Connect to databases
|
|||
|
|
print("Connecting to databases...")
|
|||
|
|
sqlite_db = await aiosqlite.connect(sqlite_path)
|
|||
|
|
pg_pool = await asyncpg.create_pool(
|
|||
|
|
host=PG_CONFIG["host"],
|
|||
|
|
port=PG_CONFIG["port"],
|
|||
|
|
database=PG_CONFIG["database"],
|
|||
|
|
user=PG_CONFIG["user"],
|
|||
|
|
password=PG_CONFIG["password"],
|
|||
|
|
min_size=2,
|
|||
|
|
max_size=10,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# Get tables
|
|||
|
|
available_tables = await get_sqlite_tables(sqlite_db)
|
|||
|
|
|
|||
|
|
if tables:
|
|||
|
|
process_tables = [t for t in tables if t in available_tables]
|
|||
|
|
missing = [t for t in tables if t not in available_tables]
|
|||
|
|
if missing:
|
|||
|
|
print(f"⚠ Tables not found in SQLite: {', '.join(missing)}")
|
|||
|
|
else:
|
|||
|
|
process_tables = available_tables
|
|||
|
|
|
|||
|
|
print(f"Tables to process: {', '.join(process_tables)}\n")
|
|||
|
|
|
|||
|
|
for i, table_name in enumerate(process_tables, 1):
|
|||
|
|
print(f"[{i}/{len(process_tables)}] Processing table: {table_name}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
dump_basename = Path(sqlite_path).name
|
|||
|
|
table_report = await upsert_table(
|
|||
|
|
sqlite_db, pg_pool, table_name, dry_run, verbose, dump_basename
|
|||
|
|
)
|
|||
|
|
report.tables.append(table_report)
|
|||
|
|
|
|||
|
|
if table_report.status == "success":
|
|||
|
|
print(
|
|||
|
|
f" ✓ Completed: +{table_report.rows_inserted:,} inserted, ~{table_report.rows_updated:,} updated"
|
|||
|
|
)
|
|||
|
|
elif table_report.status == "skipped":
|
|||
|
|
print(" ○ Skipped")
|
|||
|
|
else:
|
|||
|
|
print(" ✗ Failed")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
table_report = TableReport(
|
|||
|
|
table_name=table_name, status="failed", error_message=str(e)
|
|||
|
|
)
|
|||
|
|
report.tables.append(table_report)
|
|||
|
|
print(f" ✗ Failed: {e}")
|
|||
|
|
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
await sqlite_db.close()
|
|||
|
|
await pg_pool.close()
|
|||
|
|
|
|||
|
|
report.end_time = datetime.now()
|
|||
|
|
return report
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def test_pg_connection() -> bool:
|
|||
|
|
"""Test PostgreSQL connection before proceeding with download."""
|
|||
|
|
print("Testing PostgreSQL connection...")
|
|||
|
|
try:
|
|||
|
|
conn = await asyncpg.connect(
|
|||
|
|
host=PG_CONFIG["host"],
|
|||
|
|
port=PG_CONFIG["port"],
|
|||
|
|
database=PG_CONFIG["database"],
|
|||
|
|
user=PG_CONFIG["user"],
|
|||
|
|
password=PG_CONFIG["password"],
|
|||
|
|
)
|
|||
|
|
version = await conn.fetchval("SELECT version();")
|
|||
|
|
await conn.close()
|
|||
|
|
print(f" ✓ Connected to PostgreSQL: {version[:50]}...")
|
|||
|
|
return True
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f" ✗ PostgreSQL connection failed: {e}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def check_and_fetch_latest(
|
|||
|
|
force: bool = False,
|
|||
|
|
dest_dir: Optional[str] = None,
|
|||
|
|
dry_run: bool = False,
|
|||
|
|
verbose: bool = False,
|
|||
|
|
notify: bool = False,
|
|||
|
|
) -> Optional[UpsertReport]:
|
|||
|
|
"""Check for new dumps on lrclib.net and upsert if newer than last run."""
|
|||
|
|
|
|||
|
|
# Send start notification
|
|||
|
|
if notify:
|
|||
|
|
await notify_start()
|
|||
|
|
|
|||
|
|
# Test PostgreSQL connection FIRST before downloading large files
|
|||
|
|
if not await test_pg_connection():
|
|||
|
|
error_msg = "Cannot connect to PostgreSQL database"
|
|||
|
|
print(f"Aborting: {error_msg}")
|
|||
|
|
if notify:
|
|||
|
|
await notify_failure(
|
|||
|
|
error_msg,
|
|||
|
|
"Connection Test",
|
|||
|
|
f"Host: {PG_CONFIG['host']}:{PG_CONFIG['port']}",
|
|||
|
|
)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# Get latest dump info
|
|||
|
|
dump_info = await fetch_latest_dump_info()
|
|||
|
|
if not dump_info:
|
|||
|
|
error_msg = "Could not fetch latest dump info from lrclib.net"
|
|||
|
|
print(error_msg)
|
|||
|
|
if notify:
|
|||
|
|
await notify_failure(
|
|||
|
|
error_msg, "Fetch Dump Info", f"URL: {LRCLIB_DUMPS_URL}"
|
|||
|
|
)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
dump_date = dump_info["date"]
|
|||
|
|
dump_date_str = dump_date.strftime("%Y-%m-%d %H:%M:%S")
|
|||
|
|
|
|||
|
|
# Check against last upsert
|
|||
|
|
state = load_state()
|
|||
|
|
last_upsert_str = state.get("last_dump_date") # Full datetime string
|
|||
|
|
|
|||
|
|
if last_upsert_str:
|
|||
|
|
try:
|
|||
|
|
# Try parsing with time first
|
|||
|
|
last_date = datetime.strptime(last_upsert_str, "%Y-%m-%d %H:%M:%S")
|
|||
|
|
except ValueError:
|
|||
|
|
# Fall back to date only
|
|||
|
|
last_date = datetime.strptime(last_upsert_str, "%Y-%m-%d")
|
|||
|
|
|
|||
|
|
print(f"Last upsert dump date: {last_upsert_str}")
|
|||
|
|
|
|||
|
|
if dump_date <= last_date and not force:
|
|||
|
|
print(
|
|||
|
|
f"No new dump available (latest: {dump_date_str}, last upsert: {last_upsert_str})"
|
|||
|
|
)
|
|||
|
|
print("Use --force to upsert anyway")
|
|||
|
|
if notify:
|
|||
|
|
await notify_no_update(dump_date_str, last_upsert_str)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
print(f"New dump available: {dump_date_str} > {last_upsert_str}")
|
|||
|
|
else:
|
|||
|
|
print("No previous upsert recorded")
|
|||
|
|
|
|||
|
|
if dry_run:
|
|||
|
|
print(f"[DRY RUN] Would download and upsert {dump_info['filename']}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# Notify about new dump found
|
|||
|
|
if notify:
|
|||
|
|
await notify_new_dump_found(dump_info["filename"], dump_date_str)
|
|||
|
|
|
|||
|
|
# Download and extract
|
|||
|
|
sqlite_path, download_error = await download_and_extract_dump(
|
|||
|
|
dump_info["url"], dest_dir
|
|||
|
|
)
|
|||
|
|
if not sqlite_path:
|
|||
|
|
error_msg = download_error or "Unknown download/extract error"
|
|||
|
|
print(f"Failed: {error_msg}")
|
|||
|
|
if notify:
|
|||
|
|
await notify_failure(
|
|||
|
|
error_msg, "Download/Extract", f"URL: {dump_info['url']}"
|
|||
|
|
)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# Keep files until we confirm import success to allow resume/inspection
|
|||
|
|
gz_path = str(Path(sqlite_path).with_suffix(Path(sqlite_path).suffix + ".gz"))
|
|||
|
|
print(f"Downloaded and extracted: {sqlite_path}")
|
|||
|
|
print(f"Keeping source compressed file for resume/inspection: {gz_path}")
|
|||
|
|
|
|||
|
|
# Define SQLITE_DB_PATH dynamically after extraction (persist globally)
|
|||
|
|
global SQLITE_DB_PATH
|
|||
|
|
SQLITE_DB_PATH = sqlite_path # Set to the extracted SQLite file path
|
|||
|
|
print(f"SQLite database path set to: {SQLITE_DB_PATH}")
|
|||
|
|
|
|||
|
|
upsert_successful = False
|
|||
|
|
try:
|
|||
|
|
# Perform upsert
|
|||
|
|
report = await upsert_database(
|
|||
|
|
sqlite_path=sqlite_path,
|
|||
|
|
tables=None,
|
|||
|
|
dry_run=False,
|
|||
|
|
verbose=verbose,
|
|||
|
|
dump_date=dump_date_str,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# If we had no errors and at least one table was considered, determine if any rows were upserted
|
|||
|
|
if report.total_errors == 0 or len(report.successful_tables) > 0:
|
|||
|
|
# Determine if any table actually had rows inserted or updated
|
|||
|
|
any_upserted = any(
|
|||
|
|
(t.status == "success" and (t.rows_inserted > 0 or t.rows_updated > 0))
|
|||
|
|
for t in report.tables
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Mark overall upsert success (only true if rows were actually upserted)
|
|||
|
|
upsert_successful = any_upserted
|
|||
|
|
|
|||
|
|
if upsert_successful:
|
|||
|
|
# Save state only when an actual upsert occurred to avoid skipping future attempts
|
|||
|
|
state["last_dump_date"] = dump_date_str # Date from the dump filename
|
|||
|
|
state["last_upsert_time"] = datetime.now().isoformat() # When we ran
|
|||
|
|
state["last_dump_url"] = dump_info["url"]
|
|||
|
|
state["last_dump_filename"] = dump_info["filename"]
|
|||
|
|
save_state(state)
|
|||
|
|
print(f"\nState saved: last_dump_date = {dump_date_str}")
|
|||
|
|
|
|||
|
|
if notify:
|
|||
|
|
if any_upserted:
|
|||
|
|
await notify_success(report)
|
|||
|
|
else:
|
|||
|
|
# All tables skipped or empty (no actual upsert)
|
|||
|
|
await discord_notify(
|
|||
|
|
title="⚠️ LRCLib DB Update - No Data Upserted",
|
|||
|
|
description="Upsert completed, but no rows were inserted or updated. All tables may have been skipped due to missing primary keys or were empty. The dump files have been retained for inspection/resume.",
|
|||
|
|
color=DiscordColor.WARNING,
|
|||
|
|
fields=[
|
|||
|
|
{
|
|||
|
|
"name": "Tables Processed",
|
|||
|
|
"value": str(len(report.tables)),
|
|||
|
|
"inline": True,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "Successful Tables",
|
|||
|
|
"value": str(len(report.successful_tables)),
|
|||
|
|
"inline": True,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "Rows Inserted",
|
|||
|
|
"value": str(report.total_rows_inserted),
|
|||
|
|
"inline": True,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "Rows Updated",
|
|||
|
|
"value": str(report.total_rows_updated),
|
|||
|
|
"inline": True,
|
|||
|
|
},
|
|||
|
|
],
|
|||
|
|
footer=f"Source: {report.sqlite_source}",
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
# All tables failed
|
|||
|
|
if notify:
|
|||
|
|
failed_tables = ", ".join(report.failed_tables[:5])
|
|||
|
|
await notify_failure(
|
|||
|
|
"All tables failed to upsert",
|
|||
|
|
"Upsert",
|
|||
|
|
f"Failed tables: {failed_tables}",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return report
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = str(e)
|
|||
|
|
print(f"Upsert error: {error_msg}")
|
|||
|
|
if notify:
|
|||
|
|
import traceback
|
|||
|
|
|
|||
|
|
tb = traceback.format_exc()
|
|||
|
|
await notify_failure(error_msg, "Upsert", tb[:1000])
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
# Only remove files when a successful upsert occurred. Keep them otherwise for resume/debug.
|
|||
|
|
if upsert_successful:
|
|||
|
|
# Clean up downloaded SQLite file
|
|||
|
|
if sqlite_path and os.path.exists(sqlite_path):
|
|||
|
|
print(f"Cleaning up SQLite dump (success): {sqlite_path}")
|
|||
|
|
os.unlink(sqlite_path)
|
|||
|
|
|
|||
|
|
# Also remove the .gz file
|
|||
|
|
if sqlite_path:
|
|||
|
|
gz_path = sqlite_path + ".gz"
|
|||
|
|
if os.path.exists(gz_path):
|
|||
|
|
print(f"Cleaning up compressed file (success): {gz_path}")
|
|||
|
|
os.unlink(gz_path)
|
|||
|
|
else:
|
|||
|
|
# Keep files for inspection/resume
|
|||
|
|
if sqlite_path and os.path.exists(sqlite_path):
|
|||
|
|
print(f"Retaining SQLite dump for inspection/resume: {sqlite_path}")
|
|||
|
|
if sqlite_path:
|
|||
|
|
gz_path = sqlite_path + ".gz"
|
|||
|
|
if os.path.exists(gz_path):
|
|||
|
|
print(f"Retaining compressed file for inspection/resume: {gz_path}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def create_new_schema(pg_conn: asyncpg.Connection, schema_name: str) -> None:
|
|||
|
|
"""Create a new schema for the migration."""
|
|||
|
|
await pg_conn.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def migrate_to_new_schema(
|
|||
|
|
sqlite_conn: aiosqlite.Connection,
|
|||
|
|
pg_conn: asyncpg.Connection,
|
|||
|
|
schema_name: str,
|
|||
|
|
table_name: str,
|
|||
|
|
columns: List[str],
|
|||
|
|
) -> None:
|
|||
|
|
"""Migrate data to a new schema."""
|
|||
|
|
staging_table = f"{schema_name}.{table_name}"
|
|||
|
|
|
|||
|
|
# Drop and recreate staging table in the new schema
|
|||
|
|
await pg_conn.execute(f"DROP TABLE IF EXISTS {staging_table}")
|
|||
|
|
await pg_conn.execute(
|
|||
|
|
f"CREATE TABLE {staging_table} (LIKE {table_name} INCLUDING ALL)"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Import data into the staging table
|
|||
|
|
sqlite_cursor = await sqlite_conn.execute(
|
|||
|
|
f"SELECT {', '.join(columns)} FROM {table_name}"
|
|||
|
|
)
|
|||
|
|
rows = await sqlite_cursor.fetchall()
|
|||
|
|
await sqlite_cursor.close()
|
|||
|
|
|
|||
|
|
# Convert rows to a list to ensure compatibility with len()
|
|||
|
|
rows = list(rows)
|
|||
|
|
print(f"Imported {len(rows)} rows into {staging_table}")
|
|||
|
|
|
|||
|
|
placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
|
|||
|
|
insert_sql = f"INSERT INTO {staging_table} ({', '.join([f'"{col}"' for col in columns])}) VALUES ({placeholders})"
|
|||
|
|
|
|||
|
|
await pg_conn.executemany(insert_sql, rows)
|
|||
|
|
print(f"Imported {len(rows)} rows into {staging_table}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def swap_schemas(pg_conn: asyncpg.Connection, new_schema: str) -> None:
|
|||
|
|
"""Swap the new schema with the public schema."""
|
|||
|
|
async with pg_conn.transaction():
|
|||
|
|
# Rename public schema to a backup schema
|
|||
|
|
await pg_conn.execute("ALTER SCHEMA public RENAME TO backup")
|
|||
|
|
# Rename new schema to public
|
|||
|
|
await pg_conn.execute(f"ALTER SCHEMA {new_schema} RENAME TO public")
|
|||
|
|
# Drop the backup schema
|
|||
|
|
await pg_conn.execute("DROP SCHEMA backup CASCADE")
|
|||
|
|
|
|||
|
|
print(f"Swapped schema {new_schema} with public")
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def migrate_database(sqlite_path: Optional[str] = None):
|
|||
|
|
"""Main migration function.
|
|||
|
|
|
|||
|
|
If `sqlite_path` is provided, use it. Otherwise fall back to the global `SQLITE_DB_PATH`.
|
|||
|
|
"""
|
|||
|
|
path = sqlite_path or SQLITE_DB_PATH
|
|||
|
|
if not path:
|
|||
|
|
raise ValueError(
|
|||
|
|
"No SQLite path provided. Call migrate_database(sqlite_path=...) or set SQLITE_DB_PATH before invoking."
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
sqlite_conn = await aiosqlite.connect(path)
|
|||
|
|
pg_conn = await asyncpg.connect(**PG_CONFIG)
|
|||
|
|
|
|||
|
|
new_schema = "staging"
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# Create a new schema for the migration
|
|||
|
|
await create_new_schema(pg_conn, new_schema)
|
|||
|
|
|
|||
|
|
sqlite_tables = await sqlite_conn.execute(
|
|||
|
|
"SELECT name FROM sqlite_master WHERE type='table'"
|
|||
|
|
)
|
|||
|
|
tables = [row[0] for row in await sqlite_tables.fetchall()]
|
|||
|
|
await sqlite_tables.close()
|
|||
|
|
|
|||
|
|
for table_name in tables:
|
|||
|
|
sqlite_cursor = await sqlite_conn.execute(
|
|||
|
|
f"PRAGMA table_info({table_name})"
|
|||
|
|
)
|
|||
|
|
columns = [row[1] for row in await sqlite_cursor.fetchall()]
|
|||
|
|
await sqlite_cursor.close()
|
|||
|
|
|
|||
|
|
# Migrate data to the new schema
|
|||
|
|
await migrate_to_new_schema(
|
|||
|
|
sqlite_conn, pg_conn, new_schema, table_name, columns
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Swap schemas after successful migration
|
|||
|
|
await swap_schemas(pg_conn, new_schema)
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
await sqlite_conn.close()
|
|||
|
|
await pg_conn.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Ensure SQLITE_DB_PATH is defined globally before usage
|
|||
|
|
SQLITE_DB_PATH = None # Initialize as None to avoid NameError
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
import argparse
|
|||
|
|
|
|||
|
|
parser = argparse.ArgumentParser(description="LRCLib DB upsert / migration utility")
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--migrate",
|
|||
|
|
action="store_true",
|
|||
|
|
help="Migrate a local SQLite file into a new staging schema and swap",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--sqlite",
|
|||
|
|
type=str,
|
|||
|
|
help="Path to local SQLite file to migrate (required for --migrate)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--check",
|
|||
|
|
action="store_true",
|
|||
|
|
help="Check lrclib.net for latest dump and upsert if newer",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--force", action="store_true", help="Force upsert even if dump is not newer"
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--dest", type=str, help="Destination directory for downloaded dumps"
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--dry-run", action="store_true", help="Do not perform writes; just simulate"
|
|||
|
|
)
|
|||
|
|
parser.add_argument("--verbose", action="store_true", help="Verbose output")
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--notify", action="store_true", help="Send Discord notifications"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
if args.migrate:
|
|||
|
|
if not args.sqlite:
|
|||
|
|
parser.error("--migrate requires --sqlite PATH")
|
|||
|
|
print(f"Starting schema migration from local SQLite: {args.sqlite}")
|
|||
|
|
asyncio.run(migrate_database(sqlite_path=args.sqlite))
|
|||
|
|
else:
|
|||
|
|
# Default behavior: check lrclib.net and upsert if needed
|
|||
|
|
print("Checking for latest dump and performing upsert (if applicable)...")
|
|||
|
|
asyncio.run(
|
|||
|
|
check_and_fetch_latest(
|
|||
|
|
force=args.force,
|
|||
|
|
dest_dir=args.dest,
|
|||
|
|
dry_run=args.dry_run,
|
|||
|
|
verbose=args.verbose,
|
|||
|
|
notify=args.notify,
|
|||
|
|
)
|
|||
|
|
)
|