# make_ppocr_onnx_dynamic.py import argparse import onnx from onnx import helper, shape_inference def _set_dim(dim, value=None, param=None): # value(정수) 또는 param(문자열 심볼) 한쪽만 설정 dim.ClearField('dim_value') dim.ClearField('dim_param') if value is not None: dim.dim_value = int(value) elif param is not None: dim.dim_param = str(param) def _shape_str(vi): tp = vi.type.tensor_type if not tp.HasField("shape"): return "scalar" dims = [] for d in tp.shape.dim: if d.HasField("dim_value"): dims.append(str(d.dim_value)) elif d.HasField("dim_param"): dims.append(d.dim_param) else: dims.append("?") return "[" + ",".join(dims) + "]" def patch_ppocr_input(model_path, out_path, model_type): """ model_type: - det / cls : NCHW에서 N,H,W를 동적으로 - rec : NCHW에서 N, W 동적 + H는 고정(원래 값이 있으면 유지, 없으면 48로) """ m = onnx.load(model_path) if len(m.graph.input) == 0: raise RuntimeError("No graph input found in ONNX model.") # 보통 입력은 하나(x). 복수여도 첫 번째만 패치(일반적인 PPOCR는 1개) vi = m.graph.input[0] tp = vi.type.tensor_type if not tp.HasField("shape") or len(tp.shape.dim) != 4: raise RuntimeError(f"Expect 4D NCHW input, got: {_shape_str(vi)}") n, c, h, w = tp.shape.dim # N,C,H,W # C(채널)는 3 고정이 일반적이라 놔두고, 나머지만 처리 # det/cls: [N=?, C=3, H=?, W=?] if model_type in ("det", "cls"): _set_dim(n, param="N") # H가 값으로 박혀 있든 없든 동적 처리 _set_dim(h, param="H") _set_dim(w, param="W") # rec: [N=?, C=3, H=고정, W=?] (중국어 계열은 보통 H=48, 영문 모바일계열은 32) elif model_type == "rec": _set_dim(n, param="N") # H 유지(이미 값이 있으면 유지), 없으면 48로 설정 if not h.HasField("dim_value") and not h.HasField("dim_param"): _set_dim(h, value=48) # 기본값 48 # W는 동적 _set_dim(w, param="W") else: raise ValueError("model_type must be one of: det / rec / cls") # 모델 체크 + (가능하면) Shape Inference 수행 onnx.checker.check_model(m) try: m = shape_inference.infer_shapes(m) except Exception as e: # 동적 치수가 많으면 실패해도 무방. 패치는 이미 적용됨. print(f"[warn] onnx shape_inference skipped: {e}") onnx.save(m, out_path) print(f"[ok] saved: {out_path}") print(f" input 0 after patch: {_shape_str(m.graph.input[0])}") def main(): ap = argparse.ArgumentParser() ap.add_argument("--in", dest="inp", required=True, help="input onnx path") ap.add_argument("--out", dest="out", required=True, help="output onnx path") ap.add_argument("--type", dest="typ", required=True, choices=["det","rec","cls"], help="ppocr model type") args = ap.parse_args() patch_ppocr_input(args.inp, args.out, args.typ) if __name__ == "__main__": main()