#!/usr/bin/env python3
"""
Production Fix Script for Branch System Issues
This script fixes the OperationalError: Unknown column 'users.branch_id' in 'field list'
and ensures all branch-related functionality works correctly.
"""

import os
import sys
import django
from django.db import connection, transaction
from django.core.management import execute_from_command_line
from django.conf import settings

# Setup Django environment
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'branch_system.settings')
django.setup()

from django.core.management.base import BaseCommand
from users.models import CustomUser, Branch
from django.db import models
from django.core.management import call_command
import uuid

def log_message(message, level='INFO'):
    """Log messages with timestamp"""
    from datetime import datetime
    timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print(f"[{timestamp}] {level}: {message}")

def check_database_connection():
    """Check if database connection is working"""
    try:
        with connection.cursor() as cursor:
            cursor.execute("SELECT 1")
            log_message("Database connection successful")
            return True
    except Exception as e:
        log_message(f"Database connection failed: {e}", 'ERROR')
        return False

def check_table_exists(table_name):
    """Check if a table exists in the database"""
    try:
        with connection.cursor() as cursor:
            cursor.execute(f"""
                SELECT COUNT(*) 
                FROM INFORMATION_SCHEMA.TABLES 
                WHERE TABLE_SCHEMA = DATABASE() 
                AND TABLE_NAME = '{table_name}'
            """)
            return cursor.fetchone()[0] > 0
    except Exception as e:
        log_message(f"Error checking table {table_name}: {e}", 'ERROR')
        return False

def check_column_exists(table_name, column_name):
    """Check if a column exists in a table"""
    try:
        with connection.cursor() as cursor:
            cursor.execute(f"""
                SELECT COUNT(*) 
                FROM INFORMATION_SCHEMA.COLUMNS 
                WHERE TABLE_SCHEMA = DATABASE() 
                AND TABLE_NAME = '{table_name}' 
                AND COLUMN_NAME = '{column_name}'
            """)
            return cursor.fetchone()[0] > 0
    except Exception as e:
        log_message(f"Error checking column {table_name}.{column_name}: {e}", 'ERROR')
        return False

def create_branch_table():
    """Create the Branch table if it doesn't exist"""
    try:
        if not check_table_exists('users_branch'):
            log_message("Creating Branch table...")
            with connection.cursor() as cursor:
                cursor.execute("""
                    CREATE TABLE `users_branch` (
                        `id` char(32) NOT NULL,
                        `name` varchar(100) NOT NULL,
                        `code` varchar(20) NOT NULL UNIQUE,
                        `address` longtext,
                        `phone_number` varchar(20),
                        `email` varchar(254),
                        `is_main_branch` tinyint(1) NOT NULL DEFAULT 0,
                        `is_active` tinyint(1) NOT NULL DEFAULT 1,
                        `created_at` datetime(6) NOT NULL,
                        `updated_at` datetime(6) NOT NULL,
                        PRIMARY KEY (`id`)
                    ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
                """)
            log_message("Branch table created successfully")
        else:
            log_message("Branch table already exists")
        return True
    except Exception as e:
        log_message(f"Error creating Branch table: {e}", 'ERROR')
        return False

def add_branch_fields_to_users():
    """Add branch and accessible_branches fields to users table"""
    try:
        # Check if branch_id column exists
        if not check_column_exists('users', 'branch_id'):
            log_message("Adding branch_id column to users table...")
            with connection.cursor() as cursor:
                cursor.execute("""
                    ALTER TABLE `users` 
                    ADD COLUMN `branch_id` char(32) NULL,
                    ADD CONSTRAINT `users_branch_id_fk` 
                    FOREIGN KEY (`branch_id`) REFERENCES `users_branch` (`id`) 
                    ON DELETE SET NULL
                """)
            log_message("branch_id column added successfully")
        else:
            log_message("branch_id column already exists")
        
        # Check if accessible_branches table exists
        if not check_table_exists('users_customuser_accessible_branches'):
            log_message("Creating accessible_branches many-to-many table...")
            with connection.cursor() as cursor:
                cursor.execute("""
                    CREATE TABLE `users_customuser_accessible_branches` (
                        `id` bigint(20) NOT NULL AUTO_INCREMENT,
                        `customuser_id` char(32) NOT NULL,
                        `branch_id` char(32) NOT NULL,
                        PRIMARY KEY (`id`),
                        UNIQUE KEY `users_customuser_accessi_customuser_id_branch_id_unique` (`customuser_id`, `branch_id`),
                        KEY `users_customuser_accessible_branches_branch_id_fk` (`branch_id`),
                        CONSTRAINT `users_customuser_accessible_branches_customuser_id_fk` 
                        FOREIGN KEY (`customuser_id`) REFERENCES `users` (`id`) ON DELETE CASCADE,
                        CONSTRAINT `users_customuser_accessible_branches_branch_id_fk` 
                        FOREIGN KEY (`branch_id`) REFERENCES `users_branch` (`id`) ON DELETE CASCADE
                    ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
                """)
            log_message("accessible_branches table created successfully")
        else:
            log_message("accessible_branches table already exists")
        
        return True
    except Exception as e:
        log_message(f"Error adding branch fields to users table: {e}", 'ERROR')
        return False

def create_main_branch():
    """Create a main branch if none exists"""
    try:
        # Check if any branch exists
        if not check_table_exists('users_branch'):
            log_message("Branch table doesn't exist, skipping branch creation", 'WARNING')
            return False
            
        with connection.cursor() as cursor:
            cursor.execute("SELECT COUNT(*) FROM users_branch WHERE is_main_branch = 1")
            main_branch_count = cursor.fetchone()[0]
            
            if main_branch_count == 0:
                log_message("Creating main branch...")
                branch_id = str(uuid.uuid4()).replace('-', '')
                from datetime import datetime
                now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
                
                cursor.execute("""
                    INSERT INTO users_branch 
                    (id, name, code, address, phone_number, email, is_main_branch, is_active, created_at, updated_at)
                    VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                """, (
                    branch_id,
                    'Main Branch',
                    'MAIN',
                    'Head Office',
                    '+254700000000',
                    'info@branchbusinessadvance.co.ke',
                    1,
                    1,
                    now,
                    now
                ))
                log_message(f"Main branch created with ID: {branch_id}")
                return branch_id
            else:
                cursor.execute("SELECT id FROM users_branch WHERE is_main_branch = 1 LIMIT 1")
                branch_id = cursor.fetchone()[0]
                log_message(f"Main branch already exists with ID: {branch_id}")
                return branch_id
                
    except Exception as e:
        log_message(f"Error creating main branch: {e}", 'ERROR')
        return None

def assign_users_to_main_branch(main_branch_id):
    """Assign all users without a branch to the main branch"""
    try:
        if not main_branch_id:
            log_message("No main branch ID provided", 'ERROR')
            return False
            
        with connection.cursor() as cursor:
            # Count users without branch
            cursor.execute("SELECT COUNT(*) FROM users WHERE branch_id IS NULL")
            users_without_branch = cursor.fetchone()[0]
            
            if users_without_branch > 0:
                log_message(f"Assigning {users_without_branch} users to main branch...")
                cursor.execute(
                    "UPDATE users SET branch_id = %s WHERE branch_id IS NULL",
                    [main_branch_id]
                )
                log_message(f"Successfully assigned {users_without_branch} users to main branch")
            else:
                log_message("All users already have a branch assigned")
                
            # Give admin users access to all branches
            cursor.execute("SELECT id FROM users WHERE role = 'admin'")
            admin_users = cursor.fetchall()
            
            cursor.execute("SELECT id FROM users_branch")
            all_branches = cursor.fetchall()
            
            for admin_user in admin_users:
                admin_id = admin_user[0]
                for branch in all_branches:
                    branch_id = branch[0]
                    # Check if relationship already exists
                    cursor.execute(
                        "SELECT COUNT(*) FROM users_customuser_accessible_branches WHERE customuser_id = %s AND branch_id = %s",
                        [admin_id, branch_id]
                    )
                    if cursor.fetchone()[0] == 0:
                        cursor.execute(
                            "INSERT INTO users_customuser_accessible_branches (customuser_id, branch_id) VALUES (%s, %s)",
                            [admin_id, branch_id]
                        )
            
            log_message(f"Gave {len(admin_users)} admin users access to all branches")
            return True
            
    except Exception as e:
        log_message(f"Error assigning users to main branch: {e}", 'ERROR')
        return False

def run_pending_migrations():
    """Run any pending Django migrations"""
    try:
        log_message("Running pending migrations...")
        call_command('migrate', verbosity=1, interactive=False)
        log_message("Migrations completed successfully")
        return True
    except Exception as e:
        log_message(f"Error running migrations: {e}", 'ERROR')
        return False

def verify_fix():
    """Verify that the fix was successful"""
    try:
        log_message("Verifying fix...")
        
        # Test basic queries that were failing
        with connection.cursor() as cursor:
            # Test branch_id column access
            cursor.execute("SELECT COUNT(*) FROM users WHERE branch_id IS NOT NULL")
            users_with_branch = cursor.fetchone()[0]
            log_message(f"Users with branch assigned: {users_with_branch}")
            
            # Test branch table
            cursor.execute("SELECT COUNT(*) FROM users_branch")
            total_branches = cursor.fetchone()[0]
            log_message(f"Total branches: {total_branches}")
            
            # Test accessible_branches table
            cursor.execute("SELECT COUNT(*) FROM users_customuser_accessible_branches")
            accessible_relations = cursor.fetchone()[0]
            log_message(f"Accessible branch relations: {accessible_relations}")
            
        # Test Django ORM queries
        try:
            from users.models import CustomUser, Branch
            
            # Test filtering by branch_id (this was causing the original error)
            test_users = CustomUser.objects.filter(branch_id__isnull=False).count()
            log_message(f"Django ORM: Users with branch (via branch_id filter): {test_users}")
            
            # Test branch relationship
            test_branches = Branch.objects.count()
            log_message(f"Django ORM: Total branches: {test_branches}")
            
            log_message("✅ All verification tests passed!")
            return True
            
        except Exception as orm_error:
            log_message(f"Django ORM test failed: {orm_error}", 'ERROR')
            return False
            
    except Exception as e:
        log_message(f"Verification failed: {e}", 'ERROR')
        return False

def main():
    """Main execution function"""
    log_message("=" * 60)
    log_message("BRANCH SYSTEM PRODUCTION FIX SCRIPT")
    log_message("=" * 60)
    
    # Step 1: Check database connection
    if not check_database_connection():
        log_message("Cannot proceed without database connection", 'ERROR')
        return False
    
    try:
        with transaction.atomic():
            # Step 2: Create branch table
            if not create_branch_table():
                log_message("Failed to create branch table", 'ERROR')
                return False
            
            # Step 3: Add branch fields to users table
            if not add_branch_fields_to_users():
                log_message("Failed to add branch fields to users table", 'ERROR')
                return False
            
            # Step 4: Run pending migrations
            if not run_pending_migrations():
                log_message("Failed to run migrations", 'ERROR')
                return False
            
            # Step 5: Create main branch
            main_branch_id = create_main_branch()
            if not main_branch_id:
                log_message("Failed to create main branch", 'ERROR')
                return False
            
            # Step 6: Assign users to main branch
            if not assign_users_to_main_branch(main_branch_id):
                log_message("Failed to assign users to main branch", 'ERROR')
                return False
            
        # Step 7: Verify the fix
        if not verify_fix():
            log_message("Fix verification failed", 'ERROR')
            return False
        
        log_message("=" * 60)
        log_message("✅ PRODUCTION FIX COMPLETED SUCCESSFULLY!")
        log_message("=" * 60)
        log_message("")
        log_message("Next steps:")
        log_message("1. Restart your Django application server")
        log_message("2. Clear any cached data if applicable")
        log_message("3. Test the login functionality")
        log_message("4. Monitor logs for any remaining issues")
        log_message("")
        return True
        
    except Exception as e:
        log_message(f"Critical error during fix: {e}", 'ERROR')
        return False

if __name__ == '__main__':
    success = main()
    sys.exit(0 if success else 1)