#!/usr/bin/env python
"""
Fix collation mismatches across all tables
"""
import pymysql
import os
from dotenv import load_dotenv

load_dotenv()

connection = pymysql.connect(
    host=os.getenv('DB_HOST', 'localhost'),
    user=os.getenv('DB_USER'),
    password=os.getenv('DB_PASSWORD'),
    database=os.getenv('DB_NAME'),
    charset='utf8mb4',
    cursorclass=pymysql.cursors.DictCursor
)

print("=" * 80)
print("FIXING COLLATION MISMATCHES")
print("=" * 80)

try:
    with connection.cursor() as cursor:
        # Get all tables
        cursor.execute("SHOW TABLES")
        tables = [list(row.values())[0] for row in cursor.fetchall()]
        
        print(f"\nFound {len(tables)} tables")
        print("\nChecking collations...")
        
        tables_to_fix = []
        
        for table in tables:
            cursor.execute(f"SHOW TABLE STATUS LIKE '{table}'")
            status = cursor.fetchone()
            collation = status['Collation']
            
            if collation and 'utf8mb4_unicode_ci' in collation:
                tables_to_fix.append(table)
                print(f"  ✗ {table}: {collation} (needs fixing)")
            elif collation:
                print(f"  ✓ {table}: {collation}")
        
        if tables_to_fix:
            print(f"\n{len(tables_to_fix)} tables need fixing")
            print("\nFixing collations...")
            
            for table in tables_to_fix:
                try:
                    cursor.execute(f"ALTER TABLE `{table}` CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci")
                    print(f"  ✓ Fixed {table}")
                except Exception as e:
                    print(f"  ✗ Error fixing {table}: {str(e)}")
            
            connection.commit()
            
            print("\n" + "=" * 80)
            print("✓ COLLATION FIX COMPLETE")
            print("=" * 80)
        else:
            print("\n✓ All tables already have correct collation")
        
except Exception as e:
    print(f"\n✗ Error: {str(e)}")
    connection.rollback()
finally:
    connection.close()
