TransWorker/re/export.py

16 lines
464 B
Python

# RETHINED 모델 export.py
import torch
from rethined import RethinedModel # 논문 제공 코드 임포트
model = RethinedModel().eval().to('cuda')
H, W = 512, 512 # 또는 원하는 해상도
dummy = torch.randn(1, 3, H, W, device='cuda')
torch.onnx.export(
model, dummy, "rethined.onnx",
opset_version=14,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input':{2:'height',3:'width'}, 'output':{2:'height',3:'width'}}
)