#!/usr/bin/env python3
"""
Production Database Fix: User Permissions Table

This script fixes the database schema issue with the user_permissions table
that's causing the "Unknown column 'user_permissions.module'" error.

The issue occurs when the database table structure doesn't match the Django model
definition. This script ensures the table has the correct schema.

Usage:
    python production_database_fix.py [--dry-run] [--backup] [--migrate]

Author: AI Assistant
Date: 2024
"""

import os
import sys
import logging
import argparse
import subprocess
from datetime import datetime
from pathlib import Path

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('database_fix.log'),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

class ProductionDatabaseFixer:
    """Production database fixer for user permissions table"""
    
    def __init__(self, dry_run=False, backup=True, run_migrations=True):
        self.dry_run = dry_run
        self.backup = backup
        self.run_migrations = run_migrations
        self.fix_log = []
        self.start_time = datetime.now()
    
    def setup_django(self):
        """Setup Django environment"""
        try:
            import django
            from django.conf import settings
            
            # Configure Django
            os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'branch_system.settings')
            django.setup()
            
            logger.info("Django environment setup successful")
            return True
            
        except Exception as e:
            logger.error(f"Django setup failed: {str(e)}")
            return False
    
    def check_database_connection(self):
        """Check database connection"""
        try:
            from django.db import connection
            
            with connection.cursor() as cursor:
                cursor.execute("SELECT 1")
                result = cursor.fetchone()
                if result:
                    logger.info("Database connection successful")
                    return True
                    
        except Exception as e:
            logger.error(f"Database connection failed: {str(e)}")
            return False
        
        return False
    
    def get_table_info(self, table_name):
        """Get table information"""
        try:
            from django.db import connection
            
            with connection.cursor() as cursor:
                # Check if table exists
                cursor.execute(f"SHOW TABLES LIKE '{table_name}'")
                table_exists = cursor.fetchone()
                
                if not table_exists:
                    return {'exists': False, 'columns': []}
                
                # Get table structure
                cursor.execute(f"DESCRIBE {table_name}")
                columns = cursor.fetchall()
                
                return {'exists': True, 'columns': columns}
                
        except Exception as e:
            logger.error(f"Error getting table info for {table_name}: {str(e)}")
            return {'exists': False, 'columns': []}
    
    def backup_table(self, table_name):
        """Create backup of table"""
        if not self.backup:
            return True
            
        try:
            from django.db import connection
            
            backup_table_name = f"{table_name}_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
            
            with connection.cursor() as cursor:
                cursor.execute(f"CREATE TABLE {backup_table_name} AS SELECT * FROM {table_name}")
                logger.info(f"Created backup table: {backup_table_name}")
                self.fix_log.append(f"Backup created: {backup_table_name}")
                return True
                
        except Exception as e:
            logger.error(f"Error creating backup for {table_name}: {str(e)}")
            return False
    
    def create_user_permissions_table(self):
        """Create user_permissions table with correct structure"""
        logger.info("Creating user_permissions table...")
        
        if self.dry_run:
            logger.info("[DRY RUN] Would create user_permissions table")
            self.fix_log.append("user_permissions table: Would be created")
            return True
        
        try:
            from django.db import connection
            
            with connection.cursor() as cursor:
                create_sql = """
                CREATE TABLE user_permissions (
                    id BIGINT AUTO_INCREMENT PRIMARY KEY,
                    user_id BIGINT NOT NULL,
                    module VARCHAR(50) NOT NULL,
                    action VARCHAR(30) NOT NULL,
                    is_allowed BOOLEAN NOT NULL DEFAULT FALSE,
                    granted_by_id BIGINT NULL,
                    reason TEXT NULL,
                    expires_at DATETIME NULL,
                    created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
                    updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
                    UNIQUE KEY unique_user_permission (user_id, module, action),
                    KEY idx_user_expires (user_id, expires_at),
                    KEY idx_module_action (module, action),
                    CONSTRAINT fk_user_permissions_user 
                        FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE,
                    CONSTRAINT fk_user_permissions_granted_by 
                        FOREIGN KEY (granted_by_id) REFERENCES users(id) ON DELETE SET NULL
                ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
                """
                
                cursor.execute(create_sql)
                logger.info("Created user_permissions table successfully")
                self.fix_log.append("user_permissions table: Created successfully")
                return True
                
        except Exception as e:
            logger.error(f"Error creating user_permissions table: {str(e)}")
            return False
    
    def fix_existing_table(self, table_info):
        """Fix existing user_permissions table"""
        logger.info("Fixing existing user_permissions table...")
        
        columns = table_info['columns']
        column_names = [col[0] for col in columns]
        
        # Check what columns are missing
        required_columns = {
            'user_id': 'BIGINT NOT NULL',
            'module': 'VARCHAR(50) NOT NULL',
            'action': 'VARCHAR(30) NOT NULL',
            'is_allowed': 'BOOLEAN NOT NULL DEFAULT FALSE',
            'granted_by_id': 'BIGINT NULL',
            'reason': 'TEXT NULL',
            'expires_at': 'DATETIME NULL',
            'created_at': 'DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP',
            'updated_at': 'DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP'
        }
        
        missing_columns = []
        incorrect_columns = []
        
        for col_name, col_def in required_columns.items():
            if col_name not in column_names:
                missing_columns.append((col_name, col_def))
            else:
                # Check if existing column has correct definition
                for col in columns:
                    if col[0] == col_name:
                        if col_name == 'module' and 'varchar(20)' in col[1].lower():
                            incorrect_columns.append((col_name, col_def))
                        break
        
        if not missing_columns and not incorrect_columns:
            logger.info("user_permissions table structure is correct")
            self.fix_log.append("user_permissions table: Structure is correct")
            return True
        
        if self.dry_run:
            logger.info(f"[DRY RUN] Would fix {len(missing_columns)} missing columns and {len(incorrect_columns)} incorrect columns")
            self.fix_log.append("user_permissions table: Would be fixed")
            return True
        
        try:
            from django.db import connection
            
            with connection.cursor() as cursor:
                # Add missing columns
                for col_name, col_def in missing_columns:
                    try:
                        cursor.execute(f"ALTER TABLE user_permissions ADD COLUMN {col_name} {col_def}")
                        logger.info(f"Added column: {col_name}")
                    except Exception as e:
                        logger.warning(f"Could not add column {col_name}: {str(e)}")
                
                # Fix incorrect columns
                for col_name, col_def in incorrect_columns:
                    try:
                        cursor.execute(f"ALTER TABLE user_permissions MODIFY COLUMN {col_name} {col_def}")
                        logger.info(f"Fixed column: {col_name}")
                    except Exception as e:
                        logger.warning(f"Could not fix column {col_name}: {str(e)}")
                
                # Add constraints if they don't exist
                try:
                    cursor.execute("ALTER TABLE user_permissions ADD UNIQUE KEY unique_user_permission (user_id, module, action)")
                except:
                    pass  # Constraint might already exist
                
                try:
                    cursor.execute("ALTER TABLE user_permissions ADD KEY idx_user_expires (user_id, expires_at)")
                except:
                    pass
                
                try:
                    cursor.execute("ALTER TABLE user_permissions ADD KEY idx_module_action (module, action)")
                except:
                    pass
                
                logger.info("Fixed user_permissions table successfully")
                self.fix_log.append("user_permissions table: Fixed successfully")
                return True
                
        except Exception as e:
            logger.error(f"Error fixing user_permissions table: {str(e)}")
            return False
    
    def run_django_migrations(self):
        """Run Django migrations"""
        if not self.run_migrations:
            return True
            
        logger.info("Running Django migrations...")
        
        if self.dry_run:
            logger.info("[DRY RUN] Would run Django migrations")
            self.fix_log.append("Django migrations: Would be run")
            return True
        
        try:
            # Run migrations
            result = subprocess.run([
                sys.executable, 'manage.py', 'migrate', 'users', '--verbosity=0'
            ], capture_output=True, text=True, cwd=Path.cwd())
            
            if result.returncode == 0:
                logger.info("Django migrations completed successfully")
                self.fix_log.append("Django migrations: Completed successfully")
                return True
            else:
                logger.error(f"Django migrations failed: {result.stderr}")
                self.fix_log.append("Django migrations: Failed")
                return False
                
        except Exception as e:
            logger.error(f"Error running Django migrations: {str(e)}")
            return False
    
    def verify_fix(self):
        """Verify that the fix worked"""
        logger.info("Verifying fix...")
        
        try:
            from django.db import connection
            
            with connection.cursor() as cursor:
                # Try to query the table with the module column
                cursor.execute("SELECT COUNT(*) FROM user_permissions WHERE module = 'test'")
                result = cursor.fetchone()
                logger.info("✅ user_permissions table is working correctly")
                self.fix_log.append("Verification: Passed")
                return True
                
        except Exception as e:
            logger.error(f"❌ Verification failed: {str(e)}")
            self.fix_log.append("Verification: Failed")
            return False
    
    def fix(self):
        """Main fix method"""
        logger.info("=" * 80)
        logger.info("PRODUCTION DATABASE FIX: USER PERMISSIONS TABLE")
        logger.info("=" * 80)
        logger.info(f"Dry run: {self.dry_run}")
        logger.info(f"Backup: {self.backup}")
        logger.info(f"Run migrations: {self.run_migrations}")
        logger.info(f"Start time: {self.start_time}")
        logger.info("=" * 80)
        
        try:
            # Setup Django
            if not self.setup_django():
                return False
            
            # Check database connection
            if not self.check_database_connection():
                return False
            
            # Get table information
            table_info = self.get_table_info('user_permissions')
            
            if not table_info['exists']:
                # Create table
                if not self.create_user_permissions_table():
                    return False
            else:
                # Backup existing table
                if self.backup:
                    self.backup_table('user_permissions')
                
                # Fix existing table
                if not self.fix_existing_table(table_info):
                    return False
            
            # Run migrations
            if not self.run_django_migrations():
                return False
            
            # Verify fix
            if not self.verify_fix():
                return False
            
            # Calculate fix time
            end_time = datetime.now()
            fix_time = end_time - self.start_time
            
            # Log summary
            logger.info("=" * 80)
            logger.info("FIX SUMMARY")
            logger.info("=" * 80)
            for log_entry in self.fix_log:
                logger.info(f"  {log_entry}")
            logger.info("=" * 80)
            logger.info(f"Fix time: {fix_time}")
            logger.info(f"End time: {end_time}")
            
            if self.dry_run:
                logger.info("DRY RUN COMPLETED - No changes were made")
            else:
                logger.info("DATABASE FIX COMPLETED SUCCESSFULLY!")
                logger.info("The user_permissions table should now work correctly.")
            
            return True
            
        except Exception as e:
            logger.error(f"Database fix failed: {str(e)}")
            return False


def main():
    """Main function with argument parsing"""
    parser = argparse.ArgumentParser(
        description='Fix user_permissions database table',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog='''
Examples:
  python production_database_fix.py                    # Fix with backup and migrations
  python production_database_fix.py --dry-run          # Test fix
  python production_database_fix.py --no-backup        # Fix without backup
  python production_database_fix.py --no-migrate       # Fix without running migrations
        '''
    )
    
    parser.add_argument('--dry-run', action='store_true',
                       help='Show what would be done without making changes')
    parser.add_argument('--no-backup', action='store_true',
                       help='Skip creating backup (not recommended)')
    parser.add_argument('--no-migrate', action='store_true',
                       help='Skip running Django migrations')
    
    args = parser.parse_args()
    
    try:
        fixer = ProductionDatabaseFixer(
            dry_run=args.dry_run,
            backup=not args.no_backup,
            run_migrations=not args.no_migrate
        )
        
        success = fixer.fix()
        sys.exit(0 if success else 1)
        
    except Exception as e:
        logger.error(f"Script failed: {str(e)}")
        sys.exit(1)


if __name__ == '__main__':
    main()
