import argparse, os, sys import onnx from onnx import shape_inference, checker import onnxoptimizer # onnx-simplifier try: from onnxsim import simplify as onnx_simplify except Exception: onnx_simplify = None # FP16 변환 try: from onnxconverter_common import float16 except Exception as e: print("[WARN] onnxconverter-common 미설치 혹은 구버전: FP16 변환은 생략됩니다.", e) float16 = None # (선택) ORT 동적 양자화 try: from onnxruntime.quantization import quantize_dynamic, QuantType except Exception: quantize_dynamic, QuantType = None, None def load_model(path): m = onnx.load(path) checker.check_model(m) return m def save_model(m, path): checker.check_model(m) onnx.save(m, path) print(f"[OK] saved -> {path}") def run_shape_infer(m): try: return shape_inference.infer_shapes(m) except Exception as e: print("[WARN] onnx.shape_inference 실패 — 동적 shape가 많아도 추론엔 문제없음:", e) return m def run_simplify(m, dynamic_input_shape=True): if onnx_simplify is None: print("[SKIP] onnx-simplifier(onnxsim) 미사용") return m try: sm, ok = onnx_simplify(m) if not ok: print("[WARN] onnxsim 검증 실패 — 원본 그래프 유지") return m return sm except Exception as e: print("[WARN] onnxsim 에러 — 원본 그래프 유지:", e) return m def run_optimize(m): # 보수적인 기본 패스 세트 passes = [ "eliminate_deadend", "eliminate_identity", "eliminate_nop_transpose", "eliminate_nop_pad", "eliminate_unused_initializer", "fuse_consecutive_transposes", "fuse_transpose_into_gemm", "fuse_add_bias_into_conv", "fuse_bn_into_conv", "fuse_pad_into_conv", "fuse_matmul_add_bias_into_gemm", ] avail = set(onnxoptimizer.get_available_passes()) use = [p for p in passes if p in avail] try: return onnxoptimizer.optimize(m, use) except Exception as e: print("[WARN] onnxoptimizer 실패 — 원본 그래프 유지:", e) return m def run_fp16(m, keep_io_fp32=True): if float16 is None: print("[SKIP] FP16 변환 모듈 없음") return m try: # 일부 op은 FP16 미지원일 수 있어 자동으로 통과됨 return float16.convert_float_to_float16( m, keep_io_types=keep_io_fp32, # 입출력은 FP32 유지(파이프라인 호환) # op_block_list=set([]), # 문제가 생기면 여기에 차단 op 추가 ) except Exception as e: print("[WARN] FP16 변환 실패 — 원본 그래프 유지:", e) return m def run_dyn_quant(in_path, out_path): if quantize_dynamic is None: print("[SKIP] ORT 동적 양자화 모듈 없음") return try: # MatMul/Gemm 위주 — Conv는 동적양자화 지원 제한 quantize_dynamic( model_input=in_path, model_output=out_path, per_channel=False, weight_type=QuantType.QInt8, optimize_model=False, op_types_to_quantize=["MatMul","Gemm"] ) print(f"[OK] 동적 양자화 saved -> {out_path}") except Exception as e: print("[WARN] 동적 양자화 실패:", e) def pipeline_one(in_path, out_dir, do_fp16=True, do_quant_dynamic=False, tag=""): os.makedirs(out_dir, exist_ok=True) base = os.path.splitext(os.path.basename(in_path))[0] print(f"\n=== [{tag}] {base} ===") m = load_model(in_path) # 1) Shape inference m = run_shape_infer(m) save_model(m, os.path.join(out_dir, f"{base}.shaped.onnx")) # 2) Simplify (동적 입력 허용) m = run_simplify(m, dynamic_input_shape=True) save_model(m, os.path.join(out_dir, f"{base}.simp.onnx")) # 3) Graph optimize m = run_optimize(m) save_model(m, os.path.join(out_dir, f"{base}.opt.onnx")) # 4) FP16 (선택) — CPU에선 속도이득은 제한적, 메모리/모델크기 이득 if do_fp16: mf16 = run_fp16(m, keep_io_fp32=True) save_model(mf16, os.path.join(out_dir, f"{base}.fp16.onnx")) # (선택) FP16 + 동적양자화를 같이 쓰진 않는 게 일반적 # 5) (선택) ORT 동적 양자화 (MatMul/Gemm 위주) if do_quant_dynamic: run_dyn_quant( os.path.join(out_dir, f"{base}.opt.onnx"), os.path.join(out_dir, f"{base}.dq.onnx"), ) def main(): ap = argparse.ArgumentParser() ap.add_argument("--det", help="det onnx path") ap.add_argument("--rec", help="rec onnx path") ap.add_argument("--cls", help="cls onnx path") ap.add_argument("--outdir", required=True, help="output dir") ap.add_argument("--fp16", action="store_true", help="export fp16 model") ap.add_argument("--dq", action="store_true", help="onnxruntime dynamic quantization") args = ap.parse_args() if not any([args.det, args.rec, args.cls]): print("하나 이상 지정: --det / --rec / --cls") sys.exit(1) if args.det: pipeline_one(args.det, args.outdir, do_fp16=args.fp16, do_quant_dynamic=args.dq, tag="det") if args.rec: pipeline_one(args.rec, args.outdir, do_fp16=args.fp16, do_quant_dynamic=args.dq, tag="rec") if args.cls: pipeline_one(args.cls, args.outdir, do_fp16=args.fp16, do_quant_dynamic=args.dq, tag="cls") if __name__ == "__main__": main()