"""
Chart Generation Service for PDF Reports
Provides chart generation capabilities for embedding in PDF reports
"""
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO
import numpy as np
from typing import Dict, List, Any, Optional
import logging

logger = logging.getLogger(__name__)


class ChartGenerationService:
    """
    Service for generating charts that can be embedded in PDF reports
    """
    
    def __init__(self):
        """Initialize the chart generation service"""
        # Set style for professional appearance
        plt.style.use('seaborn-v0_8')
        sns.set_palette("husl")
        
        # Chart configuration
        self.figure_size = (10, 6)
        self.dpi = 300
        
        # Color scheme matching PDF service
        self.colors = {
            'primary': '#2c3e50',
            'secondary': '#3498db',
            'success': '#27ae60',
            'warning': '#f39c12',
            'danger': '#e74c3c',
            'light_gray': '#f8f9fc',
            'dark_gray': '#5a5c69'
        }
    
    def generate_chart_for_pdf(self, chart_type: str, data: Dict[str, Any]) -> Optional[BytesIO]:
        """
        Generate chart for PDF embedding
        
        Args:
            chart_type: Type of chart ('bar', 'pie', 'line', 'scatter')
            data: Chart data with labels, values, and title
            
        Returns:
            BytesIO: Chart image buffer or None if generation fails
        """
        try:
            if chart_type == 'bar':
                return self._generate_bar_chart(data)
            elif chart_type == 'pie':
                return self._generate_pie_chart(data)
            elif chart_type == 'line':
                return self._generate_line_chart(data)
            elif chart_type == 'scatter':
                return self._generate_scatter_chart(data)
            else:
                logger.warning(f"Unsupported chart type: {chart_type}")
                return None
                
        except Exception as e:
            logger.error(f"Error generating {chart_type} chart: {str(e)}")
            return None
    
    def _generate_bar_chart(self, data: Dict[str, Any]) -> BytesIO:
        """Generate bar chart"""
        fig, ax = plt.subplots(figsize=self.figure_size)
        
        labels = data.get('labels', [])
        values = data.get('values', [])
        title = data.get('title', 'Bar Chart')
        
        # Create bar chart
        bars = ax.bar(labels, values, color=self.colors['secondary'], alpha=0.8)
        
        # Customize chart
        ax.set_title(title, fontsize=16, fontweight='bold', color=self.colors['primary'])
        ax.set_xlabel('Categories', fontsize=12)
        ax.set_ylabel('Values', fontsize=12)
        
        # Add value labels on bars
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{height:.1f}',
                   ha='center', va='bottom', fontsize=10)
        
        # Rotate x-axis labels if needed
        if len(max(labels, key=len, default='')) > 8:
            plt.xticks(rotation=45, ha='right')
        
        # Style improvements
        ax.grid(True, alpha=0.3)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        
        plt.tight_layout()
        
        # Save to buffer
        buffer = BytesIO()
        plt.savefig(buffer, format='png', dpi=self.dpi, bbox_inches='tight')
        buffer.seek(0)
        plt.close(fig)
        
        return buffer
    
    def _generate_pie_chart(self, data: Dict[str, Any]) -> BytesIO:
        """Generate pie chart"""
        fig, ax = plt.subplots(figsize=self.figure_size)
        
        labels = data.get('labels', [])
        values = data.get('values', [])
        title = data.get('title', 'Pie Chart')
        
        # Filter out zero values
        filtered_data = [(label, value) for label, value in zip(labels, values) if value > 0]
        if not filtered_data:
            # Create empty chart
            ax.text(0.5, 0.5, 'No data available', ha='center', va='center', 
                   transform=ax.transAxes, fontsize=14)
            ax.set_title(title, fontsize=16, fontweight='bold', color=self.colors['primary'])
        else:
            filtered_labels, filtered_values = zip(*filtered_data)
            
            # Create pie chart
            colors = [self.colors['secondary'], self.colors['success'], 
                     self.colors['warning'], self.colors['danger']]
            
            wedges, texts, autotexts = ax.pie(filtered_values, labels=filtered_labels, 
                                            autopct='%1.1f%%', startangle=90,
                                            colors=colors[:len(filtered_values)])
            
            # Customize text
            for autotext in autotexts:
                autotext.set_color('white')
                autotext.set_fontweight('bold')
            
            ax.set_title(title, fontsize=16, fontweight='bold', color=self.colors['primary'])
        
        plt.tight_layout()
        
        # Save to buffer
        buffer = BytesIO()
        plt.savefig(buffer, format='png', dpi=self.dpi, bbox_inches='tight')
        buffer.seek(0)
        plt.close(fig)
        
        return buffer
    
    def _generate_line_chart(self, data: Dict[str, Any]) -> BytesIO:
        """Generate line chart"""
        fig, ax = plt.subplots(figsize=self.figure_size)
        
        labels = data.get('labels', [])
        values = data.get('values', [])
        title = data.get('title', 'Line Chart')
        
        # Create line chart
        ax.plot(labels, values, marker='o', linewidth=2, markersize=6,
               color=self.colors['secondary'])
        
        # Customize chart
        ax.set_title(title, fontsize=16, fontweight='bold', color=self.colors['primary'])
        ax.set_xlabel('Time Period', fontsize=12)
        ax.set_ylabel('Values', fontsize=12)
        
        # Add value labels
        for i, (label, value) in enumerate(zip(labels, values)):
            ax.annotate(f'{value:.1f}', (i, value), textcoords="offset points",
                       xytext=(0,10), ha='center', fontsize=9)
        
        # Style improvements
        ax.grid(True, alpha=0.3)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        
        # Rotate x-axis labels if needed
        if len(labels) > 6:
            plt.xticks(rotation=45, ha='right')
        
        plt.tight_layout()
        
        # Save to buffer
        buffer = BytesIO()
        plt.savefig(buffer, format='png', dpi=self.dpi, bbox_inches='tight')
        buffer.seek(0)
        plt.close(fig)
        
        return buffer
    
    def _generate_scatter_chart(self, data: Dict[str, Any]) -> BytesIO:
        """Generate scatter chart"""
        fig, ax = plt.subplots(figsize=self.figure_size)
        
        x_values = data.get('x_values', [])
        y_values = data.get('y_values', [])
        title = data.get('title', 'Scatter Chart')
        x_label = data.get('x_label', 'X Values')
        y_label = data.get('y_label', 'Y Values')
        
        # Create scatter chart
        ax.scatter(x_values, y_values, alpha=0.7, s=60, color=self.colors['secondary'])
        
        # Customize chart
        ax.set_title(title, fontsize=16, fontweight='bold', color=self.colors['primary'])
        ax.set_xlabel(x_label, fontsize=12)
        ax.set_ylabel(y_label, fontsize=12)
        
        # Style improvements
        ax.grid(True, alpha=0.3)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        
        plt.tight_layout()
        
        # Save to buffer
        buffer = BytesIO()
        plt.savefig(buffer, format='png', dpi=self.dpi, bbox_inches='tight')
        buffer.seek(0)
        plt.close(fig)
        
        return buffer
    
    def generate_dashboard_chart(self, chart_type: str, data: Dict[str, Any], 
                               size: tuple = (8, 5)) -> Optional[BytesIO]:
        """
        Generate chart for dashboard display
        
        Args:
            chart_type: Type of chart
            data: Chart data
            size: Figure size tuple
            
        Returns:
            BytesIO: Chart image buffer
        """
        original_size = self.figure_size
        self.figure_size = size
        
        try:
            result = self.generate_chart_for_pdf(chart_type, data)
            return result
        finally:
            self.figure_size = original_size
    
    def create_multi_chart_figure(self, charts_config: List[Dict[str, Any]]) -> Optional[BytesIO]:
        """
        Create figure with multiple charts
        
        Args:
            charts_config: List of chart configurations
            
        Returns:
            BytesIO: Multi-chart figure buffer
        """
        try:
            num_charts = len(charts_config)
            if num_charts == 0:
                return None
            
            # Determine subplot layout
            if num_charts == 1:
                rows, cols = 1, 1
            elif num_charts == 2:
                rows, cols = 1, 2
            elif num_charts <= 4:
                rows, cols = 2, 2
            else:
                rows, cols = 3, 2
            
            fig, axes = plt.subplots(rows, cols, figsize=(15, 10))
            if num_charts == 1:
                axes = [axes]
            elif rows == 1 or cols == 1:
                axes = axes.flatten()
            else:
                axes = axes.flatten()
            
            # Generate each chart
            for i, chart_config in enumerate(charts_config):
                if i >= len(axes):
                    break
                
                ax = axes[i]
                chart_type = chart_config.get('type', 'bar')
                data = chart_config.get('data', {})
                
                self._create_subplot_chart(ax, chart_type, data)
            
            # Hide unused subplots
            for i in range(num_charts, len(axes)):
                axes[i].set_visible(False)
            
            plt.tight_layout()
            
            # Save to buffer
            buffer = BytesIO()
            plt.savefig(buffer, format='png', dpi=self.dpi, bbox_inches='tight')
            buffer.seek(0)
            plt.close(fig)
            
            return buffer
            
        except Exception as e:
            logger.error(f"Error creating multi-chart figure: {str(e)}")
            return None
    
    def _create_subplot_chart(self, ax, chart_type: str, data: Dict[str, Any]):
        """Create chart in subplot"""
        labels = data.get('labels', [])
        values = data.get('values', [])
        title = data.get('title', 'Chart')
        
        if chart_type == 'bar':
            ax.bar(labels, values, color=self.colors['secondary'], alpha=0.8)
            ax.set_ylabel('Values')
        elif chart_type == 'pie':
            if any(v > 0 for v in values):
                ax.pie(values, labels=labels, autopct='%1.1f%%', startangle=90)
            else:
                ax.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax.transAxes)
        elif chart_type == 'line':
            ax.plot(labels, values, marker='o', color=self.colors['secondary'])
            ax.set_ylabel('Values')
        
        ax.set_title(title, fontsize=12, fontweight='bold')
        ax.grid(True, alpha=0.3)