#!/usr/bin/env python3
"""
Comprehensive fix for missing foreign key columns in the database.
This addresses the production errors with missing loan_id and repayment_id columns.
"""

import os
import django
import sys
from django.db import connection, transaction

# Setup Django
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'branch_system.settings')
django.setup()

from loans.models import Loan, Repayment, MpesaTransaction
from utils.models import Receipt

def check_and_add_missing_columns():
    """Check for and add missing foreign key columns"""
    with connection.cursor() as cursor:
        print("=== Checking and Adding Missing Columns ===")
        
        # Check repayments table
        cursor.execute("DESCRIBE repayments")
        repayments_columns = [row[0] for row in cursor.fetchall()]
        
        if 'loan_id' not in repayments_columns:
            print("Adding loan_id column to repayments table...")
            cursor.execute("""
                ALTER TABLE repayments 
                ADD COLUMN loan_id CHAR(36) NULL
            """)
            print("✓ Added loan_id column to repayments")
        else:
            print("- loan_id column already exists in repayments")
        
        # Check mpesa_transactions table
        cursor.execute("DESCRIBE mpesa_transactions")
        mpesa_columns = [row[0] for row in cursor.fetchall()]
        
        if 'loan_id' not in mpesa_columns:
            print("Adding loan_id column to mpesa_transactions table...")
            cursor.execute("""
                ALTER TABLE mpesa_transactions 
                ADD COLUMN loan_id CHAR(36) NULL
            """)
            print("✓ Added loan_id column to mpesa_transactions")
        else:
            print("- loan_id column already exists in mpesa_transactions")
        
        # Check for repayment_id in mpesa_transactions
        if 'repayment_id' not in mpesa_columns:
            print("Adding repayment_id column to mpesa_transactions table...")
            cursor.execute("""
                ALTER TABLE mpesa_transactions 
                ADD COLUMN repayment_id CHAR(36) NULL
            """)
            print("✓ Added repayment_id column to mpesa_transactions")
        else:
            print("- repayment_id column already exists in mpesa_transactions")
        
        # Check receipts table
        cursor.execute("DESCRIBE receipts")
        receipts_columns = [row[0] for row in cursor.fetchall()]
        
        if 'repayment_id' not in receipts_columns:
            print("Adding repayment_id column to receipts table...")
            cursor.execute("""
                ALTER TABLE receipts 
                ADD COLUMN repayment_id CHAR(36) NULL
            """)
            print("✓ Added repayment_id column to receipts")
        else:
            print("- repayment_id column already exists in receipts")

def populate_foreign_keys():
    """Populate the foreign key columns with existing data"""
    with connection.cursor() as cursor:
        print("\n=== Populating Foreign Key Data ===")
        
        # Check if repayment_id column exists in mpesa_transactions before trying to use it
        cursor.execute("DESCRIBE mpesa_transactions")
        mpesa_columns = [row[0] for row in cursor.fetchall()]
        
        if 'repayment_id' in mpesa_columns:
            # For mpesa_transactions, populate loan_id from repayment->loan relationship
            cursor.execute("""
                UPDATE mpesa_transactions mt
                INNER JOIN repayments r ON mt.repayment_id = r.id
                SET mt.loan_id = r.loan_id
                WHERE mt.loan_id IS NULL AND mt.repayment_id IS NOT NULL AND r.loan_id IS NOT NULL
            """)
            print(f"✓ Updated {cursor.rowcount} mpesa_transaction records with loan_id")
        else:
            print("⚠️  Skipping mpesa_transactions loan_id population - repayment_id column missing")
        
        # For receipts, populate repayment_id from the repayment relationship
        cursor.execute("DESCRIBE receipts")
        receipts_columns = [row[0] for row in cursor.fetchall()]
        
        if 'repayment_id' in receipts_columns:
            cursor.execute("""
                UPDATE receipts rec
                INNER JOIN repayments r ON rec.receipt_number = r.receipt_number
                SET rec.repayment_id = r.id
                WHERE rec.repayment_id IS NULL
            """)
            print(f"✓ Updated {cursor.rowcount} receipt records with repayment_id")
        else:
            print("⚠️  Skipping receipts repayment_id population - column missing")
        
        # Note about manual data migration for repayments
        print("\nNote: Manual data migration may be required for existing repayments")
        print("You may need to update repayments.loan_id based on business logic")

def add_indexes_and_constraints():
    """Add indexes and foreign key constraints"""
    with connection.cursor() as cursor:
        print("\n=== Adding Indexes and Constraints ===")
        
        try:
            # Add indexes
            cursor.execute("ALTER TABLE repayments ADD INDEX idx_repayments_loan_id (loan_id)")
            print("✓ Added index for repayments.loan_id")
        except Exception as e:
            if "Duplicate key name" not in str(e):
                print(f"Warning: Could not add repayments.loan_id index: {e}")
        
        try:
            cursor.execute("ALTER TABLE mpesa_transactions ADD INDEX idx_mpesa_loan_id (loan_id)")
            print("✓ Added index for mpesa_transactions.loan_id")
        except Exception as e:
            if "Duplicate key name" not in str(e):
                print(f"Warning: Could not add mpesa_transactions.loan_id index: {e}")
        
        try:
            cursor.execute("ALTER TABLE mpesa_transactions ADD INDEX idx_mpesa_repayment_id (repayment_id)")
            print("✓ Added index for mpesa_transactions.repayment_id")
        except Exception as e:
            if "Duplicate key name" not in str(e):
                print(f"Warning: Could not add mpesa_transactions.repayment_id index: {e}")
        
        try:
            cursor.execute("ALTER TABLE receipts ADD INDEX idx_receipts_repayment_id (repayment_id)")
            print("✓ Added index for receipts.repayment_id")
        except Exception as e:
            if "Duplicate key name" not in str(e):
                print(f"Warning: Could not add receipts.repayment_id index: {e}")

def verify_schema():
    """Verify the database schema matches the Django models"""
    with connection.cursor() as cursor:
        print("\n=== Schema Verification ===")
        
        # Check repayments table
        cursor.execute("DESCRIBE repayments")
        repayments_columns = [row[0] for row in cursor.fetchall()]
        print(f"Repayments columns: {repayments_columns}")
        
        # Check mpesa_transactions table
        cursor.execute("DESCRIBE mpesa_transactions")
        mpesa_columns = [row[0] for row in cursor.fetchall()]
        print(f"MpesaTransaction columns: {mpesa_columns}")
        
        # Check receipts table
        cursor.execute("DESCRIBE receipts")
        receipts_columns = [row[0] for row in cursor.fetchall()]
        print(f"Receipts columns: {receipts_columns}")
        
        # Verify required columns exist
        missing_columns = []
        if 'loan_id' not in repayments_columns:
            missing_columns.append('repayments.loan_id')
        if 'loan_id' not in mpesa_columns:
            missing_columns.append('mpesa_transactions.loan_id')
        if 'repayment_id' not in mpesa_columns:
            missing_columns.append('mpesa_transactions.repayment_id')
        if 'repayment_id' not in receipts_columns:
            missing_columns.append('receipts.repayment_id')
        
        if missing_columns:
            print(f"\n❌ Still missing columns: {', '.join(missing_columns)}")
            return False
        else:
            print("\n✅ All required foreign key columns are present")
            return True

def main():
    """Main execution function"""
    try:
        with transaction.atomic():
            check_and_add_missing_columns()
            populate_foreign_keys()
            add_indexes_and_constraints()
            
        # Verify outside of transaction
        schema_ok = verify_schema()
        
        if schema_ok:
            print("\n✅ Database schema fix completed successfully!")
            print("\n🔄 Please restart your Django server to apply the changes.")
            print("\n📝 Next steps:")
            print("   1. Restart Django server")
            print("   2. Test repayment functionality")
            print("   3. Verify payments now appear correctly")
            return True
        else:
            print("\n⚠️  Schema verification failed. Some columns may still be missing.")
            return False
        
    except Exception as e:
        print(f"\n❌ Error during database fix: {e}")
        return False

if __name__ == '__main__':
    success = main()
    if not success:
        sys.exit(1)