357 lines
14 KiB
Python
357 lines
14 KiB
Python
"""
|
|
BriaAI RMBG 1.4 ONNX 모델 테스트 GUI
|
|
배경제거 결과를 원본과 나란히 보여주는 테스트 도구
|
|
"""
|
|
import tkinter as tk
|
|
from tkinter import ttk, messagebox, filedialog
|
|
import cv2
|
|
import numpy as np
|
|
from PIL import Image, ImageTk
|
|
import logging
|
|
import os
|
|
import time
|
|
import threading
|
|
from typing import Optional
|
|
|
|
|
|
class SimpleLogger:
|
|
"""간단한 로거 클래스"""
|
|
def __init__(self):
|
|
self.log_text = ""
|
|
|
|
def log(self, message: str, level=logging.INFO):
|
|
timestamp = time.strftime("%H:%M:%S")
|
|
level_name = {
|
|
logging.DEBUG: "DEBUG",
|
|
logging.INFO: "INFO",
|
|
logging.WARNING: "WARN",
|
|
logging.ERROR: "ERROR"
|
|
}.get(level, "INFO")
|
|
|
|
log_line = f"[{timestamp}] {level_name}: {message}"
|
|
self.log_text += log_line + "\n"
|
|
print(log_line)
|
|
|
|
|
|
class BriaTestGUI:
|
|
def __init__(self, root):
|
|
self.root = root
|
|
self.root.title("BriaAI RMBG 1.4 테스트")
|
|
self.root.geometry("1200x800")
|
|
|
|
# 모델 파일 경로
|
|
self.model_path = "src/modules/briaaiModel/BriaRMBG1.4_model_fp16.onnx"
|
|
if not os.path.exists(self.model_path):
|
|
# 대안 경로 시도
|
|
alt_paths = [
|
|
"src/modules/briaai_Model/BriaRMBG1.4_model_fp16.onnx",
|
|
"src/modules/BriaRMBG1.4_model_fp16.onnx"
|
|
]
|
|
for alt_path in alt_paths:
|
|
if os.path.exists(alt_path):
|
|
self.model_path = alt_path
|
|
break
|
|
|
|
self.logger = SimpleLogger()
|
|
self.bria_module = None
|
|
self.current_image = None
|
|
self.result_image = None
|
|
|
|
self.setup_ui()
|
|
self.init_bria_module()
|
|
|
|
def setup_ui(self):
|
|
"""UI 설정"""
|
|
# 메인 프레임
|
|
main_frame = ttk.Frame(self.root)
|
|
main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
|
|
|
# 상단 컨트롤 프레임
|
|
control_frame = ttk.Frame(main_frame)
|
|
control_frame.pack(fill=tk.X, pady=(0, 10))
|
|
|
|
# 이미지 선택 버튼
|
|
ttk.Button(control_frame, text="이미지 선택", command=self.select_image).pack(side=tk.LEFT, padx=(0, 10))
|
|
|
|
# 기본 테스트 이미지 로드 버튼
|
|
ttk.Button(control_frame, text="테스트 이미지 로드", command=self.load_test_image).pack(side=tk.LEFT, padx=(0, 10))
|
|
|
|
# Aggressiveness 설정
|
|
ttk.Label(control_frame, text="Aggressiveness:").pack(side=tk.LEFT, padx=(20, 5))
|
|
self.aggressiveness_var = tk.DoubleVar(value=0.5)
|
|
aggressiveness_scale = ttk.Scale(control_frame, from_=0.0, to=1.0,
|
|
orient=tk.HORIZONTAL, length=200,
|
|
variable=self.aggressiveness_var)
|
|
aggressiveness_scale.pack(side=tk.LEFT, padx=(0, 5))
|
|
|
|
self.aggressiveness_label = ttk.Label(control_frame, text="0.5")
|
|
self.aggressiveness_label.pack(side=tk.LEFT, padx=(0, 10))
|
|
|
|
# 값 업데이트 함수
|
|
def update_aggressiveness_label(*args):
|
|
value = self.aggressiveness_var.get()
|
|
self.aggressiveness_label.config(text=f"{value:.2f}")
|
|
|
|
self.aggressiveness_var.trace('w', update_aggressiveness_label)
|
|
|
|
# 배경제거 실행 버튼
|
|
ttk.Button(control_frame, text="배경제거 실행", command=self.run_background_removal).pack(side=tk.LEFT, padx=(10, 0))
|
|
|
|
# 이미지 표시 프레임
|
|
image_frame = ttk.Frame(main_frame)
|
|
image_frame.pack(fill=tk.BOTH, expand=True, pady=(0, 10))
|
|
|
|
# 원본 이미지 프레임
|
|
original_frame = ttk.LabelFrame(image_frame, text="원본 이미지")
|
|
original_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=(0, 5))
|
|
|
|
self.original_label = ttk.Label(original_frame, text="이미지를 선택하세요")
|
|
self.original_label.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
|
|
|
# 결과 이미지 프레임
|
|
result_frame = ttk.LabelFrame(image_frame, text="배경제거 결과")
|
|
result_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True, padx=(5, 0))
|
|
|
|
self.result_label = ttk.Label(result_frame, text="배경제거를 실행하세요")
|
|
self.result_label.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
|
|
|
# 로그 프레임
|
|
log_frame = ttk.LabelFrame(main_frame, text="로그")
|
|
log_frame.pack(fill=tk.X, pady=(10, 0))
|
|
|
|
# 로그 텍스트와 스크롤바
|
|
log_text_frame = ttk.Frame(log_frame)
|
|
log_text_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
|
|
|
self.log_text = tk.Text(log_text_frame, height=8, wrap=tk.WORD)
|
|
log_scrollbar = ttk.Scrollbar(log_text_frame, orient=tk.VERTICAL, command=self.log_text.yview)
|
|
self.log_text.config(yscrollcommand=log_scrollbar.set)
|
|
|
|
self.log_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
|
log_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
|
|
|
|
# 로그 지우기 버튼
|
|
ttk.Button(log_frame, text="로그 지우기", command=self.clear_log).pack(side=tk.RIGHT, padx=10, pady=(0, 10))
|
|
|
|
def init_bria_module(self):
|
|
"""BriaAI 모듈 초기화"""
|
|
try:
|
|
# 모듈 import
|
|
import sys
|
|
sys.path.append('src/modules')
|
|
from bria_background_removal_module import BriaBackgroundRemovalModule
|
|
|
|
self.log_message(f"BriaAI 모델 파일 경로: {self.model_path}")
|
|
|
|
if not os.path.exists(self.model_path):
|
|
self.log_message(f"❌ 모델 파일을 찾을 수 없습니다: {self.model_path}", logging.ERROR)
|
|
messagebox.showerror("오류", f"모델 파일을 찾을 수 없습니다:\n{self.model_path}")
|
|
return
|
|
|
|
# BriaAI 모듈 초기화
|
|
self.bria_module = BriaBackgroundRemovalModule(
|
|
logger=self.logger,
|
|
default_model="bria-rmbg-1.4",
|
|
gpu_manager=None, # GPU 매니저 없이 테스트
|
|
local_rembg_model_path=self.model_path
|
|
)
|
|
|
|
if self.bria_module.is_available():
|
|
self.log_message("✅ BriaAI 모듈 초기화 성공")
|
|
else:
|
|
error = self.bria_module.get_init_error()
|
|
self.log_message(f"❌ BriaAI 모듈 초기화 실패: {error}", logging.ERROR)
|
|
messagebox.showerror("오류", f"BriaAI 모듈 초기화 실패:\n{error}")
|
|
|
|
except Exception as e:
|
|
self.log_message(f"❌ 모듈 import 실패: {e}", logging.ERROR)
|
|
messagebox.showerror("오류", f"모듈 import 실패:\n{e}")
|
|
|
|
def log_message(self, message: str, level=logging.INFO):
|
|
"""로그 메시지 추가"""
|
|
timestamp = time.strftime("%H:%M:%S")
|
|
level_name = {
|
|
logging.DEBUG: "DEBUG",
|
|
logging.INFO: "INFO",
|
|
logging.WARNING: "WARN",
|
|
logging.ERROR: "ERROR"
|
|
}.get(level, "INFO")
|
|
|
|
log_line = f"[{timestamp}] {level_name}: {message}\n"
|
|
|
|
# GUI 텍스트 위젯에 추가 (메인 스레드에서)
|
|
self.root.after(0, lambda: self._add_log_to_gui(log_line))
|
|
print(f"[{timestamp}] {level_name}: {message}")
|
|
|
|
def _add_log_to_gui(self, log_line: str):
|
|
"""GUI 로그 텍스트에 추가 (메인 스레드용)"""
|
|
self.log_text.insert(tk.END, log_line)
|
|
self.log_text.see(tk.END)
|
|
|
|
def clear_log(self):
|
|
"""로그 지우기"""
|
|
self.log_text.delete(1.0, tk.END)
|
|
|
|
def select_image(self):
|
|
"""이미지 파일 선택"""
|
|
file_path = filedialog.askopenfilename(
|
|
title="이미지 선택",
|
|
filetypes=[
|
|
("이미지 파일", "*.jpg *.jpeg *.png *.bmp *.tiff"),
|
|
("모든 파일", "*.*")
|
|
]
|
|
)
|
|
|
|
if file_path:
|
|
self.load_image(file_path)
|
|
|
|
def load_test_image(self):
|
|
"""기본 테스트 이미지 로드"""
|
|
test_image_path = "test_image.jpg"
|
|
if os.path.exists(test_image_path):
|
|
self.load_image(test_image_path)
|
|
else:
|
|
messagebox.showerror("오류", "test_image.jpg 파일을 찾을 수 없습니다.")
|
|
|
|
def load_image(self, image_path: str):
|
|
"""이미지 로드 및 표시"""
|
|
try:
|
|
# OpenCV로 이미지 로드
|
|
self.current_image = cv2.imread(image_path)
|
|
if self.current_image is None:
|
|
messagebox.showerror("오류", f"이미지를 로드할 수 없습니다: {image_path}")
|
|
return
|
|
|
|
self.log_message(f"📷 이미지 로드: {image_path} ({self.current_image.shape})")
|
|
|
|
# PIL로 변환하여 GUI에 표시
|
|
display_img = self.resize_for_display(self.current_image)
|
|
pil_img = Image.fromarray(cv2.cvtColor(display_img, cv2.COLOR_BGR2RGB))
|
|
photo = ImageTk.PhotoImage(pil_img)
|
|
|
|
self.original_label.config(image=photo, text="")
|
|
self.original_label.image = photo # 참조 유지
|
|
|
|
# 결과 이미지 초기화
|
|
self.result_label.config(image="", text="배경제거를 실행하세요")
|
|
self.result_image = None
|
|
|
|
except Exception as e:
|
|
self.log_message(f"❌ 이미지 로드 오류: {e}", logging.ERROR)
|
|
messagebox.showerror("오류", f"이미지 로드 오류:\n{e}")
|
|
|
|
def resize_for_display(self, img: np.ndarray, max_size: int = 400) -> np.ndarray:
|
|
"""GUI 표시용으로 이미지 리사이즈"""
|
|
h, w = img.shape[:2]
|
|
if max(h, w) <= max_size:
|
|
return img
|
|
|
|
scale = max_size / max(h, w)
|
|
new_w = int(w * scale)
|
|
new_h = int(h * scale)
|
|
|
|
return cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
|
|
|
def run_background_removal(self):
|
|
"""배경제거 실행 (백그라운드 스레드)"""
|
|
if self.current_image is None:
|
|
messagebox.showwarning("경고", "먼저 이미지를 선택하세요.")
|
|
return
|
|
|
|
if self.bria_module is None or not self.bria_module.is_available():
|
|
messagebox.showerror("오류", "BriaAI 모듈이 사용 불가능합니다.")
|
|
return
|
|
|
|
# 백그라운드 스레드에서 실행
|
|
threading.Thread(target=self._process_background_removal, daemon=True).start()
|
|
|
|
def _process_background_removal(self):
|
|
"""실제 배경제거 처리"""
|
|
try:
|
|
aggressiveness = self.aggressiveness_var.get()
|
|
self.log_message(f"🔄 배경제거 시작 (aggressiveness={aggressiveness:.2f})")
|
|
|
|
start_time = time.time()
|
|
|
|
# 임시 파일로 저장 (BriaAI 모듈이 파일 경로를 요구하므로)
|
|
import tempfile
|
|
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
|
|
tmp_path = tmp_file.name
|
|
cv2.imwrite(tmp_path, self.current_image)
|
|
|
|
try:
|
|
# 배경제거 실행
|
|
result_pil = self.bria_module.remove_background(
|
|
tmp_path,
|
|
model_name="bria-rmbg-1.4",
|
|
aggressiveness=aggressiveness,
|
|
force_cpu=True # 안정성을 위해 CPU 모드 사용
|
|
)
|
|
|
|
if result_pil is None:
|
|
self.log_message("❌ 배경제거 실패", logging.ERROR)
|
|
return
|
|
|
|
# PIL Image를 numpy로 변환
|
|
result_array = np.array(result_pil)
|
|
|
|
if result_pil.mode == 'RGBA':
|
|
# RGBA를 BGR로 변환 (알파 채널 제거 후 흰 배경 합성)
|
|
rgb_array = result_array[:, :, :3]
|
|
alpha_array = result_array[:, :, 3] / 255.0
|
|
|
|
# 흰 배경과 합성
|
|
white_bg = np.ones_like(rgb_array) * 255
|
|
result_bgr = (rgb_array * alpha_array[:, :, np.newaxis] +
|
|
white_bg * (1 - alpha_array[:, :, np.newaxis])).astype(np.uint8)
|
|
result_bgr = cv2.cvtColor(result_bgr, cv2.COLOR_RGB2BGR)
|
|
else:
|
|
result_bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)
|
|
|
|
self.result_image = result_bgr
|
|
|
|
# 처리 시간 계산
|
|
end_time = time.time()
|
|
processing_time = end_time - start_time
|
|
|
|
self.log_message(f"✅ 배경제거 완료 ({processing_time:.2f}초)")
|
|
|
|
# GUI 업데이트 (메인 스레드에서)
|
|
self.root.after(0, self._update_result_display)
|
|
|
|
finally:
|
|
# 임시 파일 삭제
|
|
try:
|
|
os.unlink(tmp_path)
|
|
except:
|
|
pass
|
|
|
|
except Exception as e:
|
|
self.log_message(f"❌ 배경제거 처리 중 오류: {e}", logging.ERROR)
|
|
|
|
def _update_result_display(self):
|
|
"""결과 이미지 GUI 업데이트 (메인 스레드용)"""
|
|
if self.result_image is not None:
|
|
try:
|
|
# 표시용으로 리사이즈
|
|
display_img = self.resize_for_display(self.result_image)
|
|
pil_img = Image.fromarray(cv2.cvtColor(display_img, cv2.COLOR_BGR2RGB))
|
|
photo = ImageTk.PhotoImage(pil_img)
|
|
|
|
self.result_label.config(image=photo, text="")
|
|
self.result_label.image = photo # 참조 유지
|
|
|
|
except Exception as e:
|
|
self.log_message(f"❌ 결과 표시 오류: {e}", logging.ERROR)
|
|
|
|
|
|
def main():
|
|
"""메인 함수"""
|
|
root = tk.Tk()
|
|
app = BriaTestGUI(root)
|
|
root.mainloop()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|