#!/usr/bin/env python
"""
Comprehensive Database Deployment Script
Handles migrations, schema changes, and data integrity checks
"""

import os
import sys
import subprocess
import django
from pathlib import Path

# Setup Django
BASE_DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(BASE_DIR))
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'branch_system.settings')
django.setup()

from django.core.management import call_command
from django.db import connection, connections
from django.conf import settings

class DatabaseDeployment:
    def __init__(self):
        self.errors = []
        self.warnings = []
        self.success_messages = []
        
    def log_success(self, message):
        print(f"✓ {message}")
        self.success_messages.append(message)
        
    def log_warning(self, message):
        print(f"⚠ WARNING: {message}")
        self.warnings.append(message)
        
    def log_error(self, message):
        print(f"✗ ERROR: {message}")
        self.errors.append(message)
    
    def check_database_connection(self):
        """Verify database connectivity"""
        print("\n=== Checking Database Connection ===")
        try:
            with connection.cursor() as cursor:
                cursor.execute("SELECT 1")
                self.log_success("Database connection successful")
                return True
        except Exception as e:
            self.log_error(f"Database connection failed: {e}")
            return False
    
    def backup_database(self):
        """Create database backup before changes"""
        print("\n=== Creating Database Backup ===")
        try:
            db_settings = settings.DATABASES['default']
            db_name = db_settings['NAME']
            db_user = db_settings['USER']
            db_password = db_settings['PASSWORD']
            db_host = db_settings.get('HOST', 'localhost')
            
            backup_file = f"backup_{db_name}_{os.getpid()}.sql"
            
            # For MySQL/MariaDB
            if 'mysql' in db_settings['ENGINE']:
                cmd = f"mysqldump -h {db_host} -u {db_user} -p{db_password} {db_name} > {backup_file}"
                self.log_warning(f"Backup command ready (run manually if needed): {cmd}")
                self.log_warning("Automated backup skipped - run manually for production")
            else:
                self.log_warning("Backup not implemented for this database engine")
                
            return True
        except Exception as e:
            self.log_warning(f"Backup creation skipped: {e}")
            return True  # Don't fail deployment on backup issues
    
    def check_pending_migrations(self):
        """Check for unapplied migrations"""
        print("\n=== Checking Pending Migrations ===")
        try:
            from django.db.migrations.executor import MigrationExecutor
            executor = MigrationExecutor(connection)
            plan = executor.migration_plan(executor.loader.graph.leaf_nodes())
            
            if plan:
                self.log_warning(f"Found {len(plan)} pending migrations")
                for migration, backwards in plan[:5]:
                    print(f"  - {migration.app_label}.{migration.name}")
                if len(plan) > 5:
                    print(f"  ... and {len(plan) - 5} more")
                return True
            else:
                self.log_success("No pending migrations")
                return False
        except Exception as e:
            self.log_error(f"Failed to check migrations: {e}")
            return False
    
    def apply_migrations(self):
        """Apply Django migrations"""
        print("\n=== Applying Migrations ===")
        try:
            call_command('migrate', '--noinput', verbosity=2)
            self.log_success("Migrations applied successfully")
            return True
        except Exception as e:
            self.log_error(f"Migration failed: {e}")
            return False
    
    def check_migration_conflicts(self):
        """Check for migration conflicts"""
        print("\n=== Checking Migration Conflicts ===")
        try:
            from django.db.migrations.loader import MigrationLoader
            loader = MigrationLoader(connection)
            conflicts = loader.detect_conflicts()
            
            if conflicts:
                self.log_warning(f"Migration conflicts detected: {conflicts}")
            else:
                self.log_success("No migration conflicts detected")
            return True
        except Exception as e:
            self.log_warning(f"Could not check conflicts: {e}")
            return True  # Don't fail on this
    
    def verify_critical_tables(self):
        """Verify critical tables exist"""
        print("\n=== Verifying Critical Tables ===")
        
        critical_tables = [
            'users_customuser',
            'loans_loan',
            'loans_loanrepayment',
            'payments_mpesamessage',
            'utils_branch',
        ]
        
        try:
            with connection.cursor() as cursor:
                cursor.execute("SHOW TABLES")
                existing_tables = [row[0] for row in cursor.fetchall()]
                
                missing_tables = []
                for table in critical_tables:
                    if table not in existing_tables:
                        missing_tables.append(table)
                
                if missing_tables:
                    self.log_warning(f"Missing tables: {', '.join(missing_tables)}")
                else:
                    self.log_success("All critical tables exist")
                    
            return True
        except Exception as e:
            self.log_error(f"Table verification failed: {e}")
            return False
    
    def check_foreign_key_constraints(self):
        """Check foreign key constraint issues"""
        print("\n=== Checking Foreign Key Constraints ===")
        try:
            with connection.cursor() as cursor:
                # Check for broken foreign keys
                cursor.execute("""
                    SELECT TABLE_NAME, CONSTRAINT_NAME 
                    FROM information_schema.TABLE_CONSTRAINTS 
                    WHERE CONSTRAINT_TYPE = 'FOREIGN KEY' 
                    AND TABLE_SCHEMA = DATABASE()
                    LIMIT 10
                """)
                fk_count = len(cursor.fetchall())
                self.log_success(f"Found {fk_count} foreign key constraints")
            return True
        except Exception as e:
            self.log_warning(f"Foreign key check skipped: {e}")
            return True
    
    def fix_common_issues(self):
        """Fix common database issues"""
        print("\n=== Fixing Common Issues ===")
        
        fixes = []
        
        try:
            with connection.cursor() as cursor:
                # Fix missing columns that are commonly needed
                common_fixes = [
                    ("ALTER TABLE loans_loan ADD COLUMN IF NOT EXISTS is_deleted TINYINT(1) DEFAULT 0", 
                     "is_deleted column"),
                    ("ALTER TABLE loans_loan ADD COLUMN IF NOT EXISTS deleted_at DATETIME NULL", 
                     "deleted_at column"),
                    ("ALTER TABLE utils_branch ADD COLUMN IF NOT EXISTS is_default TINYINT(1) DEFAULT 0", 
                     "is_default column"),
                ]
                
                for sql, description in common_fixes:
                    try:
                        cursor.execute(sql)
                        fixes.append(description)
                    except Exception:
                        pass  # Column might already exist
                
                if fixes:
                    self.log_success(f"Applied fixes: {', '.join(fixes)}")
                else:
                    self.log_success("No common fixes needed")
                    
            return True
        except Exception as e:
            self.log_warning(f"Common fixes skipped: {e}")
            return True
    
    def collect_static_files(self):
        """Collect static files"""
        print("\n=== Collecting Static Files ===")
        try:
            call_command('collectstatic', '--noinput', '--clear', verbosity=0)
            self.log_success("Static files collected")
            return True
        except Exception as e:
            self.log_warning(f"Static collection failed: {e}")
            return True  # Don't fail deployment
    
    def run_system_checks(self):
        """Run Django system checks"""
        print("\n=== Running System Checks ===")
        try:
            call_command('check')
            self.log_success("System checks passed")
            return True
        except Exception as e:
            self.log_warning(f"System checks found issues: {e}")
            return True  # Don't fail on warnings
    
    def print_summary(self):
        """Print deployment summary"""
        print("\n" + "="*60)
        print("DEPLOYMENT SUMMARY")
        print("="*60)
        
        if self.success_messages:
            print(f"\n✓ Successes ({len(self.success_messages)}):")
            for msg in self.success_messages:
                print(f"  - {msg}")
        
        if self.warnings:
            print(f"\n⚠ Warnings ({len(self.warnings)}):")
            for msg in self.warnings:
                print(f"  - {msg}")
        
        if self.errors:
            print(f"\n✗ Errors ({len(self.errors)}):")
            for msg in self.errors:
                print(f"  - {msg}")
        
        print("\n" + "="*60)
        
        if self.errors:
            print("❌ DEPLOYMENT FAILED")
            return False
        elif self.warnings:
            print("⚠️  DEPLOYMENT COMPLETED WITH WARNINGS")
            return True
        else:
            print("✅ DEPLOYMENT SUCCESSFUL")
            return True
    
    def deploy(self):
        """Run full deployment process"""
        print("="*60)
        print("DATABASE DEPLOYMENT SCRIPT")
        print("="*60)
        
        steps = [
            ("Database Connection", self.check_database_connection),
            ("Database Backup", self.backup_database),
            ("Migration Conflicts", self.check_migration_conflicts),
            ("Pending Migrations", self.check_pending_migrations),
            ("Apply Migrations", self.apply_migrations),
            ("Verify Tables", self.verify_critical_tables),
            ("Foreign Keys", self.check_foreign_key_constraints),
            ("Common Fixes", self.fix_common_issues),
            ("Static Files", self.collect_static_files),
            ("System Checks", self.run_system_checks),
        ]
        
        for step_name, step_func in steps:
            try:
                result = step_func()
                if not result and step_name in ["Database Connection", "Apply Migrations"]:
                    print(f"\n❌ Critical step '{step_name}' failed. Stopping deployment.")
                    break
            except Exception as e:
                self.log_error(f"Step '{step_name}' crashed: {e}")
                if step_name in ["Database Connection", "Apply Migrations"]:
                    break
        
        return self.print_summary()

if __name__ == "__main__":
    deployer = DatabaseDeployment()
    success = deployer.deploy()
    sys.exit(0 if success else 1)
