# Copyright 2025 DataRobot, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import yaml
from typing import Any, Dict, Generator, Optional, Union
from urllib.parse import urljoin, urlparse

from crewai import LLM, Agent, Crew, Task, Process
from crewai_event_listener import CrewAIEventListener
from openai.types.chat import CompletionCreateParams
from ragas import MultiTurnSample
from ragas.messages import AIMessage, HumanMessage, ToolMessage
from crewai.tools import tool
import requests
from datetime import datetime, timedelta


# =============================================================================
# CONFIGURATION: Texas city locations and ERCOT hub mappings
# =============================================================================

TEXAS_LOCATIONS = {
    "austin": {"lat": 30.2672, "lon": -97.7431},
    "dallas": {"lat": 32.7767, "lon": -96.7970},
    "houston": {"lat": 29.7604, "lon": -95.3698},
    "san_antonio": {"lat": 29.4241, "lon": -98.4936},
    "west_texas": {"lat": 31.7619, "lon": -106.4850},
    "corpus_christi": {"lat": 27.8006, "lon": -97.3964},
    "midland": {"lat": 31.9973, "lon": -102.0779},
}

HUB_LOCATIONS = {
    "HB_HOUSTON": ["houston", "corpus_christi"],
    "HB_NORTH": ["dallas", "austin"],
    "HB_SOUTH": ["san_antonio", "corpus_christi"],
    "HB_WEST": ["west_texas", "midland"],
    "HB_BUSAVG": ["houston", "dallas", "austin", "san_antonio"],
}


# =============================================================================
# TOOL 1: Weather Data Tool
# =============================================================================

@tool
def fetch_texas_weather(timestamp: str, hub_name: str = "HB_HOUSTON") -> dict[str, Any]:
    """
    Fetch weather data for Texas locations using Open-Meteo API (free, no key required).
    
    Args:
        timestamp: ISO format timestamp
        hub_name: ERCOT hub name (HB_HOUSTON, HB_NORTH, HB_SOUTH, HB_WEST, HB_BUSAVG)
    
    Returns:
        Dictionary with weather data for relevant locations
    """
    try:
        # Parse timestamp
        if timestamp.endswith('Z'):
            timestamp = timestamp[:-1] + '+00:00'
        dt = datetime.fromisoformat(timestamp)
        target_date = dt.strftime("%Y-%m-%d")
        
        # Get locations for this hub
        locations = HUB_LOCATIONS.get(hub_name, ["houston"])
        
        weather_data = {
            "timestamp": timestamp,
            "hub_name": hub_name,
            "locations": {},
            "summary": {}
        }
        
        # Fetch weather for each location
        temps, winds, clouds = [], [], []
        
        for location in locations:
            coords = TEXAS_LOCATIONS.get(location)
            if not coords:
                continue
            
            # Call Open-Meteo API
            url = "https://archive-api.open-meteo.com/v1/archive"
            params = {
                "latitude": coords["lat"],
                "longitude": coords["lon"],
                "start_date": target_date,
                "end_date": target_date,
                "hourly": "temperature_2m,wind_speed_10m,cloud_cover",
                "timezone": "America/Chicago",
                "temperature_unit": "fahrenheit",
                "wind_speed_unit": "mph"
            }
            
            response = requests.get(url, params=params, timeout=10)
            
            if response.status_code == 200:
                data = response.json()
                hourly = data.get("hourly", {})
                
                # Get data for the specific hour
                target_hour = dt.hour
                if target_hour < len(hourly.get("time", [])):
                    temp_f = hourly["temperature_2m"][target_hour]
                    wind_mph = hourly["wind_speed_10m"][target_hour]
                    cloud = hourly["cloud_cover"][target_hour]
                    
                    weather_data["locations"][location] = {
                        "temperature_f": temp_f,
                        "wind_speed_mph": wind_mph,
                        "cloud_cover": cloud
                    }
                    
                    temps.append(temp_f)
                    winds.append(wind_mph)
                    clouds.append(cloud)
        
        # Calculate summary statistics
        if temps:
            weather_data["summary"] = {
                "avg_temperature_f": sum(temps) / len(temps),
                "max_temperature_f": max(temps),
                "avg_wind_speed_mph": sum(winds) / len(winds),
                "avg_cloud_cover_percent": sum(clouds) / len(clouds)
            }
        
        return weather_data
        
    except Exception as e:
        return {"error": str(e), "timestamp": timestamp, "hub_name": hub_name}


# =============================================================================
# TOOL 2: News Data Tool (Tavily)
# =============================================================================

@tool
def fetch_energy_news(query: str = "ERCOT Texas electricity", target_date: str = None) -> dict[str, Any]:
    """
    Fetch energy-related news articles using Tavily API, filtered by date range.
    
    Args:
        query: Search query for news articles
        target_date: ISO format date/timestamp to search around (e.g., "2025-10-22T15:00:00")
    
    Returns:
        Dictionary containing news articles from around the target date
    """
    # Import here to avoid module-level import errors if package not installed
    from tavily import TavilyClient
    
    api_key = os.getenv("TAVILY_API_KEY", "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")
    
    # Debug logging
    print(f"[fetch_energy_news] API key present: {bool(api_key)}, length: {len(api_key) if api_key else 0}", flush=True)
    print(f"[fetch_energy_news] Query: {query}, Target date: {target_date}", flush=True)
    
    try:
        # Initialize Tavily client
        tavily_client = TavilyClient(api_key=api_key)
        print(f"[fetch_energy_news] Tavily client initialized", flush=True)
        
        # Parse target date and set date range (7 days before to 1 day after for better coverage)
        if target_date:
            # Handle ISO format with time (e.g., "2025-10-22T15:00:00+00:00")
            if 'T' in target_date:
                target_dt = datetime.fromisoformat(target_date.replace('Z', '+00:00'))
            else:
                target_dt = datetime.strptime(target_date, "%Y-%m-%d")
            
            # Set date range: 7 days before to 1 day after target date (wider range for historical data)
            start_dt = target_dt - timedelta(days=7)
            end_dt = target_dt + timedelta(days=1)
            
            start_date = start_dt.strftime("%Y-%m-%d")
            end_date = end_dt.strftime("%Y-%m-%d")
        else:
            # Default to last 30 days if no date specified
            end_date = datetime.now().strftime("%Y-%m-%d")
            start_date = (datetime.now() - timedelta(days=30)).strftime("%Y-%m-%d")
        
        print(f"[fetch_energy_news] Tavily search: query='{query}', dates={start_date} to {end_date}", flush=True)
        
        # Search for news with appropriate date filtering
        # Try using Tavily's start_date/end_date parameters for better historical filtering
        
        if target_date:
            from datetime import timezone
            days_ago = (datetime.now(timezone.utc) - target_dt).days
            print(f"[fetch_energy_news] Target date is {days_ago} days ago from today", flush=True)
            
            # Try using start_date and end_date parameters (if supported by API)
            try:
                print(f"[fetch_energy_news] Attempting search with start_date={start_date}, end_date={end_date}", flush=True)
                response = tavily_client.search(
                    query=query,
                    topic="news",
                    start_date=start_date,
                    end_date=end_date,
                    max_results=20,
                    include_domains=["reuters.com", "bloomberg.com", "wsj.com", "ft.com", "cnbc.com", 
                                     "forbes.com", "oilprice.com", "naturalgasintel.com", "eia.gov"]
                )
                print(f"[fetch_energy_news] ✅ Successfully used start_date/end_date parameters", flush=True)
            except Exception as date_error:
                print(f"[fetch_energy_news] ⚠️ start_date/end_date failed: {date_error}", flush=True)
                
                # Fallback 1: Try with time_range parameter
                try:
                    if days_ago <= 1:
                        time_range = "day"
                    elif days_ago <= 7:
                        time_range = "week"
                    elif days_ago <= 30:
                        time_range = "month"
                    else:
                        time_range = "year"
                    
                    print(f"[fetch_energy_news] Trying with time_range={time_range}", flush=True)
                    response = tavily_client.search(
                        query=query,
                        topic="news",
                        time_range=time_range,
                        max_results=20
                    )
                    print(f"[fetch_energy_news] ✅ Successfully used time_range parameter", flush=True)
                except Exception as time_range_error:
                    print(f"[fetch_energy_news] ⚠️ time_range failed: {time_range_error}", flush=True)
                    
                    # Fallback 2: Use days parameter for recent dates, or no filter for old dates
                    if days_ago >= 0 and days_ago <= 30:
                        print(f"[fetch_energy_news] Using days parameter: {min(days_ago + 7, 30)}", flush=True)
                        response = tavily_client.search(
                            query=query,
                            topic="news",
                            days=min(days_ago + 7, 30),
                            max_results=20
                        )
                    else:
                        print(f"[fetch_energy_news] Using no date filter (historical date)", flush=True)
                        response = tavily_client.search(
                            query=query + f" {start_date} {end_date}",
                            topic="news",
                            max_results=20
                        )
        else:
            # No target date specified, search recent news
            print(f"[fetch_energy_news] No target date, using time_range=week", flush=True)
            response = tavily_client.search(
                query=query,
                topic="news",
                time_range="week",
                max_results=20
            )
        
        # Convert Tavily results to standard article format
        articles = []
        for result in response.get("results", []):
            articles.append({
                "title": result.get("title", ""),
                "description": result.get("content", ""),
                "url": result.get("url", ""),
                "source": {"name": result.get("url", "").split("/")[2] if result.get("url") else "Unknown"},
                "publishedAt": result.get("published_date", ""),
                "score": result.get("score", 0)
            })
        
        print(f"[fetch_energy_news] Found {len(articles)} articles", flush=True)
        
        # Log article dates for debugging
        if articles:
            article_dates = [a.get("publishedAt", "Unknown") for a in articles[:5]]
            print(f"[fetch_energy_news] Sample article dates: {article_dates}", flush=True)
        
        return {
            "query": query,
            "start_date": start_date,
            "end_date": end_date,
            "total_results": len(articles),
            "articles": articles,
            "note": "News API returns most relevant recent articles. Historical news archives may not be available for dates more than 30 days old."
        }
        
    except Exception as e:
        print(f"[fetch_energy_news] ERROR: {type(e).__name__}: {str(e)}", flush=True)
        import traceback
        traceback.print_exc()
        return {
            "error": str(e),
            "error_type": type(e).__name__,
            "query": query,
            "articles": []
        }


class MyAgent:
    """ERCOT Forecast Error Analysis Agent using CrewAI.
    
    This agent analyzes forecast errors in the ERCOT (Electric Reliability Council of Texas)
    electricity market by examining weather data, renewable generation, and energy market news.
    It utilizes DataRobot's LLM Gateway for language model interactions and provides
    comprehensive narrative analysis of price forecast errors.
    """

    def __init__(
        self,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
        model: Optional[str] = None,
        verbose: Optional[Union[bool, str]] = True,
        timeout: Optional[int] = 90,
        **kwargs: Any,
    ):
        """Initializes the MyAgent class with API key, base URL, model, and verbosity settings.

        Args:
            api_key: Optional[str]: API key for authentication with DataRobot services.
                Defaults to None, in which case it will use the DATAROBOT_API_TOKEN environment variable.
            api_base: Optional[str]: Base URL for the DataRobot API.
                Defaults to None, in which case it will use the DATAROBOT_ENDPOINT environment variable.
            model: Optional[str]: The LLM model to use.
                Defaults to None.
            verbose: Optional[Union[bool, str]]: Whether to enable verbose logging.
                Accepts boolean or string values ("true"/"false"). Defaults to True.
            timeout: Optional[int]: How long to wait for the agent to respond.
                Defaults to 90 seconds.
            **kwargs: Any: Additional keyword arguments passed to the agent.
                Contains any parameters received in the CompletionCreateParams.

        Returns:
            None
        """
        self.api_key = api_key or os.environ.get("DATAROBOT_API_TOKEN")
        self.api_base = (
            api_base
            or os.environ.get("DATAROBOT_ENDPOINT")
            or "https://api.datarobot.com"
        )
        self.model = model
        self.timeout = timeout
        if isinstance(verbose, str):
            self.verbose = verbose.lower() == "true"
        elif isinstance(verbose, bool):
            self.verbose = verbose
        self.event_listener = CrewAIEventListener()
        
        # Load agent configuration from YAML file
        try:
            self.agents_config = self._load_agents_config()
        except Exception as e:
            print(f"Warning: Could not load agents.yaml: {e}")
            print(f"Current working directory: {os.getcwd()}")
            print(f"__file__ location: {os.path.abspath(__file__ if '__file__' in dir() else 'N/A')}")
            raise
    
    def _load_agents_config(self) -> Dict:
        """Load agent configuration from agents.yaml file"""
        # Try multiple path strategies for robustness
        import inspect
        
        possible_paths = []
        
        # Strategy 1: Use __file__ if available
        try:
            file_path = os.path.abspath(__file__)
            possible_paths.append(os.path.join(os.path.dirname(file_path), "agents.yaml"))
        except NameError:
            pass
        
        # Strategy 2: Use inspect to get current file location
        try:
            frame = inspect.currentframe()
            current_file = inspect.getfile(frame)
            possible_paths.append(os.path.join(os.path.dirname(os.path.abspath(current_file)), "agents.yaml"))
        except:
            pass
        
        # Strategy 3: Relative to current working directory
        possible_paths.extend([
            os.path.join(os.getcwd(), "agents.yaml"),
            os.path.join(os.getcwd(), "custom_model", "agents.yaml"),
            "agents.yaml",
        ])
        
        # Try each path
        for config_path in possible_paths:
            if os.path.exists(config_path):
                print(f"Loading agents.yaml from: {config_path}")
                with open(config_path, 'r') as file:
                    return yaml.safe_load(file)
        
        # If file not found, raise error with helpful message
        raise FileNotFoundError(
            f"agents.yaml not found. Current dir: {os.getcwd()}, Tried paths: {possible_paths}"
        )

    def invoke(
        self, completion_create_params: CompletionCreateParams
    ) -> Union[
        Generator[tuple[str, Any | None, dict[str, int]], None, None],
        tuple[str, Any | None, dict[str, int]],
    ]:
        """Run the agent with the provided completion parameters.

        [THIS METHOD IS REQUIRED FOR THE AGENT TO WORK WITH DRUM SERVER]

        Args:
            completion_create_params: The completion request parameters including input data and settings.
        Returns:
            Union[
                Generator[tuple[str, Any | None, dict[str, int]], None, None],
                tuple[str, Any | None, dict[str, int]],
            ]: For streaming requests, returns a generator yielding tuples of (response_text, pipeline_interactions, usage_metrics).
               For non-streaming requests, returns a single tuple of (response_text, pipeline_interactions, usage_metrics).
        """
        # Retrieve the starting user prompt from the CompletionCreateParams
        user_messages = [
            msg
            for msg in completion_create_params["messages"]
            # You can use other roles as needed (e.g. "system", "assistant")
            if msg.get("role") == "user"
        ]
        user_prompt: Any = user_messages[0] if user_messages else {}
        user_prompt_content = user_prompt.get("content", {})

        # Print commands may need flush=True to ensure they are displayed in real-time.
        print("Running agent with user prompt:", user_prompt_content, flush=True)

        # Parse the input - it might be a JSON string or already a dict
        import json
        if isinstance(user_prompt_content, str):
            try:
                # Try to parse as JSON
                inputs = json.loads(user_prompt_content)
            except json.JSONDecodeError:
                # If not JSON, try to evaluate as Python dict string
                try:
                    inputs = eval(user_prompt_content)
                except:
                    # If all else fails, use as-is (will likely error in CrewAI)
                    inputs = user_prompt_content
        else:
            inputs = user_prompt_content
        
        print(f"Parsed inputs type: {type(inputs)}, value: {inputs}", flush=True)

        # Create and invoke the CrewAI Agentic Workflow with the inputs
        crewai_agentic_workflow = Crew(
            agents=[self.agent_forecast_analyst],
            tasks=[self.task_error_analysis],
            verbose=self.verbose,
            process=Process.sequential,
        )
        crew_output = crewai_agentic_workflow.kickoff(
            inputs=inputs
        )

        # Extract the final agent response as the synchronous response
        response_text = str(crew_output.raw)

        # Create a list of events from the event listener
        events = self.event_listener.messages
        if len(events) > 0:
            last_message = events[-1].content
            if last_message != response_text:
                events.append(AIMessage(content=response_text))
        else:
            events = None

        pipeline_interactions = self.create_pipeline_interactions_from_events(events)

        usage_metrics = {
            "completion_tokens": crew_output.token_usage.completion_tokens,
            "prompt_tokens": crew_output.token_usage.prompt_tokens,
            "total_tokens": crew_output.token_usage.total_tokens,
        }

        return response_text, pipeline_interactions, usage_metrics

    @property
    def llm(self) -> LLM:
        """Returns a CrewAI LLM instance configured to use DataRobot's LLM Gateway or a specific deployment.

        For help configuring different LLM backends see:
        https://github.com/datarobot-community/datarobot-agent-templates/blob/main/docs/developing-agents-llm-providers.md
        """
        api_base = urlparse(self.api_base)
        if os.environ.get("LLM_DATAROBOT_DEPLOYMENT_ID"):
            path = api_base.path
            if "/api/v2/deployments" not in path and "api/v2/genai" not in path:
                # Ensure the API base ends with /api/v2/ for deployments
                if not path.endswith("/api/v2/") and not path.endswith("/api/v2"):
                    path = urljoin(path + "/", "api/v2/")
                if not path.endswith("/"):
                    path += "/"
                api_base = api_base._replace(path=path)
                deployment_url = urljoin(
                    api_base.geturl(),
                    f"deployments/{os.environ.get('LLM_DATAROBOT_DEPLOYMENT_ID')}/",
                )
            else:
                # If user specifies a likely deployment URL then leave it alone
                deployment_url = api_base.geturl()
            return LLM(
                model="openai/gpt-4o-mini",
                api_base=deployment_url,
                api_key=self.api_key,
                timeout=self.timeout,
            )
        else:
            # Ensure the API base does not end with /api/v2/ for LLM Gateway
            # Remove only '/api/v2' or '/api/v2/' from the path portion, if present
            path = api_base.path
            if path.endswith("api/v2/") or path.endswith("api/v2"):
                path = re.sub(r"/api/v2/?$", "/", path)
            api_base = api_base._replace(path=path)
            return LLM(
                model="datarobot/azure/gpt-4o-mini",
                api_base=api_base.geturl(),
                api_key=self.api_key,
                timeout=self.timeout,
            )

    @property
    def agent_forecast_analyst(self) -> Agent:
        """Energy Market Analyst agent that interprets ERCOT forecast errors."""
        agent_config = self.agents_config["agents"]["forecast_analyst"]
        return Agent(
            role=agent_config["role"],
            goal=agent_config["goal"],
            backstory=agent_config["backstory"],
            allow_delegation=False,
            verbose=self.verbose,
            llm=self.llm,
            tools=[fetch_texas_weather, fetch_energy_news],
        )

    @property
    def task_error_analysis(self) -> Task:
        """Task for analyzing forecast errors with context from weather and news data."""
        task_config = self.agents_config["tasks"]["error_analysis"]
        return Task(
            description=task_config["description"],
            expected_output=task_config["expected_output"],
            agent=self.agent_forecast_analyst,
        )

    @staticmethod
    def create_pipeline_interactions_from_events(
        events: list[Union[HumanMessage, AIMessage, ToolMessage]],
    ) -> MultiTurnSample | None:
        """Convert a list of events into a MultiTurnSample.

        Creates the pipeline interactions for moderations and evaluation
        (e.g. Task Adherence, Agent Goal Accuracy, Tool Call Accuracy)
        """
        if not events:
            return None
        return MultiTurnSample(user_input=events)
