HUTAMS_AUDIO/app/services/stt_service.py

195 lines
8.3 KiB
Python

import os
import time
from app.core.config import settings
from app.core.exceptions import AudioFileNotFoundError, ModelNotFoundError, STTError
from app.models.stt import STTRequest, STTResponse
from app.services.audio_parser import AudioParser
from app.services.speaker_classifier import classify_speaker
import logging
logger = logging.getLogger("uvicorn.error")
try:
from faster_whisper import WhisperModel
except ImportError:
WhisperModel = None
class WhisperSTTService:
"""
faster-whisper를 이용하여 오디오를 텍스트로 변환하는 서비스 클래스.
CPU / 내장그래픽 최적화를 위해 int8 양자화 모델을 기반으로 동작.
"""
def __init__(self):
self.model_name = settings.WHISPER_MODEL_NAME
self._model = None
def _load_model(self):
"""지연 로딩(Lazy Loading) 방식으로 모델을 초기화함."""
if WhisperModel is None:
raise ModelNotFoundError(
"faster-whisper 모듈이 설치되어 있지 않습니다. "
"'pip install faster-whisper'를 확인하세요."
)
if self._model is None:
try:
# CPU 타겟, int8 양자화로 모델 객체 생성. (최초 시도시 모델 자동 다운로드됨)
self._model = WhisperModel(self.model_name, device="cpu", compute_type="int8")
except Exception as e:
raise ModelNotFoundError(f"Whisper 모델 로드 실패 ({self.model_name}): {str(e)}")
def transcribe(self, request: STTRequest) -> STTResponse:
"""
[Pydantic 모델 기반 입출력]
오디오 파일을 STT 엔진에 전달하고 변환 결과 텍스트를 반환함.
"""
if not os.path.exists(request.audio_file_path):
raise AudioFileNotFoundError(request.audio_file_path)
import tracemalloc
tracemalloc.start()
load_start = time.time()
self._load_model()
load_time = time.time() - load_start
start_time = time.time()
processed_wav = None
try:
# 1. 오디오 포맷 전처리 (m4a, mp3 -> 16kHz mono wav 변환)
processed_wav = AudioParser.preprocess_audio(request.audio_file_path)
# 오디오 길이(duration) 측정
# pydub를 사용하여 길이를 구하거나 wave 파일 헤더를 읽음 (간단히 pydub 사용)
from pydub import AudioSegment
audio_segment = AudioSegment.from_file(processed_wav)
audio_duration_sec = len(audio_segment) / 1000.0
# 사전 데이터 모듈을 통한 철도/지하철 용어 프롬프트 주입
from app.core.dictionary import domain_dict
stations_prompt = domain_dict.get_prompt()
# 2. faster-whisper API 호출
segments, info = self._model.transcribe(
processed_wav,
beam_size=5,
initial_prompt=stations_prompt,
vad_filter=True, # VAD(단말기 감지)를 켜서 빈 구간 스킵 및 속도 향상
word_timestamps=False,
condition_on_previous_text=False # 무전 특성상 이전 맥락 환각 방지
)
# 3. generator를 순회하며 텍스트 추출 및 세그먼트 매핑
from app.models.stt import STTSegment
from datetime import timedelta
stt_segments = []
final_text_parts = []
# 화자 분리 추적 변수
current_speaker = "미상"
prev_end_sec = 0.0
for segment in segments:
raw_seg_text = segment.text.strip()
if not raw_seg_text:
continue
# 개별 세그먼트에 후처리 필터 적용 (RapidFuzz 교정기)
corrected_seg_text = domain_dict.post_process_correction(raw_seg_text, threshold=88.0).strip()
# 빈 문자열로 교정된 경우 스킵
if not corrected_seg_text:
continue
final_text_parts.append(corrected_seg_text)
# 절대 시간 연산 (타임존 방어: naive datetime → KST 강제 적용)
abs_start_time = None
abs_end_time = None
if request.base_datetime:
try:
from datetime import timezone
from zoneinfo import ZoneInfo
kst = ZoneInfo("Asia/Seoul")
base = request.base_datetime
# naive datetime이면 KST로 강제 localize
if base.tzinfo is None:
base = base.replace(tzinfo=kst)
abs_start_time = (base + timedelta(seconds=segment.start)).isoformat()
abs_end_time = (base + timedelta(seconds=segment.end)).isoformat()
except Exception:
pass # 안전을 위해 예외시 None으로 둠
# ── 화자 분류: speaker_classifier (휴리스틱 → LLM JSON) ──
context_chunk = " ".join(final_text_parts[-3:-1]) if len(final_text_parts) > 1 else ""
current_speaker = classify_speaker(corrected_seg_text, context_chunk)
logger.debug(f"[STT] 세그먼트 화자: '{corrected_seg_text[:30]}'{current_speaker}")
prev_end_sec = segment.end
# [Chapter 8.0] 단일 세그먼트 Opus 압축 인코딩
import subprocess
import uuid
os.makedirs("data/audio", exist_ok=True)
audio_path = f"data/audio/seg_{int(time.time())}_{uuid.uuid4().hex[:6]}.ogg"
try:
subprocess.run([
"ffmpeg", "-y", "-i", processed_wav,
"-ss", str(segment.start), "-to", str(segment.end),
"-c:a", "libopus", "-b:a", "16k", audio_path
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
except Exception as e:
logger.warning(f"ffmpeg Opus 인코딩 실패: {e}")
audio_path = None
stt_segments.append(
STTSegment(
start_sec=round(segment.start, 2),
end_sec=round(segment.end, 2),
text=corrected_seg_text,
speaker=current_speaker,
absolute_start_time=abs_start_time,
absolute_end_time=abs_end_time,
audio_path=audio_path
)
)
processing_time = time.time() - start_time
# 메모리 프로파일링 완료
current_mem, peak_mem = tracemalloc.get_traced_memory()
tracemalloc.stop()
# 4. Pydantic Response 반환
from app.services.analyzer import check_urgency, extract_train_number
final_text = " ".join(final_text_parts)
urgency = check_urgency(final_text)
train_number = extract_train_number(final_text)
return STTResponse(
text=final_text,
language=info.language if info else request.language,
segments=stt_segments,
processing_time_sec=round(processing_time, 2),
load_time_sec=round(load_time, 2),
audio_duration_sec=round(audio_duration_sec, 2),
peak_memory_mb=round(peak_mem / 1024 / 1024, 2),
process_speed_x=round(audio_duration_sec / processing_time, 2) if processing_time > 0 else 0,
urgency=urgency,
train_number=train_number
)
except STTError:
raise
except Exception as e:
raise STTError(f"STT 변환 실패: {str(e)}")
finally:
if processed_wav and processed_wav != request.audio_file_path:
AudioParser.cleanup(processed_wav)