142 lines
4.3 KiB
Python
142 lines
4.3 KiB
Python
"""
|
|
OpenAI Provider 구현
|
|
"""
|
|
|
|
import asyncio
|
|
from typing import Optional, List, Dict, Any
|
|
|
|
from ..base import BaseAIProvider, AIProviderType, AIMessage, AIResponse
|
|
|
|
|
|
class OpenAIProvider(BaseAIProvider):
|
|
"""OpenAI API 프로바이더"""
|
|
|
|
AVAILABLE_MODELS = [
|
|
"gpt-4o",
|
|
"gpt-4o-mini",
|
|
"gpt-4-turbo",
|
|
"gpt-4",
|
|
"gpt-3.5-turbo",
|
|
"o1-preview",
|
|
"o1-mini",
|
|
]
|
|
|
|
def __init__(self, api_key: str, model: Optional[str] = None):
|
|
super().__init__(api_key, model)
|
|
self._client = None
|
|
self._async_client = None
|
|
|
|
@property
|
|
def provider_type(self) -> AIProviderType:
|
|
return AIProviderType.OPENAI
|
|
|
|
@property
|
|
def provider_name(self) -> str:
|
|
return "OpenAI"
|
|
|
|
@property
|
|
def default_model(self) -> str:
|
|
return "gpt-4o-mini"
|
|
|
|
@property
|
|
def available_models(self) -> List[str]:
|
|
return self.AVAILABLE_MODELS.copy()
|
|
|
|
def initialize(self) -> bool:
|
|
"""OpenAI 클라이언트 초기화"""
|
|
if not self.validate_api_key():
|
|
return False
|
|
|
|
try:
|
|
from openai import OpenAI, AsyncOpenAI
|
|
|
|
self._client = OpenAI(api_key=self._api_key)
|
|
self._async_client = AsyncOpenAI(api_key=self._api_key)
|
|
self._is_initialized = True
|
|
return True
|
|
except ImportError:
|
|
print("OpenAI 라이브러리가 설치되지 않았습니다. pip install openai 를 실행하세요.")
|
|
return False
|
|
except Exception as e:
|
|
print(f"OpenAI 초기화 실패: {e}")
|
|
return False
|
|
|
|
def _convert_messages(self, messages: List[AIMessage]) -> List[Dict[str, str]]:
|
|
"""AIMessage를 OpenAI 형식으로 변환"""
|
|
return [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
|
|
async def chat(
|
|
self,
|
|
messages: List[AIMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
**kwargs
|
|
) -> AIResponse:
|
|
"""비동기 채팅 요청"""
|
|
if not self._is_initialized:
|
|
if not self.initialize():
|
|
raise RuntimeError("OpenAI 초기화 실패")
|
|
|
|
converted_messages = self._convert_messages(messages)
|
|
|
|
params = {
|
|
"model": self._model,
|
|
"messages": converted_messages,
|
|
"temperature": temperature,
|
|
}
|
|
if max_tokens:
|
|
params["max_tokens"] = max_tokens
|
|
params.update(kwargs)
|
|
|
|
response = await self._async_client.chat.completions.create(**params)
|
|
|
|
return AIResponse(
|
|
content=response.choices[0].message.content,
|
|
model=response.model,
|
|
provider=self.provider_name,
|
|
usage={
|
|
"prompt_tokens": response.usage.prompt_tokens,
|
|
"completion_tokens": response.usage.completion_tokens,
|
|
"total_tokens": response.usage.total_tokens,
|
|
} if response.usage else None,
|
|
raw_response=response
|
|
)
|
|
|
|
def chat_sync(
|
|
self,
|
|
messages: List[AIMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
**kwargs
|
|
) -> AIResponse:
|
|
"""동기 채팅 요청"""
|
|
if not self._is_initialized:
|
|
if not self.initialize():
|
|
raise RuntimeError("OpenAI 초기화 실패")
|
|
|
|
converted_messages = self._convert_messages(messages)
|
|
|
|
params = {
|
|
"model": self._model,
|
|
"messages": converted_messages,
|
|
"temperature": temperature,
|
|
}
|
|
if max_tokens:
|
|
params["max_tokens"] = max_tokens
|
|
params.update(kwargs)
|
|
|
|
response = self._client.chat.completions.create(**params)
|
|
|
|
return AIResponse(
|
|
content=response.choices[0].message.content,
|
|
model=response.model,
|
|
provider=self.provider_name,
|
|
usage={
|
|
"prompt_tokens": response.usage.prompt_tokens,
|
|
"completion_tokens": response.usage.completion_tokens,
|
|
"total_tokens": response.usage.total_tokens,
|
|
} if response.usage else None,
|
|
raw_response=response
|
|
)
|
|
|