168 lines
5.4 KiB
Python
168 lines
5.4 KiB
Python
import argparse, os, sys
|
|
import onnx
|
|
from onnx import shape_inference, checker
|
|
import onnxoptimizer
|
|
|
|
# onnx-simplifier
|
|
try:
|
|
from onnxsim import simplify as onnx_simplify
|
|
except Exception:
|
|
onnx_simplify = None
|
|
|
|
# FP16 변환
|
|
try:
|
|
from onnxconverter_common import float16
|
|
except Exception as e:
|
|
print("[WARN] onnxconverter-common 미설치 혹은 구버전: FP16 변환은 생략됩니다.", e)
|
|
float16 = None
|
|
|
|
# (선택) ORT 동적 양자화
|
|
try:
|
|
from onnxruntime.quantization import quantize_dynamic, QuantType
|
|
except Exception:
|
|
quantize_dynamic, QuantType = None, None
|
|
|
|
|
|
def load_model(path):
|
|
m = onnx.load(path)
|
|
checker.check_model(m)
|
|
return m
|
|
|
|
def save_model(m, path):
|
|
checker.check_model(m)
|
|
onnx.save(m, path)
|
|
print(f"[OK] saved -> {path}")
|
|
|
|
def run_shape_infer(m):
|
|
try:
|
|
return shape_inference.infer_shapes(m)
|
|
except Exception as e:
|
|
print("[WARN] onnx.shape_inference 실패 — 동적 shape가 많아도 추론엔 문제없음:", e)
|
|
return m
|
|
|
|
def run_simplify(m, dynamic_input_shape=True):
|
|
if onnx_simplify is None:
|
|
print("[SKIP] onnx-simplifier(onnxsim) 미사용")
|
|
return m
|
|
try:
|
|
sm, ok = onnx_simplify(m)
|
|
if not ok:
|
|
print("[WARN] onnxsim 검증 실패 — 원본 그래프 유지")
|
|
return m
|
|
return sm
|
|
except Exception as e:
|
|
print("[WARN] onnxsim 에러 — 원본 그래프 유지:", e)
|
|
return m
|
|
|
|
def run_optimize(m):
|
|
# 보수적인 기본 패스 세트
|
|
passes = [
|
|
"eliminate_deadend",
|
|
"eliminate_identity",
|
|
"eliminate_nop_transpose",
|
|
"eliminate_nop_pad",
|
|
"eliminate_unused_initializer",
|
|
"fuse_consecutive_transposes",
|
|
"fuse_transpose_into_gemm",
|
|
"fuse_add_bias_into_conv",
|
|
"fuse_bn_into_conv",
|
|
"fuse_pad_into_conv",
|
|
"fuse_matmul_add_bias_into_gemm",
|
|
]
|
|
avail = set(onnxoptimizer.get_available_passes())
|
|
use = [p for p in passes if p in avail]
|
|
try:
|
|
return onnxoptimizer.optimize(m, use)
|
|
except Exception as e:
|
|
print("[WARN] onnxoptimizer 실패 — 원본 그래프 유지:", e)
|
|
return m
|
|
|
|
def run_fp16(m, keep_io_fp32=True):
|
|
if float16 is None:
|
|
print("[SKIP] FP16 변환 모듈 없음")
|
|
return m
|
|
try:
|
|
# 일부 op은 FP16 미지원일 수 있어 자동으로 통과됨
|
|
return float16.convert_float_to_float16(
|
|
m,
|
|
keep_io_types=keep_io_fp32, # 입출력은 FP32 유지(파이프라인 호환)
|
|
# op_block_list=set([]), # 문제가 생기면 여기에 차단 op 추가
|
|
)
|
|
except Exception as e:
|
|
print("[WARN] FP16 변환 실패 — 원본 그래프 유지:", e)
|
|
return m
|
|
|
|
def run_dyn_quant(in_path, out_path):
|
|
if quantize_dynamic is None:
|
|
print("[SKIP] ORT 동적 양자화 모듈 없음")
|
|
return
|
|
try:
|
|
# MatMul/Gemm 위주 — Conv는 동적양자화 지원 제한
|
|
quantize_dynamic(
|
|
model_input=in_path,
|
|
model_output=out_path,
|
|
per_channel=False,
|
|
weight_type=QuantType.QInt8,
|
|
optimize_model=False,
|
|
op_types_to_quantize=["MatMul","Gemm"]
|
|
)
|
|
print(f"[OK] 동적 양자화 saved -> {out_path}")
|
|
except Exception as e:
|
|
print("[WARN] 동적 양자화 실패:", e)
|
|
|
|
def pipeline_one(in_path, out_dir, do_fp16=True, do_quant_dynamic=False, tag=""):
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
base = os.path.splitext(os.path.basename(in_path))[0]
|
|
|
|
print(f"\n=== [{tag}] {base} ===")
|
|
m = load_model(in_path)
|
|
|
|
# 1) Shape inference
|
|
m = run_shape_infer(m)
|
|
save_model(m, os.path.join(out_dir, f"{base}.shaped.onnx"))
|
|
|
|
# 2) Simplify (동적 입력 허용)
|
|
m = run_simplify(m, dynamic_input_shape=True)
|
|
save_model(m, os.path.join(out_dir, f"{base}.simp.onnx"))
|
|
|
|
# 3) Graph optimize
|
|
m = run_optimize(m)
|
|
save_model(m, os.path.join(out_dir, f"{base}.opt.onnx"))
|
|
|
|
# 4) FP16 (선택) — CPU에선 속도이득은 제한적, 메모리/모델크기 이득
|
|
if do_fp16:
|
|
mf16 = run_fp16(m, keep_io_fp32=True)
|
|
save_model(mf16, os.path.join(out_dir, f"{base}.fp16.onnx"))
|
|
# (선택) FP16 + 동적양자화를 같이 쓰진 않는 게 일반적
|
|
|
|
# 5) (선택) ORT 동적 양자화 (MatMul/Gemm 위주)
|
|
if do_quant_dynamic:
|
|
run_dyn_quant(
|
|
os.path.join(out_dir, f"{base}.opt.onnx"),
|
|
os.path.join(out_dir, f"{base}.dq.onnx"),
|
|
)
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--det", help="det onnx path")
|
|
ap.add_argument("--rec", help="rec onnx path")
|
|
ap.add_argument("--cls", help="cls onnx path")
|
|
ap.add_argument("--outdir", required=True, help="output dir")
|
|
ap.add_argument("--fp16", action="store_true", help="export fp16 model")
|
|
ap.add_argument("--dq", action="store_true", help="onnxruntime dynamic quantization")
|
|
args = ap.parse_args()
|
|
|
|
if not any([args.det, args.rec, args.cls]):
|
|
print("하나 이상 지정: --det / --rec / --cls")
|
|
sys.exit(1)
|
|
|
|
if args.det:
|
|
pipeline_one(args.det, args.outdir, do_fp16=args.fp16, do_quant_dynamic=args.dq, tag="det")
|
|
if args.rec:
|
|
pipeline_one(args.rec, args.outdir, do_fp16=args.fp16, do_quant_dynamic=args.dq, tag="rec")
|
|
if args.cls:
|
|
pipeline_one(args.cls, args.outdir, do_fp16=args.fp16, do_quant_dynamic=args.dq, tag="cls")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|