217 lines
8.7 KiB
Python
217 lines
8.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
BriaAI 배경제거 모듈 테스트 스크립트
|
|
DirectML vs CPU 모드 비교 테스트
|
|
"""
|
|
import os
|
|
import sys
|
|
import cv2
|
|
import numpy as np
|
|
from datetime import datetime
|
|
|
|
# 현재 스크립트 디렉토리 및 프로젝트 루트 설정
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
src_dir = os.path.dirname(current_dir)
|
|
project_root = os.path.dirname(src_dir)
|
|
|
|
# Python 경로에 추가
|
|
if project_root not in sys.path:
|
|
sys.path.insert(0, project_root)
|
|
|
|
# 모듈 임포트
|
|
try:
|
|
# 직접 모듈 경로에서 임포트
|
|
sys.path.insert(0, src_dir)
|
|
from src.modules.gpu_utils import GPUManager
|
|
from src.modules.request_inpaint import Request_AI_Server
|
|
print("✅ 모듈 임포트 성공")
|
|
except ImportError as e:
|
|
print(f"❌ 모듈 임포트 실패: {e}")
|
|
print(f"디버그 정보: src_dir={src_dir}, project_root={project_root}")
|
|
print(f"Python path: {sys.path[:3]}...")
|
|
sys.exit(1)
|
|
|
|
|
|
class SimpleLogger:
|
|
"""간단한 로거 클래스"""
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def log(self, msg, level=None, exc_info=False):
|
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
level_str = f"[{level.name if hasattr(level, 'name') else str(level)}]" if level else ""
|
|
print(f"[{timestamp}] {level_str} {msg}")
|
|
|
|
|
|
def main():
|
|
print("=== BriaAI 배경제거 모듈 테스트 ===")
|
|
|
|
# 1. 테스트 이미지 확인
|
|
test_image_path = os.path.join(current_dir, "1.jpg")
|
|
if not os.path.exists(test_image_path):
|
|
print(f"❌ 테스트 이미지를 찾을 수 없습니다: {test_image_path}")
|
|
print("📝 같은 폴더에 '1.jpg' 파일을 복사해주세요.")
|
|
return False
|
|
|
|
print(f"✅ 테스트 이미지 발견: {test_image_path}")
|
|
|
|
# 2. 이미지 로드 및 확인
|
|
image_data = cv2.imread(test_image_path)
|
|
if image_data is None:
|
|
print(f"❌ 이미지 로드 실패: {test_image_path}")
|
|
return False
|
|
|
|
height, width = image_data.shape[:2]
|
|
print(f"✅ 이미지 로드 성공: {width}x{height}")
|
|
|
|
# 3. 로거 및 GPU 관리자 초기화
|
|
logger = SimpleLogger("BriaAITest")
|
|
|
|
try:
|
|
gpu_manager = GPUManager(logger=logger)
|
|
# DirectML GPU 모드로 강제 설정 (테스트용)
|
|
gpu_manager.can_use_cuda = True
|
|
print(f"✅ GPU 관리자 초기화 성공: CUDA 사용 가능 = {gpu_manager.can_use_cuda} (DirectML 강제 활성화)")
|
|
except Exception as e:
|
|
print(f"⚠️ GPU 관리자 초기화 실패, None으로 설정: {e}")
|
|
gpu_manager = None
|
|
|
|
# 4. BriaAI 모델 경로 설정 (사용자가 설정해야 함)
|
|
# TODO: 실제 BriaAI ONNX 모델 경로로 변경하세요
|
|
bria_model_path = os.path.join(src_dir, "modules", "bria_models", "BriaRMBG1.4_model_fp16.onnx")
|
|
if not os.path.exists(bria_model_path):
|
|
print(f"⚠️ BriaAI 모델을 찾을 수 없습니다: {bria_model_path}")
|
|
print("📝 BriaAI RMBG 1.4 ONNX 모델을 다운로드하고 경로를 설정해주세요.")
|
|
# 대체 경로들 시도
|
|
alternative_paths = [
|
|
os.path.join(src_dir, "modules", "rembg_models", "BriaRMBG1.4_model_fp16.onnx"),
|
|
os.path.join(current_dir, "BriaRMBG1.4_model_fp16.onnx"),
|
|
os.path.join(project_root, "models", "BriaRMBG1.4_model_fp16.onnx")
|
|
]
|
|
|
|
bria_model_found = False
|
|
for alt_path in alternative_paths:
|
|
if os.path.exists(alt_path):
|
|
bria_model_path = alt_path
|
|
bria_model_found = True
|
|
print(f"✅ 대체 BriaAI 모델 발견: {bria_model_path}")
|
|
break
|
|
|
|
if not bria_model_found:
|
|
print("❌ BriaAI 모델을 찾을 수 없어 테스트를 중단합니다.")
|
|
print("💡 다음 중 한 위치에 BriaAI RMBG 1.4 ONNX 모델을 배치해주세요:")
|
|
for path in alternative_paths:
|
|
print(f" - {path}")
|
|
return False
|
|
else:
|
|
print(f"✅ BriaAI 모델 발견: {bria_model_path}")
|
|
|
|
# 5. Request_AI_Server 초기화
|
|
try:
|
|
request_ai_server = Request_AI_Server(
|
|
logger=logger,
|
|
inpaint_server_url="http://192.168.0.150:8008", # 더미 URL
|
|
rembg_server_url=None, # 로컬 모드 사용
|
|
gpu_manager=gpu_manager,
|
|
local_rembg_model_path=bria_model_path # BriaAI 모델 경로
|
|
)
|
|
print("✅ Request_AI_Server 초기화 성공")
|
|
except Exception as e:
|
|
print(f"❌ Request_AI_Server 초기화 실패: {e}")
|
|
return False
|
|
|
|
# 6. BriaAI 배경제거 모듈 테스트
|
|
print("\n--- BriaAI 배경제거 모듈 테스트 시작 ---")
|
|
|
|
test_cases = [
|
|
{"model_name": "bria-rmbg-1.4", "object_ratio": 0.75, "force_cpu": False, "desc": "DirectML_GPU_BriaAI"},
|
|
{"model_name": "bria-rmbg-1.4", "object_ratio": 0.75, "force_cpu": True, "desc": "CPU_BriaAI"},
|
|
{"model_name": "bria-rmbg-aggressive", "object_ratio": 0.8, "force_cpu": True, "desc": "Aggressive_BriaAI"},
|
|
{"model_name": "bria-rmbg-gentle", "object_ratio": 0.6, "force_cpu": True, "desc": "Gentle_BriaAI"},
|
|
]
|
|
|
|
results = []
|
|
for i, test_case in enumerate(test_cases):
|
|
print(f"\n테스트 케이스 {i+1}: {test_case}")
|
|
|
|
try:
|
|
start_time = datetime.now()
|
|
|
|
result_img = request_ai_server._use_backup_rembg(
|
|
image_data=image_data,
|
|
model_name=test_case["model_name"],
|
|
object_ratio=test_case["object_ratio"],
|
|
debug_save=True,
|
|
debug_prefix=f"bria_test_{i+1}_{test_case['desc'].replace(' ', '_')}",
|
|
force_cpu=test_case["force_cpu"]
|
|
)
|
|
|
|
end_time = datetime.now()
|
|
processing_time = (end_time - start_time).total_seconds()
|
|
|
|
if result_img is not None:
|
|
# 결과 이미지 저장 (영어 파일명)
|
|
safe_desc = test_case['desc'] # 이미 영어로 설정됨
|
|
output_filename = f"bria_rembg_result_{i+1}_{safe_desc}_ratio{int(test_case['object_ratio']*100)}.png"
|
|
output_path = os.path.join(current_dir, output_filename)
|
|
|
|
success = cv2.imwrite(output_path, result_img)
|
|
if success:
|
|
result_height, result_width = result_img.shape[:2]
|
|
print(f" ✅ 성공! 결과 이미지 크기: {result_width}x{result_height}")
|
|
print(f" 📁 저장 경로: {output_path}")
|
|
print(f" ⏱️ 처리 시간: {processing_time:.2f}초")
|
|
|
|
results.append({
|
|
'case': i+1,
|
|
'success': True,
|
|
'size': f"{result_width}x{result_height}",
|
|
'time': processing_time,
|
|
'path': output_path
|
|
})
|
|
else:
|
|
print(f" ❌ 결과 이미지 저장 실패: {output_path}")
|
|
results.append({'case': i+1, 'success': False, 'error': '이미지 저장 실패'})
|
|
else:
|
|
print(f" ❌ 함수 실행 실패: 결과 이미지가 None")
|
|
results.append({'case': i+1, 'success': False, 'error': '결과 이미지 None'})
|
|
|
|
except Exception as e:
|
|
print(f" ❌ 예외 발생: {e}")
|
|
results.append({'case': i+1, 'success': False, 'error': str(e)})
|
|
|
|
# 7. 테스트 결과 요약
|
|
print(f"\n=== 테스트 결과 요약 ===")
|
|
total_tests = len(test_cases)
|
|
successful_tests = sum(1 for r in results if r['success'])
|
|
failed_tests = total_tests - successful_tests
|
|
|
|
print(f"총 테스트: {total_tests}개")
|
|
print(f"성공: {successful_tests}개")
|
|
print(f"실패: {failed_tests}개")
|
|
|
|
for i, result in enumerate(results):
|
|
if result['success']:
|
|
print(f" ✅ 케이스 {result['case']}: {result['size']}, {result['time']:.2f}초")
|
|
else:
|
|
print(f" ❌ 케이스 {result['case']}: {result['error']}")
|
|
|
|
print(f"\n🎉 테스트 완료! 결과 이미지를 확인해보세요.")
|
|
print(f"📁 결과 파일 위치: {current_dir}")
|
|
|
|
return True
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
success = main()
|
|
sys.exit(0 if success else 1)
|
|
except KeyboardInterrupt:
|
|
print("\n⚠️ 사용자에 의해 테스트가 중단되었습니다.")
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
print(f"❌ 예상치 못한 오류 발생: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
sys.exit(1)
|