#!/usr/bin/env python3
"""
Branch Filtering Production Script
Ensures all pages filter data by selected branch correctly
"""

import os
import sys
import django
import re
import argparse
from pathlib import Path
from collections import defaultdict
import ast
from datetime import datetime

# Setup Django
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'branch_system.settings')
django.setup()

from django.apps import apps
from django.urls import get_resolver
from django.db import models
from django.contrib.auth.models import User
from users.models import CustomUser, Branch
from loans.models import Loan, LoanApplication, Repayment
from utils.models import Notification, Document  # Fixed import - using Document instead of CustomerDocument
from reports.models import CustomerRequest

class BranchFilteringAuditor:
    def __init__(self):
        self.issues = []
        self.fixes = []
        self.urls_found = []
        self.views_checked = 0
        self.verified_views = []  # Add this missing attribute
        self.dry_run = False
        
        # Create logs directory if it doesn't exist
        os.makedirs('logs', exist_ok=True)
        
        # Setup logging
        import logging
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler('logs/branch_filtering_audit.log'),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)

    def log_issue(self, severity, component, message):
        """Log an issue found during audit"""
        issue = {
            'severity': severity,
            'component': component,
            'message': message,
            'timestamp': datetime.now().isoformat()
        }
        self.issues.append(issue)
        self.logger.warning(f"{severity.upper()}: {component} - {message}")

    def log_fix(self, file_path, description):
        """Log a fix applied during audit"""
        fix = {
            'file_path': file_path,
            'description': description,
            'timestamp': datetime.now().isoformat()
        }
        self.fixes.append(fix)
        self.logger.info(f"FIXED: {file_path} - {description}")

    def discover_all_urls(self):
        """Discover all URL patterns in the Django project"""
        print("🔍 Discovering URL patterns...")
        
        resolver = get_resolver()
        
        def extract_urls(url_patterns, prefix=''):
            urls = []
            for pattern in url_patterns:
                if hasattr(pattern, 'url_patterns'):
                    # This is an include() pattern
                    urls.extend(extract_urls(pattern.url_patterns, prefix + str(pattern.pattern)))
                else:
                    # This is a regular URL pattern
                    full_pattern = prefix + str(pattern.pattern)
                    if hasattr(pattern, 'callback'):
                        view_name = f"{pattern.callback.__module__}.{pattern.callback.__name__}"
                        urls.append({
                            'pattern': full_pattern,
                            'view': view_name,
                            'name': getattr(pattern, 'name', None)
                        })
            return urls
        
        self.urls_found = extract_urls(resolver.url_patterns)
        print(f"✅ Found {len(self.urls_found)} URL patterns")
        
        return self.urls_found

    def check_view_file(self, file_path):
        """Check if a view file has proper branch filtering"""
        if not os.path.exists(file_path):
            return {'has_filtering': False, 'issues': ['File not found']}
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()
        except Exception as e:
            return {'has_filtering': False, 'issues': [f'Cannot read file: {str(e)}']}
        
        # Check for branch filtering patterns
        branch_patterns = [
            r'selected_branch_id\s*=\s*request\.session\.get\([\'"]selected_branch_id[\'"]',
            r'\.filter\([^)]*branch[^)]*\)',
            r'branch__id\s*=\s*selected_branch_id',
            r'branch_id\s*=\s*selected_branch_id',
            r'get_user_branch\(',
        ]
        
        has_filtering = any(re.search(pattern, content, re.IGNORECASE) for pattern in branch_patterns)
        
        issues = []
        if not has_filtering:
            issues.append('No branch filtering detected')
        
        # Check for common model queries that should be filtered
        model_patterns = [
            r'\.objects\.all\(\)',
            r'\.objects\.filter\(',
            r'\.objects\.get\(',
        ]
        
        for pattern in model_patterns:
            matches = re.findall(pattern, content)
            if matches and not has_filtering:
                issues.append(f'Found {len(matches)} model queries without branch filtering')
        
        return {
            'has_filtering': has_filtering,
            'issues': issues,
            'content': content
        }

    def fix_view_file(self, file_path):
        """Apply branch filtering fixes to a view file"""
        if self.dry_run:
            print(f"🔍 DRY-RUN: Would fix {file_path}")
            return
            
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()
        except Exception as e:
            self.log_issue('error', file_path, f'Cannot read file: {str(e)}')
            return
        
        # Parse the file to understand its structure
        try:
            tree = ast.parse(content)
        except SyntaxError as e:
            self.log_issue('error', file_path, f'Syntax error in file: {str(e)}')
            return
        
        # Find function definitions that look like views
        view_functions = []
        for node in ast.walk(tree):
            if isinstance(node, ast.FunctionDef):
                # Check if it's likely a view function
                if (len(node.args.args) >= 1 and 
                    node.args.args[0].arg == 'request'):
                    view_functions.append(node.name)
        
        if not view_functions:
            return
        
        # Apply fixes
        modified_content = content
        fixes_applied = []
        
        # Add branch filtering import if not present
        if 'from users.models import' not in content and 'from django.contrib.auth.decorators import login_required' in content:
            import_line = 'from users.models import CustomUser, Branch\n'
            modified_content = modified_content.replace(
                'from django.contrib.auth.decorators import login_required',
                f'{import_line}from django.contrib.auth.decorators import login_required'
            )
            fixes_applied.append('Added branch model imports')
        
        # Add branch filtering helper function
        helper_function = '''
def get_user_branch(request):
    """Get the selected branch for the current user"""
    selected_branch_id = request.session.get('selected_branch_id')
    if selected_branch_id:
        try:
            return Branch.objects.get(id=selected_branch_id)
        except Branch.DoesNotExist:
            pass
    
    # Fallback to user's default branch
    if hasattr(request.user, 'branch') and request.user.branch:
        return request.user.branch
    
    # Fallback to main branch
    return Branch.objects.filter(is_main=True).first()

'''
        
        if 'def get_user_branch(' not in content:
            # Insert helper function after imports
            import_end = content.rfind('from ')
            if import_end != -1:
                next_line = content.find('\n', import_end)
                if next_line != -1:
                    modified_content = (content[:next_line + 1] + 
                                      helper_function + 
                                      content[next_line + 1:])
                    fixes_applied.append('Added get_user_branch helper function')
        
        # Add branch filtering to view functions
        for func_name in view_functions:
            # Look for common patterns that need branch filtering
            func_pattern = rf'def {func_name}\([^)]*\):[^{{}}]*?return'
            func_match = re.search(func_pattern, modified_content, re.DOTALL)
            
            if func_match:
                func_content = func_match.group(0)
                
                # Check if branch filtering is already present
                if 'selected_branch_id' not in func_content:
                    # Add branch filtering at the beginning of the function
                    func_start = func_match.start()
                    func_def_end = modified_content.find(':', func_start) + 1
                    
                    branch_filter_code = '''
    # Get selected branch for filtering
    selected_branch_id = request.session.get('selected_branch_id')
    if not selected_branch_id and hasattr(request.user, 'branch'):
        selected_branch_id = request.user.branch.id if request.user.branch else None
'''
                    
                    # Insert after function definition
                    modified_content = (modified_content[:func_def_end] + 
                                      branch_filter_code + 
                                      modified_content[func_def_end:])
                    fixes_applied.append(f'Added branch filtering to {func_name}')
        
        # Apply model filtering fixes
        model_fixes = [
            # Fix common model queries
            (r'(\w+)\.objects\.all\(\)', r'\1.objects.filter(branch_id=selected_branch_id) if selected_branch_id else \1.objects.all()'),
            (r'(\w+)\.objects\.filter\(([^)]+)\)', r'\1.objects.filter(\2, branch_id=selected_branch_id) if selected_branch_id else \1.objects.filter(\2)'),
        ]
        
        for pattern, replacement in model_fixes:
            if re.search(pattern, modified_content):
                # Only apply if the model likely has a branch relationship
                modified_content = re.sub(pattern, replacement, modified_content)
                fixes_applied.append(f'Applied model filtering pattern: {pattern}')
        
        # Write the modified content back to file
        if fixes_applied:
            try:
                # Create backup
                backup_path = f"{file_path}.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
                with open(backup_path, 'w', encoding='utf-8') as f:
                    f.write(content)
                
                # Write modified content
                with open(file_path, 'w', encoding='utf-8') as f:
                    f.write(modified_content)
                
                self.log_fix(file_path, f"Applied fixes: {', '.join(fixes_applied)}")
                print(f"✅ Fixed {file_path}: {', '.join(fixes_applied)}")
                
            except Exception as e:
                self.log_issue('error', file_path, f'Failed to write fixes: {str(e)}')

    def verify_model_relationships(self):
        """Verify that models have proper branch relationships"""
        print("🔍 Verifying model relationships...")
        
        # Get all models
        all_models = apps.get_models()
        
        models_with_branch = []
        models_without_branch = []
        
        for model in all_models:
            # Skip Django's built-in models
            if model._meta.app_label in ['admin', 'auth', 'contenttypes', 'sessions']:
                continue
            
            # Check if model has branch relationship
            has_branch = False
            for field in model._meta.get_fields():
                if (hasattr(field, 'related_model') and 
                    field.related_model and 
                    field.related_model.__name__ == 'Branch'):
                    has_branch = True
                    break
                elif field.name in ['branch', 'branch_id']:
                    has_branch = True
                    break
            
            if has_branch:
                models_with_branch.append(model.__name__)
            else:
                models_without_branch.append(model.__name__)
        
        print(f"✅ Models with branch relationship: {len(models_with_branch)}")
        print(f"⚠️  Models without branch relationship: {len(models_without_branch)}")
        
        if models_without_branch:
            self.log_issue('warning', 'models', 
                         f'Models without branch relationship: {", ".join(models_without_branch)}')
        
        return {
            'with_branch': models_with_branch,
            'without_branch': models_without_branch
        }

    def audit_all_views(self):
        """Audit all view files for branch filtering"""
        print("🔍 Auditing view files...")
        
        view_dirs = ['users', 'loans', 'reports', 'utils']
        
        for view_dir in view_dirs:
            views_file = os.path.join(view_dir, 'views.py')
            if os.path.exists(views_file):
                print(f"📁 Checking {views_file}...")
                
                result = self.check_view_file(views_file)
                self.views_checked += 1
                
                if result['has_filtering']:
                    print(f"✅ {views_file} has branch filtering")
                    self.verified_views.append(views_file)
                else:
                    print(f"❌ {views_file} missing branch filtering")
                    self.log_issue('high', views_file, 'Missing branch filtering')
                    
                    # Apply fixes
                    self.fix_view_file(views_file)
                
                # Log specific issues
                for issue in result.get('issues', []):
                    self.log_issue('medium', views_file, issue)
        
        print(f"📊 Audited {self.views_checked} view files")

    def create_branch_filtering_middleware(self):
        """Create middleware to ensure branch context is always available"""
        if self.dry_run:
            print("🔍 DRY-RUN: Would create branch filtering middleware")
            return
            
        middleware_content = '''
from django.utils.deprecation import MiddlewareMixin
from users.models import Branch

class BranchFilteringMiddleware(MiddlewareMixin):
    """
    Middleware to ensure branch filtering context is available in all requests
    """
    
    def process_request(self, request):
        if request.user.is_authenticated:
            # Ensure selected_branch_id is in session
            if 'selected_branch_id' not in request.session:
                if hasattr(request.user, 'branch') and request.user.branch:
                    request.session['selected_branch_id'] = request.user.branch.id
                else:
                    # Set to main branch as fallback
                    main_branch = Branch.objects.filter(is_main=True).first()
                    if main_branch:
                        request.session['selected_branch_id'] = main_branch.id
            
            # Make branch available in request
            selected_branch_id = request.session.get('selected_branch_id')
            if selected_branch_id:
                try:
                    request.selected_branch = Branch.objects.get(id=selected_branch_id)
                except Branch.DoesNotExist:
                    request.selected_branch = None
            else:
                request.selected_branch = None
        
        return None
    
    def process_template_response(self, request, response):
        """Add branch context to template responses"""
        if hasattr(response, 'context_data') and response.context_data is not None:
            if hasattr(request, 'selected_branch'):
                response.context_data['selected_branch'] = request.selected_branch
                response.context_data['selected_branch_id'] = request.session.get('selected_branch_id')
        
        return response
'''
        
        middleware_file = 'utils/branch_middleware.py'
        
        try:
            with open(middleware_file, 'w', encoding='utf-8') as f:
                f.write(middleware_content)
            
            self.log_fix(middleware_file, 'Created branch filtering middleware')
            print(f"✅ Created {middleware_file}")
            
            print("\n📝 To activate this middleware, add to your settings.py:")
            print("MIDDLEWARE = [")
            print("    # ... other middleware ...")
            print("    'utils.branch_middleware.BranchFilteringMiddleware',")
            print("]")
            
        except Exception as e:
            self.log_issue('error', middleware_file, f'Failed to create middleware: {str(e)}')

    def create_template_context_processor(self):
        """Create context processor for branch information in templates"""
        if self.dry_run:
            print("🔍 DRY-RUN: Would create template context processor")
            return
            
        context_processor_content = '''
from users.models import Branch

def branch_context(request):
    """
    Context processor to make branch information available in all templates
    """
    context = {
        'selected_branch': None,
        'selected_branch_id': None,
        'available_branches': [],
        'is_main_branch': False,
    }
    
    if request.user.is_authenticated:
        # Get selected branch
        selected_branch_id = request.session.get('selected_branch_id')
        if selected_branch_id:
            try:
                selected_branch = Branch.objects.get(id=selected_branch_id)
                context.update({
                    'selected_branch': selected_branch,
                    'selected_branch_id': selected_branch_id,
                    'is_main_branch': selected_branch.is_main,
                })
            except Branch.DoesNotExist:
                pass
        
        # Get available branches for user
        if hasattr(request.user, 'accessible_branches'):
            context['available_branches'] = request.user.accessible_branches.all()
        elif hasattr(request.user, 'branch') and request.user.branch:
            context['available_branches'] = [request.user.branch]
    
    return context
'''
        
        context_file = 'utils/context_processors.py'
        
        try:
            # Check if file exists and append or create
            if os.path.exists(context_file):
                with open(context_file, 'r', encoding='utf-8') as f:
                    existing_content = f.read()
                
                if 'def branch_context(' not in existing_content:
                    with open(context_file, 'a', encoding='utf-8') as f:
                        f.write('\n' + context_processor_content)
                    self.log_fix(context_file, 'Added branch context processor')
                else:
                    print(f"ℹ️  Branch context processor already exists in {context_file}")
            else:
                with open(context_file, 'w', encoding='utf-8') as f:
                    f.write(context_processor_content)
                self.log_fix(context_file, 'Created branch context processor')
            
            print(f"✅ Updated {context_file}")
            
            print("\n📝 To activate this context processor, add to your settings.py:")
            print("TEMPLATES = [")
            print("    {")
            print("        'OPTIONS': {")
            print("            'context_processors': [")
            print("                # ... other context processors ...")
            print("                'utils.context_processors.branch_context',")
            print("            ],")
            print("        },")
            print("    },")
            print("]")
            
        except Exception as e:
            self.log_issue('error', context_file, f'Failed to create context processor: {str(e)}')

    def run_comprehensive_test(self):
        """Run comprehensive tests to verify branch filtering works"""
        print("🧪 Running comprehensive tests...")
        
        test_results = {
            'passed': 0,
            'failed': 0,
            'tests': []
        }
        
        # Test 1: Check if Branch model exists and has required fields
        try:
            from users.models import Branch
            branch_fields = [f.name for f in Branch._meta.get_fields()]
            required_fields = ['name', 'code', 'is_main']
            
            missing_fields = [f for f in required_fields if f not in branch_fields]
            if missing_fields:
                test_results['tests'].append({
                    'name': 'Branch Model Fields',
                    'status': 'failed',
                    'message': f'Missing fields: {missing_fields}'
                })
                test_results['failed'] += 1
            else:
                test_results['tests'].append({
                    'name': 'Branch Model Fields',
                    'status': 'passed',
                    'message': 'All required fields present'
                })
                test_results['passed'] += 1
        except Exception as e:
            test_results['tests'].append({
                'name': 'Branch Model Fields',
                'status': 'failed',
                'message': f'Error: {str(e)}'
            })
            test_results['failed'] += 1
        
        # Test 2: Check if main branch exists
        try:
            main_branch = Branch.objects.filter(is_main=True).first()
            if main_branch:
                test_results['tests'].append({
                    'name': 'Main Branch Exists',
                    'status': 'passed',
                    'message': f'Main branch: {main_branch.name}'
                })
                test_results['passed'] += 1
            else:
                test_results['tests'].append({
                    'name': 'Main Branch Exists',
                    'status': 'failed',
                    'message': 'No main branch found'
                })
                test_results['failed'] += 1
        except Exception as e:
            test_results['tests'].append({
                'name': 'Main Branch Exists',
                'status': 'failed',
                'message': f'Error: {str(e)}'
            })
            test_results['failed'] += 1
        
        # Test 3: Check if users have branch relationships
        try:
            users_with_branch = CustomUser.objects.exclude(branch__isnull=True).count()
            total_users = CustomUser.objects.count()
            
            if users_with_branch > 0:
                test_results['tests'].append({
                    'name': 'User Branch Relationships',
                    'status': 'passed',
                    'message': f'{users_with_branch}/{total_users} users have branch assignments'
                })
                test_results['passed'] += 1
            else:
                test_results['tests'].append({
                    'name': 'User Branch Relationships',
                    'status': 'failed',
                    'message': 'No users have branch assignments'
                })
                test_results['failed'] += 1
        except Exception as e:
            test_results['tests'].append({
                'name': 'User Branch Relationships',
                'status': 'failed',
                'message': f'Error: {str(e)}'
            })
            test_results['failed'] += 1
        
        # Print test results
        print(f"\n📊 Test Results: {test_results['passed']} passed, {test_results['failed']} failed")
        for test in test_results['tests']:
            status_icon = "✅" if test['status'] == 'passed' else "❌"
            print(f"{status_icon} {test['name']}: {test['message']}")
        
        return test_results

    def generate_report(self):
        """Generate comprehensive audit report"""
        print("\n📋 Generating audit report...")
        
        report_content = f"""
# Branch Filtering Audit Report
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

## Summary
- URLs discovered: {len(self.urls_found)}
- View files checked: {self.views_checked}
- View files verified: {len(self.verified_views)}
- Issues found: {len(self.issues)}
- Fixes applied: {len(self.fixes)}

## Issues Found
"""
        
        for issue in self.issues:
            report_content += f"- **{issue['severity'].upper()}** [{issue['component']}]: {issue['message']}\n"
        
        report_content += "\n## Fixes Applied\n"
        for fix in self.fixes:
            report_content += f"- {fix['file_path']}: {fix['description']}\n"
        
        report_content += f"""
## URLs Discovered
Total: {len(self.urls_found)}

"""
        for url in self.urls_found[:10]:  # Show first 10
            report_content += f"- {url['pattern']} -> {url['view']}\n"
        
        if len(self.urls_found) > 10:
            report_content += f"... and {len(self.urls_found) - 10} more\n"
        
        # Write report to file
        report_file = f'logs/branch_filtering_audit_report_{datetime.now().strftime("%Y%m%d_%H%M%S")}.md'
        
        try:
            with open(report_file, 'w', encoding='utf-8') as f:
                f.write(report_content)
            
            print(f"✅ Report saved to: {report_file}")
            
            # Also write summary to console
            print("\n" + "="*60)
            print("AUDIT SUMMARY")
            print("="*60)
            print(f"  - URLs discovered: {len(self.urls_found)}")
            print(f"  - View files checked: {self.views_checked}")
            print(f"  - View files verified: {len(self.verified_views)}")
            print(f"  - Issues found: {len(self.issues)}")
            print(f"  - Fixes applied: {len(self.fixes)}")
            print("="*60)
            
        except Exception as e:
            self.log_issue('error', 'report', f'Failed to write report: {str(e)}')

def main():
    parser = argparse.ArgumentParser(description='Branch Filtering Production Audit')
    parser.add_argument('--auto', action='store_true', 
                       help='Run automatically without user confirmation')
    parser.add_argument('--dry-run', action='store_true',
                       help='Show what would be done without making changes')
    
    args = parser.parse_args()
    
    print("🔍 Starting Branch Filtering Production Audit...")
    print("This script will:")
    print("  1. Discover all URL patterns")
    print("  2. Audit view files for branch filtering")
    print("  3. Fix missing branch filtering")
    print("  4. Verify model relationships")
    print("  5. Create supporting middleware and context processors")
    print("  6. Run comprehensive tests")
    print("  7. Generate detailed report")
    
    if not args.auto:
        try:
            proceed = input("\nProceed with audit? (y/n): ").lower()
            if proceed != 'y':
                print("Audit cancelled.")
                return
        except EOFError:
            print("\nNo input available. Use --auto flag for automatic execution.")
            print("Usage: python branch_filtering_production_script.py --auto")
            return
    else:
        print("\n✅ Running in automatic mode...")
    
    auditor = BranchFilteringAuditor()
    
    # Set dry-run mode if specified
    if args.dry_run:
        auditor.dry_run = True
        print("🔍 Running in DRY-RUN mode - no changes will be made")
    
    try:
        # Step 1: Discover URLs
        auditor.discover_all_urls()
        
        # Step 2: Verify model relationships
        auditor.verify_model_relationships()
        
        # Step 3: Audit view files
        auditor.audit_all_views()
        
        # Step 4: Create supporting files (only if not dry-run)
        if not args.dry_run:
            auditor.create_branch_filtering_middleware()
            auditor.create_template_context_processor()
        
        # Step 5: Run tests
        auditor.run_comprehensive_test()
        
        # Step 6: Generate report
        auditor.generate_report()
        
        print("\n🎉 Branch filtering audit completed successfully!")
        
    except KeyboardInterrupt:
        print("\n\n⚠️  Audit interrupted by user")
    except Exception as e:
        print(f"\n❌ Audit failed with error: {str(e)}")
        import traceback
        traceback.print_exc()

if __name__ == '__main__':
    main()