106 lines
3.9 KiB
Python
106 lines
3.9 KiB
Python
import os
|
|
import pytest
|
|
from app.models.stt import STTRequest, STTResponse
|
|
from app.core.exceptions import AudioFileNotFoundError, ModelNotFoundError
|
|
from app.services.stt_service import WhisperSTTService
|
|
|
|
class MockSegment:
|
|
def __init__(self, text):
|
|
self.text = text
|
|
|
|
class MockInfo:
|
|
def __init__(self, language):
|
|
self.language = language
|
|
|
|
class MockWhisperModel:
|
|
"""faster-whisper 엔진을 흉내내는 테스트용 Mock 클래스"""
|
|
def __init__(self, model_size_or_path, device="cpu", compute_type="int8"):
|
|
self.model_size_or_path = model_size_or_path
|
|
|
|
def transcribe(self, audio_path, beam_size=5):
|
|
# 의도된 테스트 출력 결과 (무전 기록)
|
|
segments = [MockSegment("열차 번호 1234, 기지 입고 확인했습니다. 수신 양호합니다.")]
|
|
info = MockInfo(language="ko")
|
|
return segments, info
|
|
|
|
def test_pydantic_schema_validation():
|
|
"""Pydantic 모델이 오염된 데이터를 차단하고 정상 생성되는지 테스트"""
|
|
req = STTRequest(audio_file_path="dummy.wav")
|
|
assert req.audio_file_path == "dummy.wav"
|
|
assert req.language == "ko" # Default 값 확인
|
|
|
|
def test_stt_transcription_success(monkeypatch, tmp_path):
|
|
"""STT 전체 프로세스가 정상적으로 처리되는지 확인하는 테스트"""
|
|
|
|
# 1. 테스트용 더미 오디오
|
|
dummy_audio = tmp_path / "test_audio.wav"
|
|
dummy_audio.touch()
|
|
|
|
# 2. 서비스 인스턴스 초기화
|
|
service = WhisperSTTService()
|
|
service.model_name = "test_base"
|
|
|
|
# 3. WhisperModel 및 AudioParser Mock 객체로 대체(Patch)
|
|
monkeypatch.setattr("app.services.stt_service.WhisperModel", MockWhisperModel)
|
|
|
|
# AudioParser는 변환 대신 그대로 리턴하게 모킹
|
|
class MockAudioParser:
|
|
@staticmethod
|
|
def preprocess_audio(file_path, sample_rate=16000):
|
|
return file_path
|
|
@staticmethod
|
|
def cleanup(file_path):
|
|
pass
|
|
|
|
monkeypatch.setattr("app.services.stt_service.AudioParser", MockAudioParser)
|
|
|
|
# 4. Request 생성 및 실행
|
|
request = STTRequest(audio_file_path=str(dummy_audio))
|
|
response = service.transcribe(request)
|
|
|
|
# 5. Response 검증 (Pydantic 모델인지, 텍스트가 정상인지 검사)
|
|
assert isinstance(response, STTResponse)
|
|
assert "열차 번호 1234" in response.text
|
|
assert response.language == "ko"
|
|
assert response.processing_time_sec is not None
|
|
|
|
def test_stt_audio_file_not_found():
|
|
"""존재하지 않는 오디오 파일 요청 시 커스텀 예외 발생 테스트"""
|
|
service = WhisperSTTService()
|
|
request = STTRequest(audio_file_path="invalid_path.wav")
|
|
|
|
with pytest.raises(AudioFileNotFoundError) as exc_info:
|
|
service.transcribe(request)
|
|
|
|
assert "오디오 파일을 찾을 수 없습니다" in str(exc_info.value)
|
|
|
|
def test_stt_model_file_not_found(monkeypatch, tmp_path):
|
|
"""Whisper 모델 로딩 실패 시 발생하는 예외 테스트"""
|
|
dummy_audio = tmp_path / "test_audio.wav"
|
|
dummy_audio.touch()
|
|
|
|
# AudioParser 모킹
|
|
class MockAudioParser:
|
|
@staticmethod
|
|
def preprocess_audio(file_path, sample_rate=16000):
|
|
return file_path
|
|
@staticmethod
|
|
def cleanup(file_path):
|
|
pass
|
|
|
|
monkeypatch.setattr("app.services.stt_service.AudioParser", MockAudioParser)
|
|
|
|
# 모델 로딩을 강제로 실패하게 하는 모킹
|
|
def mock_init(*args, **kwargs):
|
|
raise ValueError("Simulated model loading error")
|
|
|
|
monkeypatch.setattr("app.services.stt_service.WhisperModel", mock_init)
|
|
|
|
service = WhisperSTTService()
|
|
service.model_name = "invalid_model_path_or_name"
|
|
|
|
request = STTRequest(audio_file_path=str(dummy_audio))
|
|
|
|
with pytest.raises(ModelNotFoundError) as exc_info:
|
|
service.transcribe(request)
|