#!/usr/bin/env python3
"""
Comprehensive MySQL Collation Fix Script
Fixes collation conflicts by standardizing all tables to utf8mb4_unicode_ci
"""

import os
import sys
import django
from django.conf import settings
from django.db import connection, transaction
import logging

# Setup Django
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'branch_system.settings')
django.setup()

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('collation_fix.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

def get_database_tables():
    """Get all tables in the database"""
    with connection.cursor() as cursor:
        cursor.execute("SHOW TABLES")
        return [table[0] for table in cursor.fetchall()]

def get_table_collation(table_name):
    """Get current collation for a table"""
    with connection.cursor() as cursor:
        cursor.execute(f"""
            SELECT TABLE_COLLATION 
            FROM information_schema.TABLES 
            WHERE TABLE_SCHEMA = DATABASE() 
            AND TABLE_NAME = %s
        """, [table_name])
        result = cursor.fetchone()
        return result[0] if result else None

def get_column_collations(table_name):
    """Get collations for all text columns in a table"""
    with connection.cursor() as cursor:
        cursor.execute(f"""
            SELECT COLUMN_NAME, COLLATION_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT
            FROM information_schema.COLUMNS 
            WHERE TABLE_SCHEMA = DATABASE() 
            AND TABLE_NAME = %s 
            AND COLLATION_NAME IS NOT NULL
        """, [table_name])
        return cursor.fetchall()

def fix_table_collation(table_name):
    """Fix collation for a specific table"""
    logger.info(f"Fixing collation for table: {table_name}")
    
    try:
        with connection.cursor() as cursor:
            # First, convert the table to utf8mb4_unicode_ci
            cursor.execute(f"ALTER TABLE `{table_name}` CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")
            logger.info(f"✓ Converted table {table_name} to utf8mb4_unicode_ci")
            
            # Get column information after conversion
            columns = get_column_collations(table_name)
            
            # Fix any remaining column collation issues
            for column_name, collation, data_type, is_nullable, default_val in columns:
                if collation and collation != 'utf8mb4_unicode_ci':
                    nullable = "NULL" if is_nullable == "YES" else "NOT NULL"
                    default_clause = f"DEFAULT '{default_val}'" if default_val else ""
                    
                    cursor.execute(f"""
                        ALTER TABLE `{table_name}` 
                        MODIFY COLUMN `{column_name}` {data_type} 
                        CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci 
                        {nullable} {default_clause}
                    """)
                    logger.info(f"✓ Fixed column {table_name}.{column_name}")
                    
    except Exception as e:
        logger.error(f"✗ Error fixing table {table_name}: {str(e)}")
        return False
    
    return True

def fix_database_collation():
    """Fix the default database collation"""
    try:
        with connection.cursor() as cursor:
            # Get current database name
            cursor.execute("SELECT DATABASE()")
            db_name = cursor.fetchone()[0]
            
            # Set database default collation
            cursor.execute(f"ALTER DATABASE `{db_name}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")
            logger.info(f"✓ Set database {db_name} default collation to utf8mb4_unicode_ci")
            
    except Exception as e:
        logger.error(f"✗ Error setting database collation: {str(e)}")
        return False
    
    return True

def create_missing_tables():
    """Create any missing tables that are referenced in the errors"""
    missing_tables = [
        'portfolio_performance'
    ]
    
    for table_name in missing_tables:
        try:
            with connection.cursor() as cursor:
                # Check if table exists
                cursor.execute(f"SHOW TABLES LIKE '{table_name}'")
                if not cursor.fetchone():
                    logger.info(f"Creating missing table: {table_name}")
                    
                    if table_name == 'portfolio_performance':
                        cursor.execute(f"""
                            CREATE TABLE `{table_name}` (
                                `id` bigint NOT NULL AUTO_INCREMENT,
                                `date` date NOT NULL,
                                `total_loans` int DEFAULT 0,
                                `active_loans` int DEFAULT 0,
                                `overdue_loans` int DEFAULT 0,
                                `total_amount` decimal(15,2) DEFAULT 0.00,
                                `collected_amount` decimal(15,2) DEFAULT 0.00,
                                `outstanding_amount` decimal(15,2) DEFAULT 0.00,
                                `created_at` datetime(6) NOT NULL,
                                `updated_at` datetime(6) NOT NULL,
                                PRIMARY KEY (`id`),
                                UNIQUE KEY `date` (`date`)
                            ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
                        """)
                        logger.info(f"✓ Created table {table_name}")
                        
        except Exception as e:
            logger.error(f"✗ Error creating table {table_name}: {str(e)}")

def main():
    """Main execution function"""
    logger.info("Starting comprehensive collation fix...")
    
    try:
        # Fix database default collation
        logger.info("Step 1: Fixing database collation...")
        fix_database_collation()
        
        # Create missing tables
        logger.info("Step 2: Creating missing tables...")
        create_missing_tables()
        
        # Get all tables
        logger.info("Step 3: Getting all database tables...")
        tables = get_database_tables()
        logger.info(f"Found {len(tables)} tables to process")
        
        # Fix each table
        logger.info("Step 4: Fixing table collations...")
        success_count = 0
        error_count = 0
        
        for table in tables:
            current_collation = get_table_collation(table)
            logger.info(f"Processing {table} (current: {current_collation})")
            
            if current_collation != 'utf8mb4_unicode_ci':
                if fix_table_collation(table):
                    success_count += 1
                else:
                    error_count += 1
            else:
                logger.info(f"✓ Table {table} already has correct collation")
                success_count += 1
        
        # Summary
        logger.info("="*50)
        logger.info("COLLATION FIX SUMMARY")
        logger.info("="*50)
        logger.info(f"Total tables processed: {len(tables)}")
        logger.info(f"Successfully fixed: {success_count}")
        logger.info(f"Errors encountered: {error_count}")
        
        if error_count == 0:
            logger.info("✓ All collation issues have been resolved!")
            logger.info("✓ Your Django application should now work without collation errors")
        else:
            logger.warning(f"⚠ {error_count} tables had issues. Check the log for details.")
            
    except Exception as e:
        logger.error(f"Critical error during collation fix: {str(e)}")
        sys.exit(1)

if __name__ == "__main__":
    main()