147 lines
4.5 KiB
Python
147 lines
4.5 KiB
Python
"""
|
|
xAI (Grok) Provider 구현
|
|
xAI API는 OpenAI 호환 형식을 사용합니다.
|
|
"""
|
|
|
|
import asyncio
|
|
from typing import Optional, List, Dict, Any
|
|
|
|
from ..base import BaseAIProvider, AIProviderType, AIMessage, AIResponse
|
|
|
|
|
|
class XAIProvider(BaseAIProvider):
|
|
"""xAI (Grok) API 프로바이더"""
|
|
|
|
BASE_URL = "https://api.x.ai/v1"
|
|
|
|
AVAILABLE_MODELS = [
|
|
"grok-beta",
|
|
"grok-2-1212",
|
|
"grok-2-vision-1212",
|
|
]
|
|
|
|
def __init__(self, api_key: str, model: Optional[str] = None):
|
|
super().__init__(api_key, model)
|
|
self._client = None
|
|
self._async_client = None
|
|
|
|
@property
|
|
def provider_type(self) -> AIProviderType:
|
|
return AIProviderType.XAI
|
|
|
|
@property
|
|
def provider_name(self) -> str:
|
|
return "xAI (Grok)"
|
|
|
|
@property
|
|
def default_model(self) -> str:
|
|
return "grok-beta"
|
|
|
|
@property
|
|
def available_models(self) -> List[str]:
|
|
return self.AVAILABLE_MODELS.copy()
|
|
|
|
def initialize(self) -> bool:
|
|
"""xAI 클라이언트 초기화 (OpenAI 호환 API 사용)"""
|
|
if not self.validate_api_key():
|
|
return False
|
|
|
|
try:
|
|
from openai import OpenAI, AsyncOpenAI
|
|
|
|
self._client = OpenAI(
|
|
api_key=self._api_key,
|
|
base_url=self.BASE_URL
|
|
)
|
|
self._async_client = AsyncOpenAI(
|
|
api_key=self._api_key,
|
|
base_url=self.BASE_URL
|
|
)
|
|
self._is_initialized = True
|
|
return True
|
|
except ImportError:
|
|
print("OpenAI 라이브러리가 설치되지 않았습니다. pip install openai 를 실행하세요.")
|
|
return False
|
|
except Exception as e:
|
|
print(f"xAI 초기화 실패: {e}")
|
|
return False
|
|
|
|
def _convert_messages(self, messages: List[AIMessage]) -> List[Dict[str, str]]:
|
|
"""AIMessage를 OpenAI 형식으로 변환"""
|
|
return [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
|
|
async def chat(
|
|
self,
|
|
messages: List[AIMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
**kwargs
|
|
) -> AIResponse:
|
|
"""비동기 채팅 요청"""
|
|
if not self._is_initialized:
|
|
if not self.initialize():
|
|
raise RuntimeError("xAI 초기화 실패")
|
|
|
|
converted_messages = self._convert_messages(messages)
|
|
|
|
params = {
|
|
"model": self._model,
|
|
"messages": converted_messages,
|
|
"temperature": temperature,
|
|
}
|
|
if max_tokens:
|
|
params["max_tokens"] = max_tokens
|
|
params.update(kwargs)
|
|
|
|
response = await self._async_client.chat.completions.create(**params)
|
|
|
|
return AIResponse(
|
|
content=response.choices[0].message.content,
|
|
model=response.model,
|
|
provider=self.provider_name,
|
|
usage={
|
|
"prompt_tokens": response.usage.prompt_tokens,
|
|
"completion_tokens": response.usage.completion_tokens,
|
|
"total_tokens": response.usage.total_tokens,
|
|
} if response.usage else None,
|
|
raw_response=response
|
|
)
|
|
|
|
def chat_sync(
|
|
self,
|
|
messages: List[AIMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
**kwargs
|
|
) -> AIResponse:
|
|
"""동기 채팅 요청"""
|
|
if not self._is_initialized:
|
|
if not self.initialize():
|
|
raise RuntimeError("xAI 초기화 실패")
|
|
|
|
converted_messages = self._convert_messages(messages)
|
|
|
|
params = {
|
|
"model": self._model,
|
|
"messages": converted_messages,
|
|
"temperature": temperature,
|
|
}
|
|
if max_tokens:
|
|
params["max_tokens"] = max_tokens
|
|
params.update(kwargs)
|
|
|
|
response = self._client.chat.completions.create(**params)
|
|
|
|
return AIResponse(
|
|
content=response.choices[0].message.content,
|
|
model=response.model,
|
|
provider=self.provider_name,
|
|
usage={
|
|
"prompt_tokens": response.usage.prompt_tokens,
|
|
"completion_tokens": response.usage.completion_tokens,
|
|
"total_tokens": response.usage.total_tokens,
|
|
} if response.usage else None,
|
|
raw_response=response
|
|
)
|
|
|