16 lines
464 B
Python
16 lines
464 B
Python
# RETHINED 모델 export.py
|
|
import torch
|
|
from rethined import RethinedModel # 논문 제공 코드 임포트
|
|
|
|
model = RethinedModel().eval().to('cuda')
|
|
H, W = 512, 512 # 또는 원하는 해상도
|
|
dummy = torch.randn(1, 3, H, W, device='cuda')
|
|
|
|
torch.onnx.export(
|
|
model, dummy, "rethined.onnx",
|
|
opset_version=14,
|
|
input_names=['input'],
|
|
output_names=['output'],
|
|
dynamic_axes={'input':{2:'height',3:'width'}, 'output':{2:'height',3:'width'}}
|
|
)
|