# iop_tensorrt.py import torch from torch2trt import torch2trt, TRTModule # --- 1. 모델 정의 및 로드 (nn.Module 형태) --- # MI‑GAN 예시: 원본 모델 클래스 import 필요 from model_zoo.migan_inference import MIGAN # 실제 경로로 수정하세요 model = MIGAN().eval().cuda() # 필요 시 weight 로드 model.load_state_dict(torch.load("migan_traced_weights.pth")) # weights 파일 위치 # --- 2. 더미 입력 (4채널: RGB + 마스크) --- H, W = 512, 512 # 이미지 크기에 맞게 조정 dummy = torch.randn((1, 4, H, W)).cuda().half() # FP16 입력 가능 # --- 3. TensorRT 변환 --- model_trt = torch2trt( model, [dummy], fp16_mode=True, max_batch_size=1, default_device_type=None, # DLA 사용 시 trt.DeviceType.DLA 지정 가능 gpu_fallback=True ) # --- 4. TRTModule wrapping 및 저장 --- trt_mod = TRTModule() trt_mod.load_state_dict(model_trt.state_dict()) torch.save(trt_mod.state_dict(), "migan_trt.pth") print("✅ TensorRT 변환 완료 — migan_trt.pth 저장됨")