Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving raw_model.state_dict() checkpoints #90

Open
anw-g01 opened this issue Jan 28, 2025 · 0 comments
Open

Saving raw_model.state_dict() checkpoints #90

anw-g01 opened this issue Jan 28, 2025 · 0 comments

Comments

@anw-g01
Copy link

anw-g01 commented Jan 28, 2025

The following line defining the raw_model is only called once at the start before the training loop begins:

raw_model = model.module if ddp else model # always contains the "raw" unwrapped model

If I'm not mistaken, doesn't this create a small bug because raw_model is never trained in the loop? As only model, which is either the normal model or a DDP() wrapped model, is the model instance that is trained? Unless raw_model is also updated during DDP?

If so with this logic, it seems that saving the model checkpoints is currently redundant, as only raw_model.state_dict() is being saved every time, which is static.

checkpoint = {
                    'model': raw_model.state_dict(),
                    'config': raw_model.config,
                    ...
                    ...
                }

As a suggestion would it be more correct to use:

checkpoint = {
                    'model': (model.module if ddp else model).state_dict(),
                    ...
                    ...
                    ...
                }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant