import os
import django
from django.db import connection

# Setup Django environment
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'branch_system.settings_production')
django.setup()

def fix_all_customuser_references():
    """
    Fix all tables that reference customuser_id
    """
    with connection.cursor() as cursor:
        try:
            tables_to_fix = [
                'users_groups',
                'users_user_permissions',
                'report_schedules_recipients'
            ]
            
            for table in tables_to_fix:
                print(f"\nFixing table: {table}")
                
                # Check if table exists
                cursor.execute(f"""
                    SELECT COUNT(*)
                    FROM information_schema.tables
                    WHERE table_schema = DATABASE()
                    AND table_name = '{table}'
                """)
                if cursor.fetchone()[0] == 0:
                    print(f"Table {table} does not exist, skipping...")
                    continue
                
                # Check column names
                cursor.execute(f"""
                    SELECT COLUMN_NAME
                    FROM information_schema.COLUMNS
                    WHERE TABLE_SCHEMA = DATABASE()
                    AND TABLE_NAME = '{table}'
                    AND COLUMN_NAME IN ('customuser_id', 'customuser')
                """)
                old_column = cursor.fetchone()
                
                if old_column:
                    old_column_name = old_column[0]
                    print(f"Found column {old_column_name}")
                    
                    # Drop any indexes that might conflict
                    cursor.execute(f"""
                        SELECT INDEX_NAME
                        FROM information_schema.STATISTICS
                        WHERE TABLE_SCHEMA = DATABASE()
                        AND TABLE_NAME = '{table}'
                        AND COLUMN_NAME = '{old_column_name}'
                        AND INDEX_NAME != 'PRIMARY'
                    """)
                    indexes = cursor.fetchall()
                    for idx in indexes:
                        try:
                            cursor.execute(f"""
                                DROP INDEX {idx[0]} ON {table}
                            """)
                            print(f"Dropped index {idx[0]}")
                        except Exception as e:
                            print(f"Warning: Could not drop index {idx[0]}: {e}")
                    
                    # Rename column
                    cursor.execute(f"""
                        ALTER TABLE {table}
                        CHANGE COLUMN {old_column_name} user_id char(32) NOT NULL
                    """)
                    print(f"Renamed column {old_column_name} to user_id")
                    
                    # Add foreign key if it doesn't exist
                    cursor.execute(f"""
                        SELECT COUNT(*)
                        FROM information_schema.KEY_COLUMN_USAGE
                        WHERE TABLE_SCHEMA = DATABASE()
                        AND TABLE_NAME = '{table}'
                        AND COLUMN_NAME = 'user_id'
                        AND REFERENCED_TABLE_NAME = 'users'
                    """)
                    if cursor.fetchone()[0] == 0:
                        try:
                            # First add index
                            cursor.execute(f"""
                                ALTER TABLE {table}
                                ADD INDEX {table}_user_id_fk (user_id)
                            """)
                            
                            # Then add foreign key
                            cursor.execute(f"""
                                ALTER TABLE {table}
                                ADD CONSTRAINT {table}_user_id_fk
                                FOREIGN KEY (user_id) REFERENCES users(id)
                            """)
                            print("Added foreign key constraint")
                        except Exception as e:
                            print(f"Warning: Could not add foreign key: {e}")
                
                # Verify the fix
                cursor.execute(f"""
                    SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE
                    FROM information_schema.COLUMNS
                    WHERE TABLE_SCHEMA = DATABASE()
                    AND TABLE_NAME = '{table}'
                    AND COLUMN_NAME = 'user_id'
                """)
                result = cursor.fetchone()
                if result:
                    print(f"✓ Verification: {table} has correct user_id column")
                    print(f"  Type: {result[1]}, Nullable: {result[2]}")
                else:
                    print(f"✗ Error: {table} structure is incorrect")

        except Exception as e:
            print(f"Error during fix: {e}")
            raise

if __name__ == '__main__':
    print("Starting comprehensive customuser fix...")
    fix_all_customuser_references()
    print("\nFix completed.")
