203 lines
8.7 KiB
Python
203 lines
8.7 KiB
Python
import os
|
|
import time
|
|
from app.core.config import settings
|
|
from app.core.exceptions import AudioFileNotFoundError, ModelNotFoundError, STTError
|
|
from app.models.stt import STTRequest, STTResponse
|
|
from app.services.audio_parser import AudioParser
|
|
from app.services.speaker_classifier import classify_speaker
|
|
import logging
|
|
|
|
logger = logging.getLogger("uvicorn.error")
|
|
|
|
try:
|
|
from faster_whisper import WhisperModel
|
|
except ImportError:
|
|
WhisperModel = None
|
|
|
|
|
|
|
|
class WhisperSTTService:
|
|
"""
|
|
faster-whisper를 이용하여 오디오를 텍스트로 변환하는 서비스 클래스.
|
|
CPU / 내장그래픽 최적화를 위해 int8 양자화 모델을 기반으로 동작.
|
|
"""
|
|
def __init__(self):
|
|
self.model_name = settings.WHISPER_MODEL_NAME
|
|
self._model = None
|
|
|
|
def _load_model(self):
|
|
"""지연 로딩(Lazy Loading) 방식으로 모델을 초기화함."""
|
|
if WhisperModel is None:
|
|
raise ModelNotFoundError(
|
|
"faster-whisper 모듈이 설치되어 있지 않습니다. "
|
|
"'pip install faster-whisper'를 확인하세요."
|
|
)
|
|
|
|
if self._model is None:
|
|
try:
|
|
# CPU 타겟, int8 양자화로 모델 객체 생성.
|
|
# WHISPER_MODEL_PATH가 설정되어 있으면 해당 폴더를 다운로드 루트로 사용 (퀔야 3기동 후 오프라인 가능)
|
|
from app.core.config import settings
|
|
download_root = settings.WHISPER_MODEL_PATH if settings.WHISPER_MODEL_PATH else None
|
|
self._model = WhisperModel(
|
|
self.model_name,
|
|
device="cpu",
|
|
compute_type="int8",
|
|
download_root=download_root
|
|
)
|
|
except Exception as e:
|
|
raise ModelNotFoundError(f"Whisper 모델 로드 실패 ({self.model_name}): {str(e)}")
|
|
|
|
def transcribe(self, request: STTRequest) -> STTResponse:
|
|
"""
|
|
[Pydantic 모델 기반 입출력]
|
|
오디오 파일을 STT 엔진에 전달하고 변환 결과 텍스트를 반환함.
|
|
"""
|
|
if not os.path.exists(request.audio_file_path):
|
|
raise AudioFileNotFoundError(request.audio_file_path)
|
|
|
|
import tracemalloc
|
|
tracemalloc.start()
|
|
|
|
load_start = time.time()
|
|
self._load_model()
|
|
load_time = time.time() - load_start
|
|
|
|
start_time = time.time()
|
|
processed_wav = None
|
|
|
|
try:
|
|
# 1. 오디오 포맷 전처리 (m4a, mp3 -> 16kHz mono wav 변환)
|
|
processed_wav = AudioParser.preprocess_audio(request.audio_file_path)
|
|
|
|
# 오디오 길이(duration) 측정
|
|
# pydub를 사용하여 길이를 구하거나 wave 파일 헤더를 읽음 (간단히 pydub 사용)
|
|
from pydub import AudioSegment
|
|
audio_segment = AudioSegment.from_file(processed_wav)
|
|
audio_duration_sec = len(audio_segment) / 1000.0
|
|
|
|
# 사전 데이터 모듈을 통한 철도/지하철 용어 프롬프트 주입
|
|
from app.core.dictionary import domain_dict
|
|
stations_prompt = domain_dict.get_prompt()
|
|
|
|
# 2. faster-whisper API 호출
|
|
segments, info = self._model.transcribe(
|
|
processed_wav,
|
|
beam_size=5,
|
|
initial_prompt=stations_prompt,
|
|
vad_filter=True, # VAD(단말기 감지)를 켜서 빈 구간 스킵 및 속도 향상
|
|
word_timestamps=False,
|
|
condition_on_previous_text=False # 무전 특성상 이전 맥락 환각 방지
|
|
)
|
|
|
|
# 3. generator를 순회하며 텍스트 추출 및 세그먼트 매핑
|
|
from app.models.stt import STTSegment
|
|
from datetime import timedelta
|
|
|
|
stt_segments = []
|
|
final_text_parts = []
|
|
|
|
# 화자 분리 추적 변수
|
|
current_speaker = "미상"
|
|
prev_end_sec = 0.0
|
|
|
|
for segment in segments:
|
|
raw_seg_text = segment.text.strip()
|
|
if not raw_seg_text:
|
|
continue
|
|
|
|
# 개별 세그먼트에 후처리 필터 적용 (RapidFuzz 교정기)
|
|
corrected_seg_text = domain_dict.post_process_correction(raw_seg_text, threshold=88.0).strip()
|
|
|
|
# 빈 문자열로 교정된 경우 스킵
|
|
if not corrected_seg_text:
|
|
continue
|
|
|
|
final_text_parts.append(corrected_seg_text)
|
|
|
|
# 절대 시간 연산 (타임존 방어: naive datetime → KST 강제 적용)
|
|
abs_start_time = None
|
|
abs_end_time = None
|
|
if request.base_datetime:
|
|
try:
|
|
from datetime import timezone
|
|
from zoneinfo import ZoneInfo
|
|
kst = ZoneInfo("Asia/Seoul")
|
|
|
|
base = request.base_datetime
|
|
# naive datetime이면 KST로 강제 localize
|
|
if base.tzinfo is None:
|
|
base = base.replace(tzinfo=kst)
|
|
|
|
abs_start_time = (base + timedelta(seconds=segment.start)).isoformat()
|
|
abs_end_time = (base + timedelta(seconds=segment.end)).isoformat()
|
|
except Exception:
|
|
pass # 안전을 위해 예외시 None으로 둠
|
|
|
|
# ── 화자 분류: speaker_classifier (휴리스틱 → LLM JSON) ──
|
|
context_chunk = " ".join(final_text_parts[-3:-1]) if len(final_text_parts) > 1 else ""
|
|
current_speaker = classify_speaker(corrected_seg_text, context_chunk)
|
|
logger.debug(f"[STT] 세그먼트 화자: '{corrected_seg_text[:30]}' → {current_speaker}")
|
|
|
|
prev_end_sec = segment.end
|
|
|
|
# [Chapter 8.0] 단일 세그먼트 Opus 압축 인코딩
|
|
import subprocess
|
|
import uuid
|
|
os.makedirs("data/audio", exist_ok=True)
|
|
audio_path = f"data/audio/seg_{int(time.time())}_{uuid.uuid4().hex[:6]}.ogg"
|
|
try:
|
|
subprocess.run([
|
|
"ffmpeg", "-y", "-i", processed_wav,
|
|
"-ss", str(segment.start), "-to", str(segment.end),
|
|
"-c:a", "libopus", "-b:a", "16k", audio_path
|
|
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
|
|
except Exception as e:
|
|
logger.warning(f"ffmpeg Opus 인코딩 실패: {e}")
|
|
audio_path = None
|
|
|
|
stt_segments.append(
|
|
STTSegment(
|
|
start_sec=round(segment.start, 2),
|
|
end_sec=round(segment.end, 2),
|
|
text=corrected_seg_text,
|
|
speaker=current_speaker,
|
|
absolute_start_time=abs_start_time,
|
|
absolute_end_time=abs_end_time,
|
|
audio_path=audio_path
|
|
)
|
|
)
|
|
|
|
processing_time = time.time() - start_time
|
|
|
|
# 메모리 프로파일링 완료
|
|
current_mem, peak_mem = tracemalloc.get_traced_memory()
|
|
tracemalloc.stop()
|
|
|
|
# 4. Pydantic Response 반환
|
|
from app.services.analyzer import check_urgency, extract_train_number
|
|
final_text = " ".join(final_text_parts)
|
|
urgency = check_urgency(final_text)
|
|
train_number = extract_train_number(final_text)
|
|
|
|
return STTResponse(
|
|
text=final_text,
|
|
language=info.language if info else request.language,
|
|
segments=stt_segments,
|
|
processing_time_sec=round(processing_time, 2),
|
|
load_time_sec=round(load_time, 2),
|
|
audio_duration_sec=round(audio_duration_sec, 2),
|
|
peak_memory_mb=round(peak_mem / 1024 / 1024, 2),
|
|
process_speed_x=round(audio_duration_sec / processing_time, 2) if processing_time > 0 else 0,
|
|
urgency=urgency,
|
|
train_number=train_number
|
|
)
|
|
|
|
except STTError:
|
|
raise
|
|
except Exception as e:
|
|
raise STTError(f"STT 변환 실패: {str(e)}")
|
|
finally:
|
|
if processed_wav and processed_wav != request.audio_file_path:
|
|
AudioParser.cleanup(processed_wav)
|