AutoPercenty3/test_bria_background_remova...

357 lines
14 KiB
Python

"""
BriaAI RMBG 1.4 ONNX 모델 테스트 GUI
배경제거 결과를 원본과 나란히 보여주는 테스트 도구
"""
import tkinter as tk
from tkinter import ttk, messagebox, filedialog
import cv2
import numpy as np
from PIL import Image, ImageTk
import logging
import os
import time
import threading
from typing import Optional
class SimpleLogger:
"""간단한 로거 클래스"""
def __init__(self):
self.log_text = ""
def log(self, message: str, level=logging.INFO):
timestamp = time.strftime("%H:%M:%S")
level_name = {
logging.DEBUG: "DEBUG",
logging.INFO: "INFO",
logging.WARNING: "WARN",
logging.ERROR: "ERROR"
}.get(level, "INFO")
log_line = f"[{timestamp}] {level_name}: {message}"
self.log_text += log_line + "\n"
print(log_line)
class BriaTestGUI:
def __init__(self, root):
self.root = root
self.root.title("BriaAI RMBG 1.4 테스트")
self.root.geometry("1200x800")
# 모델 파일 경로
self.model_path = "src/modules/briaaiModel/BriaRMBG1.4_model_fp16.onnx"
if not os.path.exists(self.model_path):
# 대안 경로 시도
alt_paths = [
"src/modules/briaai_Model/BriaRMBG1.4_model_fp16.onnx",
"src/modules/BriaRMBG1.4_model_fp16.onnx"
]
for alt_path in alt_paths:
if os.path.exists(alt_path):
self.model_path = alt_path
break
self.logger = SimpleLogger()
self.bria_module = None
self.current_image = None
self.result_image = None
self.setup_ui()
self.init_bria_module()
def setup_ui(self):
"""UI 설정"""
# 메인 프레임
main_frame = ttk.Frame(self.root)
main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
# 상단 컨트롤 프레임
control_frame = ttk.Frame(main_frame)
control_frame.pack(fill=tk.X, pady=(0, 10))
# 이미지 선택 버튼
ttk.Button(control_frame, text="이미지 선택", command=self.select_image).pack(side=tk.LEFT, padx=(0, 10))
# 기본 테스트 이미지 로드 버튼
ttk.Button(control_frame, text="테스트 이미지 로드", command=self.load_test_image).pack(side=tk.LEFT, padx=(0, 10))
# Aggressiveness 설정
ttk.Label(control_frame, text="Aggressiveness:").pack(side=tk.LEFT, padx=(20, 5))
self.aggressiveness_var = tk.DoubleVar(value=0.5)
aggressiveness_scale = ttk.Scale(control_frame, from_=0.0, to=1.0,
orient=tk.HORIZONTAL, length=200,
variable=self.aggressiveness_var)
aggressiveness_scale.pack(side=tk.LEFT, padx=(0, 5))
self.aggressiveness_label = ttk.Label(control_frame, text="0.5")
self.aggressiveness_label.pack(side=tk.LEFT, padx=(0, 10))
# 값 업데이트 함수
def update_aggressiveness_label(*args):
value = self.aggressiveness_var.get()
self.aggressiveness_label.config(text=f"{value:.2f}")
self.aggressiveness_var.trace('w', update_aggressiveness_label)
# 배경제거 실행 버튼
ttk.Button(control_frame, text="배경제거 실행", command=self.run_background_removal).pack(side=tk.LEFT, padx=(10, 0))
# 이미지 표시 프레임
image_frame = ttk.Frame(main_frame)
image_frame.pack(fill=tk.BOTH, expand=True, pady=(0, 10))
# 원본 이미지 프레임
original_frame = ttk.LabelFrame(image_frame, text="원본 이미지")
original_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=(0, 5))
self.original_label = ttk.Label(original_frame, text="이미지를 선택하세요")
self.original_label.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
# 결과 이미지 프레임
result_frame = ttk.LabelFrame(image_frame, text="배경제거 결과")
result_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True, padx=(5, 0))
self.result_label = ttk.Label(result_frame, text="배경제거를 실행하세요")
self.result_label.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
# 로그 프레임
log_frame = ttk.LabelFrame(main_frame, text="로그")
log_frame.pack(fill=tk.X, pady=(10, 0))
# 로그 텍스트와 스크롤바
log_text_frame = ttk.Frame(log_frame)
log_text_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
self.log_text = tk.Text(log_text_frame, height=8, wrap=tk.WORD)
log_scrollbar = ttk.Scrollbar(log_text_frame, orient=tk.VERTICAL, command=self.log_text.yview)
self.log_text.config(yscrollcommand=log_scrollbar.set)
self.log_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
log_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
# 로그 지우기 버튼
ttk.Button(log_frame, text="로그 지우기", command=self.clear_log).pack(side=tk.RIGHT, padx=10, pady=(0, 10))
def init_bria_module(self):
"""BriaAI 모듈 초기화"""
try:
# 모듈 import
import sys
sys.path.append('src/modules')
from bria_background_removal_module import BriaBackgroundRemovalModule
self.log_message(f"BriaAI 모델 파일 경로: {self.model_path}")
if not os.path.exists(self.model_path):
self.log_message(f"❌ 모델 파일을 찾을 수 없습니다: {self.model_path}", logging.ERROR)
messagebox.showerror("오류", f"모델 파일을 찾을 수 없습니다:\n{self.model_path}")
return
# BriaAI 모듈 초기화
self.bria_module = BriaBackgroundRemovalModule(
logger=self.logger,
default_model="bria-rmbg-1.4",
gpu_manager=None, # GPU 매니저 없이 테스트
local_rembg_model_path=self.model_path
)
if self.bria_module.is_available():
self.log_message("✅ BriaAI 모듈 초기화 성공")
else:
error = self.bria_module.get_init_error()
self.log_message(f"❌ BriaAI 모듈 초기화 실패: {error}", logging.ERROR)
messagebox.showerror("오류", f"BriaAI 모듈 초기화 실패:\n{error}")
except Exception as e:
self.log_message(f"❌ 모듈 import 실패: {e}", logging.ERROR)
messagebox.showerror("오류", f"모듈 import 실패:\n{e}")
def log_message(self, message: str, level=logging.INFO):
"""로그 메시지 추가"""
timestamp = time.strftime("%H:%M:%S")
level_name = {
logging.DEBUG: "DEBUG",
logging.INFO: "INFO",
logging.WARNING: "WARN",
logging.ERROR: "ERROR"
}.get(level, "INFO")
log_line = f"[{timestamp}] {level_name}: {message}\n"
# GUI 텍스트 위젯에 추가 (메인 스레드에서)
self.root.after(0, lambda: self._add_log_to_gui(log_line))
print(f"[{timestamp}] {level_name}: {message}")
def _add_log_to_gui(self, log_line: str):
"""GUI 로그 텍스트에 추가 (메인 스레드용)"""
self.log_text.insert(tk.END, log_line)
self.log_text.see(tk.END)
def clear_log(self):
"""로그 지우기"""
self.log_text.delete(1.0, tk.END)
def select_image(self):
"""이미지 파일 선택"""
file_path = filedialog.askopenfilename(
title="이미지 선택",
filetypes=[
("이미지 파일", "*.jpg *.jpeg *.png *.bmp *.tiff"),
("모든 파일", "*.*")
]
)
if file_path:
self.load_image(file_path)
def load_test_image(self):
"""기본 테스트 이미지 로드"""
test_image_path = "test_image.jpg"
if os.path.exists(test_image_path):
self.load_image(test_image_path)
else:
messagebox.showerror("오류", "test_image.jpg 파일을 찾을 수 없습니다.")
def load_image(self, image_path: str):
"""이미지 로드 및 표시"""
try:
# OpenCV로 이미지 로드
self.current_image = cv2.imread(image_path)
if self.current_image is None:
messagebox.showerror("오류", f"이미지를 로드할 수 없습니다: {image_path}")
return
self.log_message(f"📷 이미지 로드: {image_path} ({self.current_image.shape})")
# PIL로 변환하여 GUI에 표시
display_img = self.resize_for_display(self.current_image)
pil_img = Image.fromarray(cv2.cvtColor(display_img, cv2.COLOR_BGR2RGB))
photo = ImageTk.PhotoImage(pil_img)
self.original_label.config(image=photo, text="")
self.original_label.image = photo # 참조 유지
# 결과 이미지 초기화
self.result_label.config(image="", text="배경제거를 실행하세요")
self.result_image = None
except Exception as e:
self.log_message(f"❌ 이미지 로드 오류: {e}", logging.ERROR)
messagebox.showerror("오류", f"이미지 로드 오류:\n{e}")
def resize_for_display(self, img: np.ndarray, max_size: int = 400) -> np.ndarray:
"""GUI 표시용으로 이미지 리사이즈"""
h, w = img.shape[:2]
if max(h, w) <= max_size:
return img
scale = max_size / max(h, w)
new_w = int(w * scale)
new_h = int(h * scale)
return cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
def run_background_removal(self):
"""배경제거 실행 (백그라운드 스레드)"""
if self.current_image is None:
messagebox.showwarning("경고", "먼저 이미지를 선택하세요.")
return
if self.bria_module is None or not self.bria_module.is_available():
messagebox.showerror("오류", "BriaAI 모듈이 사용 불가능합니다.")
return
# 백그라운드 스레드에서 실행
threading.Thread(target=self._process_background_removal, daemon=True).start()
def _process_background_removal(self):
"""실제 배경제거 처리"""
try:
aggressiveness = self.aggressiveness_var.get()
self.log_message(f"🔄 배경제거 시작 (aggressiveness={aggressiveness:.2f})")
start_time = time.time()
# 임시 파일로 저장 (BriaAI 모듈이 파일 경로를 요구하므로)
import tempfile
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
tmp_path = tmp_file.name
cv2.imwrite(tmp_path, self.current_image)
try:
# 배경제거 실행
result_pil = self.bria_module.remove_background(
tmp_path,
model_name="bria-rmbg-1.4",
aggressiveness=aggressiveness,
force_cpu=True # 안정성을 위해 CPU 모드 사용
)
if result_pil is None:
self.log_message("❌ 배경제거 실패", logging.ERROR)
return
# PIL Image를 numpy로 변환
result_array = np.array(result_pil)
if result_pil.mode == 'RGBA':
# RGBA를 BGR로 변환 (알파 채널 제거 후 흰 배경 합성)
rgb_array = result_array[:, :, :3]
alpha_array = result_array[:, :, 3] / 255.0
# 흰 배경과 합성
white_bg = np.ones_like(rgb_array) * 255
result_bgr = (rgb_array * alpha_array[:, :, np.newaxis] +
white_bg * (1 - alpha_array[:, :, np.newaxis])).astype(np.uint8)
result_bgr = cv2.cvtColor(result_bgr, cv2.COLOR_RGB2BGR)
else:
result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)
self.result_image = result_bgr
# 처리 시간 계산
end_time = time.time()
processing_time = end_time - start_time
self.log_message(f"✅ 배경제거 완료 ({processing_time:.2f}초)")
# GUI 업데이트 (메인 스레드에서)
self.root.after(0, self._update_result_display)
finally:
# 임시 파일 삭제
try:
os.unlink(tmp_path)
except:
pass
except Exception as e:
self.log_message(f"❌ 배경제거 처리 중 오류: {e}", logging.ERROR)
def _update_result_display(self):
"""결과 이미지 GUI 업데이트 (메인 스레드용)"""
if self.result_image is not None:
try:
# 표시용으로 리사이즈
display_img = self.resize_for_display(self.result_image)
pil_img = Image.fromarray(cv2.cvtColor(display_img, cv2.COLOR_BGR2RGB))
photo = ImageTk.PhotoImage(pil_img)
self.result_label.config(image=photo, text="")
self.result_label.image = photo # 참조 유지
except Exception as e:
self.log_message(f"❌ 결과 표시 오류: {e}", logging.ERROR)
def main():
"""메인 함수"""
root = tk.Tk()
app = BriaTestGUI(root)
root.mainloop()
if __name__ == "__main__":
main()