import polars as pl import json from typing import List, Dict, Any, Optional from app.core.reference_data import ref_data class MMIAIInterface: """ Interface for AI to interact with MMI Data. Provides tools and data access methods. """ def __init__(self, parquet_path: str): self.parquet_path = parquet_path self.df: Optional[pl.DataFrame] = None self._load_data() def _load_data(self): try: self.df = pl.read_parquet(self.parquet_path) except Exception as e: print(f"[AI Interface] Load Error: {e}") self.df = None def get_tools_schema(self) -> List[Dict[str, Any]]: """ Returns the JSON schema for tools available to the AI. """ return [ { "type": "function", "function": { "name": "get_error_stats", "description": "Get statistics of errors (e.g., ATC status, Warnings) for the current log.", "parameters": { "type": "object", "properties": {}, "required": [] } } }, { "type": "function", "function": { "name": "search_station_event", "description": "Find events happening at a specific station.", "parameters": { "type": "object", "properties": { "station_name": {"type": "string", "description": "Name of the station (e.g., '범어사')"} }, "required": ["station_name"] } } } ] def get_error_stats(self) -> Dict[str, Any]: """ Tool: Get Error Statistics """ if self.df is None: return {"error": "No data loaded"} # Example: Count ATC Status != 'ATC ACTIVE' (64) # Note: atc_byte 64 is ACTIVE. # We need to filter where system_id == 1 (Main) usually? # Let's return stats for Main system (1) main_df = self.df.filter(pl.col("system_id") == 1) total_rows = len(main_df) # Count non-active frames # 64 is ATC ACTIVE. non_active_count = main_df.filter(pl.col("atc_byte") != 64).height return { "total_frames": total_rows, "non_active_frames": non_active_count, "active_ratio": (total_rows - non_active_count) / total_rows if total_rows > 0 else 0 } def search_station_event(self, station_name: str) -> Dict[str, Any]: """ Tool: Search events at a station """ if self.df is None: return {"error": "No data loaded"} # Filter by station name # We assume 'station_name' column exists (added by FastParser) station_df = self.df.filter(pl.col("station_name") == station_name) if station_df.is_empty(): return {"message": f"No events found at {station_name}"} # Return summary of that station visit # e.g. avg speed, time range return { "station": station_name, "visit_count": station_df.height, "avg_speed": station_df["trainspeed"].mean(), "start_time": str(station_df["timestamp"].min()), "end_time": str(station_df["timestamp"].max()) } def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any: """ Dispatcher for tool execution. """ if tool_name == "get_error_stats": return self.get_error_stats() elif tool_name == "search_station_event": return self.search_station_event(arguments.get("station_name", "")) else: return {"error": "Unknown tool"}