AI_MMI_Analyser/app/ai/providers/gemini_provider.py

157 lines
5.0 KiB
Python

"""
Google Gemini Provider 구현
"""
import asyncio
from typing import Optional, List, Dict, Any
from ..base import BaseAIProvider, AIProviderType, AIMessage, AIResponse
class GeminiProvider(BaseAIProvider):
"""Google Gemini API 프로바이더"""
AVAILABLE_MODELS = [
"gemini-2.0-flash-exp",
"gemini-1.5-pro",
"gemini-1.5-flash",
"gemini-1.5-flash-8b",
"gemini-1.0-pro",
]
def __init__(self, api_key: str, model: Optional[str] = None):
super().__init__(api_key, model)
self._client = None
@property
def provider_type(self) -> AIProviderType:
return AIProviderType.GEMINI
@property
def provider_name(self) -> str:
return "Google Gemini"
@property
def default_model(self) -> str:
return "gemini-1.5-flash"
@property
def available_models(self) -> List[str]:
return self.AVAILABLE_MODELS.copy()
def initialize(self) -> bool:
"""Gemini 클라이언트 초기화"""
if not self.validate_api_key():
return False
try:
import google.generativeai as genai
genai.configure(api_key=self._api_key)
self._client = genai
self._is_initialized = True
return True
except ImportError:
print("Google Generative AI 라이브러리가 설치되지 않았습니다. pip install google-generativeai 를 실행하세요.")
return False
except Exception as e:
print(f"Gemini 초기화 실패: {e}")
return False
def _convert_messages_to_gemini(self, messages: List[AIMessage]) -> tuple:
"""
AIMessage를 Gemini 형식으로 변환
Returns:
(system_instruction, history, last_user_message)
"""
system_instruction = None
history = []
last_user_message = None
for msg in messages:
if msg.role == "system":
system_instruction = msg.content
elif msg.role == "user":
last_user_message = msg.content
# 이전 user 메시지는 history에 추가
if len([m for m in messages if m.role == "user"]) > 1:
history.append({"role": "user", "parts": [msg.content]})
elif msg.role == "assistant":
history.append({"role": "model", "parts": [msg.content]})
# 마지막 user 메시지를 history에서 제거 (send_message로 별도 전송)
if history and history[-1].get("role") == "user":
history.pop()
return system_instruction, history, last_user_message
async def chat(
self,
messages: List[AIMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
**kwargs
) -> AIResponse:
"""비동기 채팅 요청 (Gemini는 sync를 async로 래핑)"""
# Gemini SDK는 기본적으로 동기식이므로 executor에서 실행
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: self.chat_sync(messages, temperature, max_tokens, **kwargs)
)
def chat_sync(
self,
messages: List[AIMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
**kwargs
) -> AIResponse:
"""동기 채팅 요청"""
if not self._is_initialized:
if not self.initialize():
raise RuntimeError("Gemini 초기화 실패")
system_instruction, history, last_user_message = self._convert_messages_to_gemini(messages)
# 모델 생성 설정
generation_config = {
"temperature": temperature,
}
if max_tokens:
generation_config["max_output_tokens"] = max_tokens
# 모델 인스턴스 생성
model = self._client.GenerativeModel(
model_name=self._model,
generation_config=generation_config,
system_instruction=system_instruction
)
# 대화 시작
if history:
chat = model.start_chat(history=history)
else:
chat = model.start_chat()
# 메시지 전송
response = chat.send_message(last_user_message or "")
# 토큰 사용량 추출 시도
usage = None
if hasattr(response, 'usage_metadata'):
usage = {
"prompt_tokens": getattr(response.usage_metadata, 'prompt_token_count', 0),
"completion_tokens": getattr(response.usage_metadata, 'candidates_token_count', 0),
"total_tokens": getattr(response.usage_metadata, 'total_token_count', 0),
}
return AIResponse(
content=response.text,
model=self._model,
provider=self.provider_name,
usage=usage,
raw_response=response
)