﻿"""
Database Schema Comparison Script
Compares current Django database schema with target Grazuri schema
"""

import mysql.connector
import re
from collections import defaultdict
from typing import Dict, List, Set, Tuple
import json

# Database connection settings
CURRENT_DB_CONFIG = {
    'host': 'localhost',
    'user': 'root',
    'password': 'password',
    'database': 'xygbfpsg_graz'
}

TARGET_SQL_FILE = r'c:\Users\Teamjoint company\Desktop\branchsystem\xygbfpsg_loans.sql'


def parse_sql_file(sql_file_path: str) -> Dict:
    """Parse SQL file to extract table structures"""
    with open(sql_file_path, 'r', encoding='utf-8', errors='ignore') as f:
        sql_content = f.read()
    
    tables = {}
    
    # Find all CREATE TABLE statements
    create_table_pattern = r'CREATE TABLE `?(\w+)`?\s*\((.*?)\)\s*ENGINE'
    matches = re.finditer(create_table_pattern, sql_content, re.DOTALL | re.IGNORECASE)
    
    for match in matches:
        table_name = match.group(1)
        table_def = match.group(2)
        
        columns = {}
        constraints = []
        
        # Parse column definitions
        lines = table_def.split('\n')
        for line in lines:
            line = line.strip().rstrip(',')
            if not line:
                continue
            
            # Skip constraints for now
            if line.upper().startswith(('PRIMARY KEY', 'KEY', 'UNIQUE KEY', 'CONSTRAINT', 'FOREIGN KEY', 'INDEX')):
                constraints.append(line)
                continue
            
            # Parse column definition
            parts = line.split()
            if len(parts) >= 2:
                col_name = parts[0].strip('`')
                col_type = parts[1].upper()
                
                # Extract additional attributes
                attributes = {
                    'type': col_type,
                    'nullable': 'NOT NULL' not in line.upper(),
                    'default': None,
                    'auto_increment': 'AUTO_INCREMENT' in line.upper(),
                    'definition': line
                }
                
                # Extract default value
                default_match = re.search(r"DEFAULT\s+(['\"].*?['\"]|\S+)", line, re.IGNORECASE)
                if default_match:
                    attributes['default'] = default_match.group(1)
                
                columns[col_name] = attributes
        
        tables[table_name] = {
            'columns': columns,
            'constraints': constraints
        }
    
    return tables


def get_current_schema(db_config: Dict) -> Dict:
    """Get current database schema from MySQL"""
    conn = mysql.connector.connect(**db_config)
    cursor = conn.cursor()
    
    tables = {}
    
    # Get all tables
    cursor.execute("SHOW TABLES")
    table_names = [row[0] for row in cursor.fetchall()]
    
    for table_name in table_names:
        # Get column information
        cursor.execute(f"DESCRIBE `{table_name}`")
        columns = {}
        
        for row in cursor.fetchall():
            col_name = row[0]
            col_type = row[1]
            nullable = row[2] == 'YES'
            key = row[3]
            default = row[4]
            extra = row[5]
            
            columns[col_name] = {
                'type': col_type.upper(),
                'nullable': nullable,
                'default': default,
                'auto_increment': 'auto_increment' in extra.lower(),
                'key': key
            }
        
        # Get foreign keys
        cursor.execute(f"""
            SELECT 
                COLUMN_NAME,
                REFERENCED_TABLE_NAME,
                REFERENCED_COLUMN_NAME
            FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
            WHERE TABLE_SCHEMA = '{db_config['database']}'
            AND TABLE_NAME = '{table_name}'
            AND REFERENCED_TABLE_NAME IS NOT NULL
        """)
        
        foreign_keys = []
        for row in cursor.fetchall():
            foreign_keys.append({
                'column': row[0],
                'referenced_table': row[1],
                'referenced_column': row[2]
            })
        
        tables[table_name] = {
            'columns': columns,
            'foreign_keys': foreign_keys
        }
    
    cursor.close()
    conn.close()
    
    return tables


def normalize_type(type_str: str) -> str:
    """Normalize MySQL data types for comparison"""
    type_str = type_str.upper()
    
    # Remove size specifications for comparison
    type_str = re.sub(r'\(\d+\)', '', type_str)
    type_str = re.sub(r'\(\d+,\d+\)', '', type_str)
    
    # Normalize common type variations
    type_mappings = {
        'INT': 'INTEGER',
        'TINYINT(1)': 'BOOLEAN',
        'DATETIME': 'TIMESTAMP',
    }
    
    for old, new in type_mappings.items():
        if old in type_str:
            type_str = type_str.replace(old, new)
    
    return type_str


def compare_schemas(current: Dict, target: Dict) -> Dict:
    """Compare current and target schemas"""
    comparison = {
        'missing_tables': [],
        'extra_tables': [],
        'missing_columns': defaultdict(list),
        'extra_columns': defaultdict(list),
        'type_differences': defaultdict(list),
        'nullable_differences': defaultdict(list),
        'summary': {}
    }
    
    current_tables = set(current.keys())
    target_tables = set(target.keys())
    
    # Find missing and extra tables
    comparison['missing_tables'] = sorted(list(target_tables - current_tables))
    comparison['extra_tables'] = sorted(list(current_tables - target_tables))
    
    # Compare common tables
    common_tables = current_tables & target_tables
    
    for table_name in sorted(common_tables):
        current_cols = set(current[table_name]['columns'].keys())
        target_cols = set(target[table_name]['columns'].keys())
        
        # Find missing and extra columns
        missing_cols = target_cols - current_cols
        extra_cols = current_cols - target_cols
        
        if missing_cols:
            comparison['missing_columns'][table_name] = sorted(list(missing_cols))
        
        if extra_cols:
            comparison['extra_columns'][table_name] = sorted(list(extra_cols))
        
        # Compare column types for common columns
        common_cols = current_cols & target_cols
        for col_name in sorted(common_cols):
            current_col = current[table_name]['columns'][col_name]
            target_col = target[table_name]['columns'][col_name]
            
            # Compare types
            current_type = normalize_type(current_col['type'])
            target_type = normalize_type(target_col['type'])
            
            if current_type != target_type:
                comparison['type_differences'][table_name].append({
                    'column': col_name,
                    'current_type': current_col['type'],
                    'target_type': target_col['type']
                })
            
            # Compare nullable
            if current_col.get('nullable') != target_col.get('nullable'):
                comparison['nullable_differences'][table_name].append({
                    'column': col_name,
                    'current_nullable': current_col.get('nullable'),
                    'target_nullable': target_col.get('nullable')
                })
    
    # Generate summary
    comparison['summary'] = {
        'total_missing_tables': len(comparison['missing_tables']),
        'total_extra_tables': len(comparison['extra_tables']),
        'total_tables_with_missing_columns': len(comparison['missing_columns']),
        'total_tables_with_extra_columns': len(comparison['extra_columns']),
        'total_tables_with_type_differences': len(comparison['type_differences']),
        'total_tables_with_nullable_differences': len(comparison['nullable_differences'])
    }
    
    return comparison


def generate_report(comparison: Dict, output_file: str):
    """Generate a detailed comparison report"""
    report_lines = []
    
    report_lines.append("=" * 80)
    report_lines.append("DATABASE SCHEMA COMPARISON REPORT")
    report_lines.append("=" * 80)
    report_lines.append("")
    report_lines.append("Current Database: xygbfpsg_graz (Django)")
    report_lines.append("Target Schema: xygbfpsg_loans.sql (Grazuri)")
    report_lines.append("")
    
    # Summary
    report_lines.append("=" * 80)
    report_lines.append("SUMMARY")
    report_lines.append("=" * 80)
    report_lines.append("")
    summary = comparison['summary']
    report_lines.append(f"Missing Tables in Current Schema: {summary['total_missing_tables']}")
    report_lines.append(f"Extra Tables in Current Schema: {summary['total_extra_tables']}")
    report_lines.append(f"Tables with Missing Columns: {summary['total_tables_with_missing_columns']}")
    report_lines.append(f"Tables with Extra Columns: {summary['total_tables_with_extra_columns']}")
    report_lines.append(f"Tables with Type Differences: {summary['total_tables_with_type_differences']}")
    report_lines.append(f"Tables with Nullable Differences: {summary['total_tables_with_nullable_differences']}")
    report_lines.append("")
    
    # Missing Tables
    if comparison['missing_tables']:
        report_lines.append("=" * 80)
        report_lines.append("MISSING TABLES IN CURRENT SCHEMA")
        report_lines.append("=" * 80)
        report_lines.append("")
        report_lines.append("The following tables exist in the target Grazuri schema but are missing")
        report_lines.append("in the current Django database:")
        report_lines.append("")
        for i, table in enumerate(comparison['missing_tables'], 1):
            report_lines.append(f"{i}. {table}")
        report_lines.append("")
    
    # Extra Tables
    if comparison['extra_tables']:
        report_lines.append("=" * 80)
        report_lines.append("EXTRA TABLES IN CURRENT SCHEMA")
        report_lines.append("=" * 80)
        report_lines.append("")
        report_lines.append("The following tables exist in the current Django database but are not")
        report_lines.append("in the target Grazuri schema:")
        report_lines.append("")
        for i, table in enumerate(comparison['extra_tables'], 1):
            report_lines.append(f"{i}. {table}")
        report_lines.append("")
    
    # Missing Columns
    if comparison['missing_columns']:
        report_lines.append("=" * 80)
        report_lines.append("MISSING COLUMNS IN EXISTING TABLES")
        report_lines.append("=" * 80)
        report_lines.append("")
        for table_name in sorted(comparison['missing_columns'].keys()):
            columns = comparison['missing_columns'][table_name]
            report_lines.append(f"Table: {table_name}")
            report_lines.append(f"Missing Columns ({len(columns)}):")
            for col in columns:
                report_lines.append(f"  - {col}")
            report_lines.append("")
    
    # Extra Columns
    if comparison['extra_columns']:
        report_lines.append("=" * 80)
        report_lines.append("EXTRA COLUMNS IN EXISTING TABLES")
        report_lines.append("=" * 80)
        report_lines.append("")
        for table_name in sorted(comparison['extra_columns'].keys()):
            columns = comparison['extra_columns'][table_name]
            report_lines.append(f"Table: {table_name}")
            report_lines.append(f"Extra Columns ({len(columns)}):")
            for col in columns:
                report_lines.append(f"  - {col}")
            report_lines.append("")
    
    # Type Differences
    if comparison['type_differences']:
        report_lines.append("=" * 80)
        report_lines.append("DATA TYPE DIFFERENCES")
        report_lines.append("=" * 80)
        report_lines.append("")
        for table_name in sorted(comparison['type_differences'].keys()):
            diffs = comparison['type_differences'][table_name]
            report_lines.append(f"Table: {table_name}")
            report_lines.append(f"Type Differences ({len(diffs)}):")
            for diff in diffs:
                report_lines.append(f"  Column: {diff['column']}")
                report_lines.append(f"    Current: {diff['current_type']}")
                report_lines.append(f"    Target:  {diff['target_type']}")
            report_lines.append("")
    
    # Nullable Differences
    if comparison['nullable_differences']:
        report_lines.append("=" * 80)
        report_lines.append("NULLABLE CONSTRAINT DIFFERENCES")
        report_lines.append("=" * 80)
        report_lines.append("")
        for table_name in sorted(comparison['nullable_differences'].keys()):
            diffs = comparison['nullable_differences'][table_name]
            report_lines.append(f"Table: {table_name}")
            report_lines.append(f"Nullable Differences ({len(diffs)}):")
            for diff in diffs:
                current_null = "NULL" if diff['current_nullable'] else "NOT NULL"
                target_null = "NULL" if diff['target_nullable'] else "NOT NULL"
                report_lines.append(f"  Column: {diff['column']}")
                report_lines.append(f"    Current: {current_null}")
                report_lines.append(f"    Target:  {target_null}")
            report_lines.append("")
    
    # Recommendations
    report_lines.append("=" * 80)
    report_lines.append("RECOMMENDATIONS")
    report_lines.append("=" * 80)
    report_lines.append("")
    report_lines.append("1. Create Django models for all missing tables")
    report_lines.append("2. Add missing columns to existing Django models")
    report_lines.append("3. Generate Django migrations for all schema changes")
    report_lines.append("4. Review data type differences and adjust Django field types")
    report_lines.append("5. Review nullable constraints and update Django field definitions")
    report_lines.append("6. Test migrations in development environment before production")
    report_lines.append("")
    
    # Write report
    report_content = '\n'.join(report_lines)
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write(report_content)
    
    return report_content


def main():
    print("Starting database schema comparison...")
    print()
    
    # Parse target SQL file
    print("Parsing target Grazuri schema from SQL file...")
    target_schema = parse_sql_file(TARGET_SQL_FILE)
    print(f"Found {len(target_schema)} tables in target schema")
    print()
    
    # Get current database schema
    print("Connecting to current Django database...")
    try:
        current_schema = get_current_schema(CURRENT_DB_CONFIG)
        print(f"Found {len(current_schema)} tables in current database")
        print()
    except Exception as e:
        print(f"Error connecting to database: {e}")
        print("Please ensure MySQL is running and credentials are correct.")
        return
    
    # Compare schemas
    print("Comparing schemas...")
    comparison = compare_schemas(current_schema, target_schema)
    print()
    
    # Generate report
    output_file = r'c:\Users\Teamjoint company\Desktop\branchsystem\schema_comparison_report.txt'
    print(f"Generating comparison report...")
    report_content = generate_report(comparison, output_file)
    print(f"Report saved to: {output_file}")
    print()
    
    # Save JSON for programmatic access
    json_file = r'c:\Users\Teamjoint company\Desktop\branchsystem\schema_comparison.json'
    with open(json_file, 'w', encoding='utf-8') as f:
        json.dump(comparison, f, indent=2, default=str)
    print(f"JSON data saved to: {json_file}")
    print()
    
    # Print summary
    print("=" * 80)
    print("COMPARISON SUMMARY")
    print("=" * 80)
    print(f"Missing Tables: {comparison['summary']['total_missing_tables']}")
    print(f"Extra Tables: {comparison['summary']['total_extra_tables']}")
    print(f"Tables with Missing Columns: {comparison['summary']['total_tables_with_missing_columns']}")
    print(f"Tables with Extra Columns: {comparison['summary']['total_tables_with_extra_columns']}")
    print(f"Tables with Type Differences: {comparison['summary']['total_tables_with_type_differences']}")
    print(f"Tables with Nullable Differences: {comparison['summary']['total_tables_with_nullable_differences']}")
    print()
    
    print("Schema comparison complete!")


if __name__ == '__main__':
    main()
