116 lines
4.0 KiB
Python
116 lines
4.0 KiB
Python
import polars as pl
|
|
import json
|
|
from typing import List, Dict, Any, Optional
|
|
from app.core.reference_data import ref_data
|
|
|
|
class MMIAIInterface:
|
|
"""
|
|
Interface for AI to interact with MMI Data.
|
|
Provides tools and data access methods.
|
|
"""
|
|
|
|
def __init__(self, parquet_path: str):
|
|
self.parquet_path = parquet_path
|
|
self.df: Optional[pl.DataFrame] = None
|
|
self._load_data()
|
|
|
|
def _load_data(self):
|
|
try:
|
|
self.df = pl.read_parquet(self.parquet_path)
|
|
except Exception as e:
|
|
print(f"[AI Interface] Load Error: {e}")
|
|
self.df = None
|
|
|
|
def get_tools_schema(self) -> List[Dict[str, Any]]:
|
|
"""
|
|
Returns the JSON schema for tools available to the AI.
|
|
"""
|
|
return [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_error_stats",
|
|
"description": "Get statistics of errors (e.g., ATC status, Warnings) for the current log.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {},
|
|
"required": []
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "search_station_event",
|
|
"description": "Find events happening at a specific station.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"station_name": {"type": "string", "description": "Name of the station (e.g., '범어사')"}
|
|
},
|
|
"required": ["station_name"]
|
|
}
|
|
}
|
|
}
|
|
]
|
|
|
|
def get_error_stats(self) -> Dict[str, Any]:
|
|
"""
|
|
Tool: Get Error Statistics
|
|
"""
|
|
if self.df is None:
|
|
return {"error": "No data loaded"}
|
|
|
|
# Example: Count ATC Status != 'ATC ACTIVE' (64)
|
|
# Note: atc_byte 64 is ACTIVE.
|
|
# We need to filter where system_id == 1 (Main) usually?
|
|
# Let's return stats for Main system (1)
|
|
|
|
main_df = self.df.filter(pl.col("system_id") == 1)
|
|
total_rows = len(main_df)
|
|
|
|
# Count non-active frames
|
|
# 64 is ATC ACTIVE.
|
|
non_active_count = main_df.filter(pl.col("atc_byte") != 64).height
|
|
|
|
return {
|
|
"total_frames": total_rows,
|
|
"non_active_frames": non_active_count,
|
|
"active_ratio": (total_rows - non_active_count) / total_rows if total_rows > 0 else 0
|
|
}
|
|
|
|
def search_station_event(self, station_name: str) -> Dict[str, Any]:
|
|
"""
|
|
Tool: Search events at a station
|
|
"""
|
|
if self.df is None:
|
|
return {"error": "No data loaded"}
|
|
|
|
# Filter by station name
|
|
# We assume 'station_name' column exists (added by FastParser)
|
|
station_df = self.df.filter(pl.col("station_name") == station_name)
|
|
|
|
if station_df.is_empty():
|
|
return {"message": f"No events found at {station_name}"}
|
|
|
|
# Return summary of that station visit
|
|
# e.g. avg speed, time range
|
|
return {
|
|
"station": station_name,
|
|
"visit_count": station_df.height,
|
|
"avg_speed": station_df["trainspeed"].mean(),
|
|
"start_time": str(station_df["timestamp"].min()),
|
|
"end_time": str(station_df["timestamp"].max())
|
|
}
|
|
|
|
def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
|
"""
|
|
Dispatcher for tool execution.
|
|
"""
|
|
if tool_name == "get_error_stats":
|
|
return self.get_error_stats()
|
|
elif tool_name == "search_station_event":
|
|
return self.search_station_event(arguments.get("station_name", ""))
|
|
else:
|
|
return {"error": "Unknown tool"}
|