AI_MMI_Analyser/app/ai/providers/xai_provider.py

147 lines
4.5 KiB
Python

"""
xAI (Grok) Provider 구현
xAI API는 OpenAI 호환 형식을 사용합니다.
"""
import asyncio
from typing import Optional, List, Dict, Any
from ..base import BaseAIProvider, AIProviderType, AIMessage, AIResponse
class XAIProvider(BaseAIProvider):
"""xAI (Grok) API 프로바이더"""
BASE_URL = "https://api.x.ai/v1"
AVAILABLE_MODELS = [
"grok-beta",
"grok-2-1212",
"grok-2-vision-1212",
]
def __init__(self, api_key: str, model: Optional[str] = None):
super().__init__(api_key, model)
self._client = None
self._async_client = None
@property
def provider_type(self) -> AIProviderType:
return AIProviderType.XAI
@property
def provider_name(self) -> str:
return "xAI (Grok)"
@property
def default_model(self) -> str:
return "grok-beta"
@property
def available_models(self) -> List[str]:
return self.AVAILABLE_MODELS.copy()
def initialize(self) -> bool:
"""xAI 클라이언트 초기화 (OpenAI 호환 API 사용)"""
if not self.validate_api_key():
return False
try:
from openai import OpenAI, AsyncOpenAI
self._client = OpenAI(
api_key=self._api_key,
base_url=self.BASE_URL
)
self._async_client = AsyncOpenAI(
api_key=self._api_key,
base_url=self.BASE_URL
)
self._is_initialized = True
return True
except ImportError:
print("OpenAI 라이브러리가 설치되지 않았습니다. pip install openai 를 실행하세요.")
return False
except Exception as e:
print(f"xAI 초기화 실패: {e}")
return False
def _convert_messages(self, messages: List[AIMessage]) -> List[Dict[str, str]]:
"""AIMessage를 OpenAI 형식으로 변환"""
return [{"role": msg.role, "content": msg.content} for msg in messages]
async def chat(
self,
messages: List[AIMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
**kwargs
) -> AIResponse:
"""비동기 채팅 요청"""
if not self._is_initialized:
if not self.initialize():
raise RuntimeError("xAI 초기화 실패")
converted_messages = self._convert_messages(messages)
params = {
"model": self._model,
"messages": converted_messages,
"temperature": temperature,
}
if max_tokens:
params["max_tokens"] = max_tokens
params.update(kwargs)
response = await self._async_client.chat.completions.create(**params)
return AIResponse(
content=response.choices[0].message.content,
model=response.model,
provider=self.provider_name,
usage={
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
} if response.usage else None,
raw_response=response
)
def chat_sync(
self,
messages: List[AIMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
**kwargs
) -> AIResponse:
"""동기 채팅 요청"""
if not self._is_initialized:
if not self.initialize():
raise RuntimeError("xAI 초기화 실패")
converted_messages = self._convert_messages(messages)
params = {
"model": self._model,
"messages": converted_messages,
"temperature": temperature,
}
if max_tokens:
params["max_tokens"] = max_tokens
params.update(kwargs)
response = self._client.chat.completions.create(**params)
return AIResponse(
content=response.choices[0].message.content,
model=response.model,
provider=self.provider_name,
usage={
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
} if response.usage else None,
raw_response=response
)