﻿"""
Sync local database schema to match production database
This will add missing tables and columns from production SQL file
"""
import os
import django
import re

os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'branch_system.settings')
django.setup()

from django.db import connection

def execute_sql(sql, description=""):
    """Execute SQL with error handling"""
    try:
        with connection.cursor() as cursor:
            cursor.execute(sql)
        print(f"✓ {description}")
        return True
    except Exception as e:
        print(f"✗ {description}: {str(e)}")
        return False

def parse_create_table_statements(filepath):
    """Extract CREATE TABLE statements from SQL file"""
    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read()
    
    # Find all CREATE TABLE statements
    pattern = r'CREATE TABLE `([^`]+)` \((.*?)\) ENGINE=InnoDB[^;]*;'
    matches = re.finditer(pattern, content, re.DOTALL)
    
    tables = {}
    for match in matches:
        table_name = match.group(1)
        table_def = match.group(0)
        tables[table_name] = table_def
    
    return tables

def get_current_tables():
    """Get all tables from current database"""
    with connection.cursor() as cursor:
        cursor.execute("SHOW TABLES")
        return {row[0] for row in cursor.fetchall()}

def get_table_columns(table_name):
    """Get columns for a table"""
    try:
        with connection.cursor() as cursor:
            cursor.execute(f"DESCRIBE `{table_name}`")
            return {row[0] for row in cursor.fetchall()}
    except:
        return set()

def extract_columns_from_create(create_statement):
    """Extract column definitions from CREATE TABLE statement"""
    # Get the part between CREATE TABLE and ENGINE
    match = re.search(r'CREATE TABLE `[^`]+` \((.*?)\) ENGINE=', create_statement, re.DOTALL)
    if not match:
        return []
    
    table_def = match.group(1)
    columns = []
    
    lines = table_def.split('\n')
    for line in lines:
        line = line.strip()
        # Skip constraints and keys
        if line.startswith('`') and not any(keyword in line.upper() for keyword in ['PRIMARY KEY', 'UNIQUE KEY', 'KEY ', 'CONSTRAINT']):
            # Remove trailing comma
            col_def = line.rstrip(',')
            columns.append(col_def)
    
    return columns

def main():
    print("=" * 80)
    print("SYNCING LOCAL DATABASE TO PRODUCTION SCHEMA")
    print("=" * 80)
    print()
    
    sql_file = r"C:\Users\PC\Desktop\branch-system\users\xygbfpsg_graz (1).sql"
    
    print("Parsing production SQL file...")
    production_tables = parse_create_table_statements(sql_file)
    print(f"Found {len(production_tables)} tables in production\n")
    
    print("Getting current local tables...")
    local_tables = get_current_tables()
    print(f"Found {len(local_tables)} tables in local database\n")
    
    # Tables to create
    missing_tables = set(production_tables.keys()) - local_tables
    
    print("=" * 80)
    print(f"CREATING MISSING TABLES: {len(missing_tables)}")
    print("=" * 80)
    
    created_count = 0
    for table_name in sorted(missing_tables):
        print(f"\nCreating table: {table_name}")
        create_statement = production_tables[table_name]
        if execute_sql(create_statement, f"Created {table_name}"):
            created_count += 1
    
    print(f"\n✓ Created {created_count} tables")
    
    # Now check for missing columns in existing tables
    print("\n" + "=" * 80)
    print("CHECKING FOR MISSING COLUMNS IN EXISTING TABLES")
    print("=" * 80)
    
    common_tables = local_tables & set(production_tables.keys())
    columns_added = 0
    
    for table_name in sorted(common_tables):
        local_cols = get_table_columns(table_name)
        prod_cols_defs = extract_columns_from_create(production_tables[table_name])
        
        # Extract column names from definitions
        prod_cols = {}
        for col_def in prod_cols_defs:
            match = re.match(r'`([^`]+)`\s+(.+)', col_def)
            if match:
                col_name = match.group(1)
                col_type = match.group(2)
                prod_cols[col_name] = col_type
        
        missing_cols = set(prod_cols.keys()) - local_cols
        
        if missing_cols:
            print(f"\n📋 Table: {table_name} - Missing {len(missing_cols)} columns")
            for col_name in sorted(missing_cols):
                col_def = prod_cols[col_name]
                alter_sql = f"ALTER TABLE `{table_name}` ADD COLUMN `{col_name}` {col_def}"
                if execute_sql(alter_sql, f"  Added column: {col_name}"):
                    columns_added += 1
    
    print("\n" + "=" * 80)
    print("SUMMARY")
    print("=" * 80)
    print(f"✓ Tables created: {created_count}")
    print(f"✓ Columns added: {columns_added}")
    print("\n✅ Local database synced with production schema!")
    print("\nNext steps:")
    print("1. Run: python manage.py makemigrations")
    print("2. Run: python manage.py migrate --fake")
    print("3. Restart your development server")
    print("=" * 80)

if __name__ == '__main__':
    try:
        main()
    except Exception as e:
        print(f"\n❌ Error: {str(e)}")
        import traceback
        traceback.print_exc()
