157 lines
5.0 KiB
Python
157 lines
5.0 KiB
Python
"""
|
|
Google Gemini Provider 구현
|
|
"""
|
|
|
|
import asyncio
|
|
from typing import Optional, List, Dict, Any
|
|
|
|
from ..base import BaseAIProvider, AIProviderType, AIMessage, AIResponse
|
|
|
|
|
|
class GeminiProvider(BaseAIProvider):
|
|
"""Google Gemini API 프로바이더"""
|
|
|
|
AVAILABLE_MODELS = [
|
|
"gemini-2.0-flash-exp",
|
|
"gemini-1.5-pro",
|
|
"gemini-1.5-flash",
|
|
"gemini-1.5-flash-8b",
|
|
"gemini-1.0-pro",
|
|
]
|
|
|
|
def __init__(self, api_key: str, model: Optional[str] = None):
|
|
super().__init__(api_key, model)
|
|
self._client = None
|
|
|
|
@property
|
|
def provider_type(self) -> AIProviderType:
|
|
return AIProviderType.GEMINI
|
|
|
|
@property
|
|
def provider_name(self) -> str:
|
|
return "Google Gemini"
|
|
|
|
@property
|
|
def default_model(self) -> str:
|
|
return "gemini-1.5-flash"
|
|
|
|
@property
|
|
def available_models(self) -> List[str]:
|
|
return self.AVAILABLE_MODELS.copy()
|
|
|
|
def initialize(self) -> bool:
|
|
"""Gemini 클라이언트 초기화"""
|
|
if not self.validate_api_key():
|
|
return False
|
|
|
|
try:
|
|
import google.generativeai as genai
|
|
|
|
genai.configure(api_key=self._api_key)
|
|
self._client = genai
|
|
self._is_initialized = True
|
|
return True
|
|
except ImportError:
|
|
print("Google Generative AI 라이브러리가 설치되지 않았습니다. pip install google-generativeai 를 실행하세요.")
|
|
return False
|
|
except Exception as e:
|
|
print(f"Gemini 초기화 실패: {e}")
|
|
return False
|
|
|
|
def _convert_messages_to_gemini(self, messages: List[AIMessage]) -> tuple:
|
|
"""
|
|
AIMessage를 Gemini 형식으로 변환
|
|
Returns:
|
|
(system_instruction, history, last_user_message)
|
|
"""
|
|
system_instruction = None
|
|
history = []
|
|
last_user_message = None
|
|
|
|
for msg in messages:
|
|
if msg.role == "system":
|
|
system_instruction = msg.content
|
|
elif msg.role == "user":
|
|
last_user_message = msg.content
|
|
# 이전 user 메시지는 history에 추가
|
|
if len([m for m in messages if m.role == "user"]) > 1:
|
|
history.append({"role": "user", "parts": [msg.content]})
|
|
elif msg.role == "assistant":
|
|
history.append({"role": "model", "parts": [msg.content]})
|
|
|
|
# 마지막 user 메시지를 history에서 제거 (send_message로 별도 전송)
|
|
if history and history[-1].get("role") == "user":
|
|
history.pop()
|
|
|
|
return system_instruction, history, last_user_message
|
|
|
|
async def chat(
|
|
self,
|
|
messages: List[AIMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
**kwargs
|
|
) -> AIResponse:
|
|
"""비동기 채팅 요청 (Gemini는 sync를 async로 래핑)"""
|
|
# Gemini SDK는 기본적으로 동기식이므로 executor에서 실행
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(
|
|
None,
|
|
lambda: self.chat_sync(messages, temperature, max_tokens, **kwargs)
|
|
)
|
|
|
|
def chat_sync(
|
|
self,
|
|
messages: List[AIMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
**kwargs
|
|
) -> AIResponse:
|
|
"""동기 채팅 요청"""
|
|
if not self._is_initialized:
|
|
if not self.initialize():
|
|
raise RuntimeError("Gemini 초기화 실패")
|
|
|
|
system_instruction, history, last_user_message = self._convert_messages_to_gemini(messages)
|
|
|
|
# 모델 생성 설정
|
|
generation_config = {
|
|
"temperature": temperature,
|
|
}
|
|
if max_tokens:
|
|
generation_config["max_output_tokens"] = max_tokens
|
|
|
|
# 모델 인스턴스 생성
|
|
model = self._client.GenerativeModel(
|
|
model_name=self._model,
|
|
generation_config=generation_config,
|
|
system_instruction=system_instruction
|
|
)
|
|
|
|
# 대화 시작
|
|
if history:
|
|
chat = model.start_chat(history=history)
|
|
else:
|
|
chat = model.start_chat()
|
|
|
|
# 메시지 전송
|
|
response = chat.send_message(last_user_message or "")
|
|
|
|
# 토큰 사용량 추출 시도
|
|
usage = None
|
|
if hasattr(response, 'usage_metadata'):
|
|
usage = {
|
|
"prompt_tokens": getattr(response.usage_metadata, 'prompt_token_count', 0),
|
|
"completion_tokens": getattr(response.usage_metadata, 'candidates_token_count', 0),
|
|
"total_tokens": getattr(response.usage_metadata, 'total_token_count', 0),
|
|
}
|
|
|
|
return AIResponse(
|
|
content=response.text,
|
|
model=self._model,
|
|
provider=self.provider_name,
|
|
usage=usage,
|
|
raw_response=response
|
|
)
|
|
|