AI_MMI_Analyser/app/ai/interface.py

116 lines
4.0 KiB
Python

import polars as pl
import json
from typing import List, Dict, Any, Optional
from app.core.reference_data import ref_data
class MMIAIInterface:
"""
Interface for AI to interact with MMI Data.
Provides tools and data access methods.
"""
def __init__(self, parquet_path: str):
self.parquet_path = parquet_path
self.df: Optional[pl.DataFrame] = None
self._load_data()
def _load_data(self):
try:
self.df = pl.read_parquet(self.parquet_path)
except Exception as e:
print(f"[AI Interface] Load Error: {e}")
self.df = None
def get_tools_schema(self) -> List[Dict[str, Any]]:
"""
Returns the JSON schema for tools available to the AI.
"""
return [
{
"type": "function",
"function": {
"name": "get_error_stats",
"description": "Get statistics of errors (e.g., ATC status, Warnings) for the current log.",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
},
{
"type": "function",
"function": {
"name": "search_station_event",
"description": "Find events happening at a specific station.",
"parameters": {
"type": "object",
"properties": {
"station_name": {"type": "string", "description": "Name of the station (e.g., '범어사')"}
},
"required": ["station_name"]
}
}
}
]
def get_error_stats(self) -> Dict[str, Any]:
"""
Tool: Get Error Statistics
"""
if self.df is None:
return {"error": "No data loaded"}
# Example: Count ATC Status != 'ATC ACTIVE' (64)
# Note: atc_byte 64 is ACTIVE.
# We need to filter where system_id == 1 (Main) usually?
# Let's return stats for Main system (1)
main_df = self.df.filter(pl.col("system_id") == 1)
total_rows = len(main_df)
# Count non-active frames
# 64 is ATC ACTIVE.
non_active_count = main_df.filter(pl.col("atc_byte") != 64).height
return {
"total_frames": total_rows,
"non_active_frames": non_active_count,
"active_ratio": (total_rows - non_active_count) / total_rows if total_rows > 0 else 0
}
def search_station_event(self, station_name: str) -> Dict[str, Any]:
"""
Tool: Search events at a station
"""
if self.df is None:
return {"error": "No data loaded"}
# Filter by station name
# We assume 'station_name' column exists (added by FastParser)
station_df = self.df.filter(pl.col("station_name") == station_name)
if station_df.is_empty():
return {"message": f"No events found at {station_name}"}
# Return summary of that station visit
# e.g. avg speed, time range
return {
"station": station_name,
"visit_count": station_df.height,
"avg_speed": station_df["trainspeed"].mean(),
"start_time": str(station_df["timestamp"].min()),
"end_time": str(station_df["timestamp"].max())
}
def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""
Dispatcher for tool execution.
"""
if tool_name == "get_error_stats":
return self.get_error_stats()
elif tool_name == "search_station_event":
return self.search_station_event(arguments.get("station_name", ""))
else:
return {"error": "Unknown tool"}