#!/usr/bin/env python3
"""
Fix inconsistent Django migration history by ensuring earlier migrations are marked applied
before later ones. Intended to be run from the project root where `manage.py` lives.

What it does:
- Backs up the current rows from `django_migrations` to `django_migrations_backup.csv`.
- Parses `python manage.py showmigrations` output for each app.
- For each app, if a migration later in the list is applied while an earlier one is not,
  the script will mark the earlier one(s) as applied using `python manage.py migrate <app> <migration> --fake`.
- Finally runs `python manage.py migrate`.

WARNING: This script changes migration history using --fake/--fake-initial. BACKUP your
database before running it. Review the printed plan before confirming.

Usage:
  python scripts/fix_migration_history.py  # interactive confirm
  python scripts/fix_migration_history.py --yes  # run without confirmation

"""
from __future__ import annotations

import csv
import os
import re
import shlex
import subprocess
import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
MANAGE_PY = ROOT / "manage.py"

if not MANAGE_PY.exists():
    print("Error: manage.py not found. Run this script from the project root.")
    sys.exit(2)


def run(cmd, check=True):
    # accept either a list/tuple of args or a command string
    print(f"\n$ {cmd}")
    if isinstance(cmd, (list, tuple)):
        res = subprocess.run(cmd, capture_output=True, text=True)
    else:
        res = subprocess.run(shlex.split(cmd), capture_output=True, text=True)
    print(res.stdout, end="")
    if res.stderr:
        print(res.stderr, end="", file=sys.stderr)
    if check and res.returncode != 0:
        raise subprocess.CalledProcessError(res.returncode, cmd)
    return res


def backup_django_migrations(out_path: Path):
    print(f"Backing up django_migrations to {out_path}")
    # Run a small Django one-liner to dump rows from django_migrations as CSV
    script = (
        "from django.db import connection, utils;"
        "cur = connection.cursor();"
        "cur.execute(\"SELECT app, name, applied FROM django_migrations ORDER BY app, name\");"
        "rows = cur.fetchall();"
        "import csv, sys, json;"
        "print('CSV-BEGIN');"
        "writer = csv.writer(sys.stdout);"
        "writer.writerow(['app','name','applied']);"
        "writer.writerows(rows);"
    )
    # Call manage.py without going through a shell to avoid wrapper issues on some hosts
    cmd = [sys.executable, str(MANAGE_PY), "shell", "-c", script]
    res = subprocess.run(cmd, capture_output=True, text=True)
    if res.returncode != 0:
        print("Warning: could not backup django_migrations via manage.py shell.\n", file=sys.stderr)
        print(res.stdout)
        print(res.stderr, file=sys.stderr)
        return False

    out = res.stdout
    marker = "CSV-BEGIN"
    if marker not in out:
        print("Unexpected output while backing up migrations. Aborting backup.")
        return False
    csv_text = out.split(marker, 1)[1].lstrip()
    out_path.write_text(csv_text, encoding="utf-8")
    print(f"Backup written: {out_path}")
    return True


def parse_showmigrations(output: str):
    """Return dict app -> list of (migration_name, applied_bool) preserving order."""
    apps = {}
    current_app = None
    mig_re = re.compile(r"^\s*\[(?P<applied>[ X])\]\s+(?P<name>\S+)" )
    for line in output.splitlines():
        if not line.strip():
            continue
        # app line: starts at col 0 and has no [ ]
        if not line.startswith(" "):
            current_app = line.strip()
            apps[current_app] = []
            continue
        m = mig_re.match(line)
        if m and current_app is not None:
            applied = (m.group('applied') == 'X')
            name = m.group('name').strip()
            apps[current_app].append((name, applied))
    return apps


def get_showmigrations():
    cmd = [sys.executable, str(MANAGE_PY), "showmigrations"]
    res = subprocess.run(cmd, capture_output=True, text=True)
    if res.returncode != 0:
        print(res.stderr, file=sys.stderr)
        raise RuntimeError("showmigrations failed")
    return res.stdout


def plan_fixes(app_migs: dict[str, list[tuple[str, bool]]]):
    """Return list of (app, migration_name) to fake-apply in order."""
    to_fake = []
    for app, migrations in app_migs.items():
        # find highest applied index
        highest_applied = -1
        for i, (_name, applied) in enumerate(migrations):
            if applied:
                highest_applied = i
        if highest_applied <= 0:
            continue
        # any index <= highest_applied that is not applied should be faked
        for i in range(0, highest_applied + 1):
            name, applied = migrations[i]
            if not applied:
                to_fake.append((app, name))
    return to_fake


def migration_file_path(app: str, migration_name: str) -> Path | None:
    """Return the Path to the migration file for an app/migration or None if not found."""
    mig_py = ROOT / app / "migrations" / f"{migration_name}.py"
    if mig_py.exists():
        return mig_py
    # sometimes migration files are named with suffixes; try glob
    pattern = str(ROOT / app / "migrations" / f"{migration_name}*.py")
    import glob

    matches = glob.glob(pattern)
    return Path(matches[0]) if matches else None


def parse_migration_dependencies(mig_path: Path) -> list[tuple[str, str]]:
    """Parse migration file using ast and extract explicit tuples from the
    Migration.dependencies list. Returns list of (app, name).
    Swappable or callable dependencies are ignored.
    """
    try:
        import ast
        text = mig_path.read_text(encoding="utf-8")
        tree = ast.parse(text)
    except Exception:
        return []

    deps: list[tuple[str, str]] = []

    for node in ast.walk(tree):
        # find class Migration
        if isinstance(node, ast.ClassDef) and node.name == 'Migration':
            for stmt in node.body:
                # look for assignments like: dependencies = [ ... ]
                if isinstance(stmt, ast.Assign):
                    for target in stmt.targets:
                        if isinstance(target, ast.Name) and target.id == 'dependencies':
                            value = stmt.value
                            # expect a list literal
                            if isinstance(value, ast.List):
                                for elt in value.elts:
                                    # only accept tuple of two constant strings
                                    if isinstance(elt, ast.Tuple) and len(elt.elts) >= 2:
                                        a, b = elt.elts[0], elt.elts[1]
                                        if isinstance(a, ast.Constant) and isinstance(b, ast.Constant):
                                            if isinstance(a.value, str) and isinstance(b.value, str):
                                                deps.append((a.value, b.value))
    return deps


def build_dependency_plan(app_migs: dict[str, list[tuple[str, bool]]], initial_fixes: list[tuple[str, str]]):
    """Return an ordered list of (app,name) to fake-apply where dependencies come first.
    This will also include any migration that a currently applied migration depends on.
    """
    # helper to check if a migration is already applied
    applied_map = {app: {name for name, applied in migs if applied} for app, migs in app_migs.items()}

    needed = set()

    # Start from any applied migration: ensure their declared dependencies are applied
    for app, migs in app_migs.items():
        for name, applied in migs:
            if not applied:
                continue
            # inspect this migration's file for dependencies
            mig_path = migration_file_path(app, name)
            if not mig_path:
                continue
            for dep_app, dep_name in parse_migration_dependencies(mig_path):
                if dep_app in applied_map and dep_name in applied_map[dep_app]:
                    continue
                needed.add((dep_app, dep_name))

    # also include the initial per-app fixes (missing earlier migrations)
    for item in initial_fixes:
        needed.add(item)

    # recursively expand dependencies
    ordered = []
    visited = set()

    def add_with_deps(app, name):
        if (app, name) in visited:
            return
        visited.add((app, name))
        # if already applied, skip
        if app in applied_map and name in applied_map[app]:
            return
        mig_path = migration_file_path(app, name)
        if mig_path:
            for dep_app, dep_name in parse_migration_dependencies(mig_path):
                # skip swappable_dependency-like entries where dep_app is not a simple app name
                if not isinstance(dep_app, str) or not dep_app:
                    continue
                add_with_deps(dep_app, dep_name)
        ordered.append((app, name))

    # Build order for all needed
    for app, name in list(needed):
        add_with_deps(app, name)

    # remove duplicates while preserving order
    final = []
    seen = set()
    for a, n in ordered:
        if (a, n) not in seen:
            final.append((a, n))
            seen.add((a, n))
    return final


def main():
    quiet = (len(sys.argv) > 1 and sys.argv[1] in ("-y", "--yes"))

    print("This script will attempt to repair inconsistent migration history by marking missing earlier\n"
          "migrations as applied (using --fake) where a later migration is already marked applied.)\n")
    if not quiet:
        print("Make a full database backup before proceeding. Press Ctrl-C to abort if you haven't.")
        input("Press Enter to continue or Ctrl-C to abort...")

    backup_path = ROOT / "django_migrations_backup.csv"
    try:
        ok = backup_django_migrations(backup_path)
    except Exception as e:
        print("Backup failed:", e, file=sys.stderr)
        ok = False
    if not ok:
        print("Warning: backup failed. Proceeding is risky. Aborting.")
        sys.exit(1)

    show = get_showmigrations()
    apps = parse_showmigrations(show)

    print("\nParsed migration status:")
    for app, migrations in apps.items():
        applied_count = sum(1 for _n, a in migrations if a)
        print(f" - {app}: {applied_count}/{len(migrations)} applied")

    initial_fixes = plan_fixes(apps)
    # build dependency-aware ordered fixes
    fixes = build_dependency_plan(apps, initial_fixes)

    if not fixes:
        print("No out-of-order applied migrations detected. Running regular migrate to ensure consistency.")
        run([sys.executable, str(MANAGE_PY), "migrate"])
        print("Done.")
        return

    print("\nPlanned fake-apply operations in dependency order (will mark these migrations as applied):")
    for app, name in fixes:
        print(f" - {app}: {name}")

    if not quiet:
        input('\nPress Enter to perform the above fake-applies, or Ctrl-C to abort...')

    # If Django refuses to run migrate due to inconsistent history we need to insert
    # the missing migration rows into django_migrations before calling migrate.
    # We'll use MigrationRecorder via manage.py shell to do this safely.
    def insert_migration_rows(rows: list[tuple[str, str]]):
        if not rows:
            return True
        # build a small script that creates missing MigrationRecorder entries
        py_lines = [
            "from django.db.migrations.recorder import MigrationRecorder",
            "from django.db import transaction",
            "rec = MigrationRecorder(None)",
            "created = 0",
        ]
        py_lines.append("with transaction.atomic():")
        for app, name in rows:
            # ensure we skip empty names or non-string
            py_lines.append(
                f"    if not MigrationRecorder.Migration.objects.filter(app=\"{app}\", name=\"{name}\").exists():"
            )
            py_lines.append(
                f"        MigrationRecorder.Migration.objects.create(app=\"{app}\", name=\"{name}\")"
            )
            py_lines.append("        created += 1")
        py_lines.append('print(\"MIGRATION_ROWS_CREATED:\", created)')

        script = '\n'.join(py_lines)
        cmd = [sys.executable, str(MANAGE_PY), "shell", "-c", script]
        res = subprocess.run(cmd, capture_output=True, text=True)
        print(res.stdout, end="")
        if res.stderr:
            print(res.stderr, end="", file=sys.stderr)
        return res.returncode == 0

    # Insert rows directly to satisfy dependency checks
    ok = insert_migration_rows(fixes)
    if not ok:
        print("Failed to insert migration rows via manage.py shell. Aborting.")
        sys.exit(1)

    # Finally run normal migrate
    try:
        run([sys.executable, str(MANAGE_PY), "migrate"])
    except subprocess.CalledProcessError:
        print("Final migrate failed. Inspect output above.")
        sys.exit(1)

    print("Migration history repair completed. If you still see errors, paste the full output here.")


if __name__ == '__main__':
    main()
