88 lines
3.1 KiB
Python
88 lines
3.1 KiB
Python
# make_ppocr_onnx_dynamic.py
|
|
import argparse
|
|
import onnx
|
|
from onnx import helper, shape_inference
|
|
|
|
def _set_dim(dim, value=None, param=None):
|
|
# value(정수) 또는 param(문자열 심볼) 한쪽만 설정
|
|
dim.ClearField('dim_value')
|
|
dim.ClearField('dim_param')
|
|
if value is not None:
|
|
dim.dim_value = int(value)
|
|
elif param is not None:
|
|
dim.dim_param = str(param)
|
|
|
|
def _shape_str(vi):
|
|
tp = vi.type.tensor_type
|
|
if not tp.HasField("shape"):
|
|
return "scalar"
|
|
dims = []
|
|
for d in tp.shape.dim:
|
|
if d.HasField("dim_value"):
|
|
dims.append(str(d.dim_value))
|
|
elif d.HasField("dim_param"):
|
|
dims.append(d.dim_param)
|
|
else:
|
|
dims.append("?")
|
|
return "[" + ",".join(dims) + "]"
|
|
|
|
def patch_ppocr_input(model_path, out_path, model_type):
|
|
"""
|
|
model_type:
|
|
- det / cls : NCHW에서 N,H,W를 동적으로
|
|
- rec : NCHW에서 N, W 동적 + H는 고정(원래 값이 있으면 유지, 없으면 48로)
|
|
"""
|
|
m = onnx.load(model_path)
|
|
if len(m.graph.input) == 0:
|
|
raise RuntimeError("No graph input found in ONNX model.")
|
|
|
|
# 보통 입력은 하나(x). 복수여도 첫 번째만 패치(일반적인 PPOCR는 1개)
|
|
vi = m.graph.input[0]
|
|
tp = vi.type.tensor_type
|
|
if not tp.HasField("shape") or len(tp.shape.dim) != 4:
|
|
raise RuntimeError(f"Expect 4D NCHW input, got: {_shape_str(vi)}")
|
|
|
|
n, c, h, w = tp.shape.dim # N,C,H,W
|
|
|
|
# C(채널)는 3 고정이 일반적이라 놔두고, 나머지만 처리
|
|
# det/cls: [N=?, C=3, H=?, W=?]
|
|
if model_type in ("det", "cls"):
|
|
_set_dim(n, param="N")
|
|
# H가 값으로 박혀 있든 없든 동적 처리
|
|
_set_dim(h, param="H")
|
|
_set_dim(w, param="W")
|
|
|
|
# rec: [N=?, C=3, H=고정, W=?] (중국어 계열은 보통 H=48, 영문 모바일계열은 32)
|
|
elif model_type == "rec":
|
|
_set_dim(n, param="N")
|
|
# H 유지(이미 값이 있으면 유지), 없으면 48로 설정
|
|
if not h.HasField("dim_value") and not h.HasField("dim_param"):
|
|
_set_dim(h, value=48) # 기본값 48
|
|
# W는 동적
|
|
_set_dim(w, param="W")
|
|
else:
|
|
raise ValueError("model_type must be one of: det / rec / cls")
|
|
|
|
# 모델 체크 + (가능하면) Shape Inference 수행
|
|
onnx.checker.check_model(m)
|
|
try:
|
|
m = shape_inference.infer_shapes(m)
|
|
except Exception as e:
|
|
# 동적 치수가 많으면 실패해도 무방. 패치는 이미 적용됨.
|
|
print(f"[warn] onnx shape_inference skipped: {e}")
|
|
|
|
onnx.save(m, out_path)
|
|
print(f"[ok] saved: {out_path}")
|
|
print(f" input 0 after patch: {_shape_str(m.graph.input[0])}")
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--in", dest="inp", required=True, help="input onnx path")
|
|
ap.add_argument("--out", dest="out", required=True, help="output onnx path")
|
|
ap.add_argument("--type", dest="typ", required=True, choices=["det","rec","cls"], help="ppocr model type")
|
|
args = ap.parse_args()
|
|
patch_ppocr_input(args.inp, args.out, args.typ)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|