Skip to content

Commit

Permalink
added model convert: accelerate checkpoint -> pytorch model
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunkoome committed Dec 17, 2024
1 parent fdeb042 commit 12512ca
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 1 deletion.
13 changes: 13 additions & 0 deletions configs/inference_test_hkkim_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
defaults:
- config
- _self_

log_root_prefix: ./magicdrive-log/inference_test_demo

runner:
validation_batch_size: 10

resume_on_exists: false
show_box: true

fix_seed_for_every_generation: False
27 changes: 27 additions & 0 deletions configs/model_convert_hkkim_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
defaults:
- _self_
- model: SDv1.5mv_rawbox
- dataset: Nuscenes_cache
- accelerator: default
- runner: default

task_id: "0.0.0"
log_root_prefix: ./magicdrive-log/model_convert
projname: ${model.name}
hydra:
run:
dir: ${log_root_prefix}/${projname}_${now:%Y-%m-%d}_${now:%H-%M}_${task_id}
output_subdir: hydra

try_run: false
debug: false
log_root: ???
init_method: env://
seed: 42
fix_seed_within_batch: false

resume_from_checkpoint: null
resume_reset_scheduler: false
validation_only: false
# num_gpus: 1
# num_workers: 4
1 change: 1 addition & 0 deletions magicdrive/misc/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def run_one_batch(cfg, pipe, val_input, weight_dtype, global_generator=None,
camera_param = val_input["camera_param"].to(weight_dtype)

# 3-dim list: B, Times, views
print(val_input['captions'])
gen_imgs_list = run_one_batch_pipe_func(
cfg, pipe, val_input['pixel_values'], val_input['captions'],
val_input['bev_map_with_aux'], camera_param, val_input['kwargs'],
Expand Down
7 changes: 7 additions & 0 deletions scripts/inference_test_hkkim.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash

#python tools/testhkkim.py resume_from_checkpoint=./pretrained/SDv1.5mv-rawbox_2023-09-07_18-39_224x400

#python tools/testhkkim.py resume_from_checkpoint=/home/hyunkoo/DATA/ssd8tb/Journal/MagicDrive/pretrained/SDv1.5mv-rawbox_2023-09-07_18-39_224x400

python tools/inference_test_hkkim.py resume_from_checkpoint=/home/hyunkoo/DATA/ssd8tb/Journal/MagicDrive/magicdrive-log/model_convert/SDv1.5mv-rawbox_2024-12-17_23-16_224x400
4 changes: 4 additions & 0 deletions scripts/save_pytorch_model_from_accelerate_checkpoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
accelerate launch --config_file /home/hyunkoo/DATA/ssd8tb/Journal/MagicDrive/configs/accelerator/accelerate_config_1gpu.yaml \
tools/save_pytorch_model_from_accelerate_checkpoint.py \
resume_from_checkpoint=/home/hyunkoo/DATA/ssd8tb/Journal/MagicDrive/magicdrive-log/SDv1.5mv-rawbox_2024-12-13_21-38_224x400/checkpoint-160000 \
+exp=224x400 runner=2gpus
2 changes: 1 addition & 1 deletion tools/testhkkim.py → tools/inference_test_hkkim.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def output_func(x): return concat_6_views(x)
# def output_func(x): return img_concat_h(*x[:3])


@hydra.main(version_base=None, config_path="../configs", config_name="test_config")
@hydra.main(version_base=None, config_path="../configs", config_name="inference_test_hkkim_config")
def main(cfg: DictConfig):
if cfg.debug:
import debugpy
Expand Down
111 changes: 111 additions & 0 deletions tools/save_pytorch_model_from_accelerate_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os
import sys
import logging
from dotenv import load_dotenv
load_dotenv('/home/hyunkoo/DATA/ssd8tb/Journal/MagicDrive/.env')

print(os.environ['HF_HOME'])

import warnings
from shapely.errors import ShapelyDeprecationWarning
warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning)

import hydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
import torch

from mmdet3d.datasets import build_dataset
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.utils import set_seed

sys.path.append(".") # 필요하다면 적절히 수정
import magicdrive.dataset.pipeline
from magicdrive.misc.common import load_module


def set_logger(global_rank, logdir):
if global_rank == 0:
return
logging.info(f"reset logger for {global_rank}")
root = logging.getLogger()
root.handlers.clear()
root.setLevel(logging.DEBUG)
formatter = logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s] - %(message)s")
file_path = os.path.join(logdir, f"train.{global_rank}.log")
handler = logging.FileHandler(file_path)
handler.setFormatter(formatter)
root.addHandler(handler)


@hydra.main(version_base=None, config_path="../configs", config_name="model_convert_hkkim_config")
def main(cfg: DictConfig):
# 기존과 동일한 환경 설정
logging.getLogger().setLevel(logging.DEBUG)
for handler in logging.getLogger().handlers:
if isinstance(handler, logging.FileHandler) or cfg.try_run:
handler.setLevel(logging.DEBUG)
else:
handler.setLevel(logging.INFO)
logging.getLogger("shapely.geos").setLevel(logging.WARN)
logging.getLogger("asyncio").setLevel(logging.INFO)
logging.getLogger("accelerate.tracking").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.WARN)
logging.getLogger("PIL").setLevel(logging.WARN)
logging.getLogger("matplotlib").setLevel(logging.WARN)

setattr(cfg, "log_root", HydraConfig.get().runtime.output_dir)

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
gradient_accumulation_steps=cfg.accelerator.gradient_accumulation_steps,
mixed_precision=cfg.accelerator.mixed_precision,
log_with=cfg.accelerator.report_to,
project_dir=cfg.log_root,
kwargs_handlers=[ddp_kwargs],
)
set_logger(accelerator.process_index, cfg.log_root)
set_seed(cfg.seed)

# dataset 필요 없다면 생략 가능 (단, runner 초기화에 필요하다면 남겨둬야 함)
train_dataset = build_dataset(
OmegaConf.to_container(cfg.dataset.data.train, resolve=True)
)
val_dataset = build_dataset(
OmegaConf.to_container(cfg.dataset.data.val, resolve=True)
)

# runner 초기화
runner_cls = load_module(cfg.model.runner_module)
runner = runner_cls(cfg, accelerator, train_dataset, val_dataset)
runner.set_optimizer_scheduler()
runner.prepare_device()

# 여기서 이미 학습 완료된 체크포인트를 로드
# cfg.resume_from_checkpoint 를 통해 체크포인트 경로를 받아온다고 가정
if not cfg.resume_from_checkpoint:
raise ValueError("resume_from_checkpoint 경로를 지정해주세요.")
load_path = cfg.resume_from_checkpoint
accelerator.load_state(load_path)

# unwrap_model 로 모델 추출
controlnet = accelerator.unwrap_model(runner.controlnet)

# 모델 저장 (controlnet_dir은 original code에서 cfg.model.controlnet_dir 로 지정)
save_dir = os.path.join(cfg.log_root, "controlnet")
os.makedirs(save_dir, exist_ok=True)
controlnet.save_pretrained(save_dir)
logging.info(f"Model saved to: {save_dir}")

# unwrap_model 로 모델 추출
unet = accelerator.unwrap_model(runner.unet)

# 모델 저장 (controlnet_dir은 original code에서 cfg.model.controlnet_dir 로 지정)
save_dir = os.path.join(cfg.log_root, "unet")
os.makedirs(save_dir, exist_ok=True)
unet.save_pretrained(save_dir)
logging.info(f"Model saved to: {save_dir}")


if __name__ == "__main__":
main()

0 comments on commit 12512ca

Please sign in to comment.