diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 4607ef65b7..264d3fbf0e 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1273,7 +1273,6 @@ def ckpt_export( config_file_, filepath_, ckpt_file_, - bundle_root_, net_id_, meta_file_, key_in_ckpt_, @@ -1285,7 +1284,6 @@ def ckpt_export( "config_file", filepath=None, ckpt_file=None, - bundle_root=os.getcwd(), net_id=None, meta_file=None, key_in_ckpt="", @@ -1293,18 +1291,23 @@ def ckpt_export( input_shape=None, converter_kwargs={}, ) + bundle_root = _args.get("bundle_root", os.getcwd()) parser = ConfigParser() - parser.read_config(f=config_file_) - meta_file_ = os.path.join(bundle_root_, "configs", "metadata.json") if meta_file_ is None else meta_file_ - filepath_ = os.path.join(bundle_root_, "models", "model.ts") if filepath_ is None else filepath_ - ckpt_file_ = os.path.join(bundle_root_, "models", "model.pt") if ckpt_file_ is None else ckpt_file_ - if not os.path.exists(ckpt_file_): - raise FileNotFoundError(f'Checkpoint file "{ckpt_file_}" not found, please specify it in argument "ckpt_file".') + meta_file_ = os.path.join(bundle_root, "configs", "metadata.json") if meta_file_ is None else meta_file_ if os.path.exists(meta_file_): parser.read_meta(f=meta_file_) + # the rest key-values in the _args are to override config content + for k, v in _args.items(): + parser[k] = v + + filepath_ = os.path.join(bundle_root, "models", "model.ts") if filepath_ is None else filepath_ + ckpt_file_ = os.path.join(bundle_root, "models", "model.pt") if ckpt_file_ is None else ckpt_file_ + if not os.path.exists(ckpt_file_): + raise FileNotFoundError(f'Checkpoint file "{ckpt_file_}" not found, please specify it in argument "ckpt_file".') + net_id_ = "network_def" if net_id_ is None else net_id_ try: parser.get_parsed_content(net_id_) @@ -1313,10 +1316,6 @@ def ckpt_export( f'Network definition "{net_id_}" cannot be found in "{config_file_}", specify name with argument "net_id".' ) from e - # the rest key-values in the _args are to override config content - for k, v in _args.items(): - parser[k] = v - # When export through torch.jit.trace without providing input_shape, will try to parse one from the parser. if (not input_shape_) and use_trace: input_shape_ = _get_fake_input_shape(parser=parser)