Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to do custom EarlyStopping?❓ [QUESTION] #380

Open
ThePauliPrinciple opened this issue Oct 26, 2023 · 4 comments
Open

How to do custom EarlyStopping?❓ [QUESTION] #380

ThePauliPrinciple opened this issue Oct 26, 2023 · 4 comments
Labels
question Further information is requested

Comments

@ThePauliPrinciple
Copy link

I would like to do some custom early stopping (e.g. based on a file existing, or checking if I get close to a walltime on a compute cluster)

Is there some way to specify a custom early stopping class?
I tried using early_stopping and early_stopping_conds arguments of the trainer (or in the config.yaml), but could not make anything happen.

I was able to accomplish what I wanted through an on-end-epoch callback

class FileStopCallback:
    def __init__(self, stop_file: Path):
        self.stop_file = stop_file

    def __call__(self, trainer):
        if self.stop_file.is_file():
            with open(self.stop_file, 'r') as f:
                reason = f.readline()
            trainer.__class__.stop_cond = True
            trainer.stop_arg = f"Early stopping: stop file detected with reason: {reason}"

But it seems rather hackish (you can't set trainer.stop_cond directly because it is a property without a setter).

@ThePauliPrinciple ThePauliPrinciple added the question Further information is requested label Oct 26, 2023
@Linux-cpp-lisp
Copy link
Collaborator

Hi @ThePauliPrinciple ,

Thanks for your nice question and work with our code!

Re

checking if I get close to a walltime on a compute cluster
we do have support for a fixed walltime bound: https://github.com/mir-group/nequip/blob/develop/configs/full.yaml#L210-L211. But if you want to query the job scheduler for example that will have to be custom of course.

I've just added support for custom early stopping conditions on branch: https://github.com/mir-group/nequip/tree/feature-custom-early-stop with an example at https://github.com/mir-group/nequip-example-extension/tree/earlystop. Please give this a try and let me know if it works for you, and I'll merge it down.

If this doesn't fully solve the issue (or even if it does), it might be a more complicated workflow than I'm anticipating, and maybe we should have a quick call to discuss---please feel free to send me an email at the address listed in my profile.

@ThePauliPrinciple
Copy link
Author

ThePauliPrinciple commented Oct 31, 2023

This looks good to me.

Passing the trainer object to the stopper might be useful to some, although for my use case I am only interested in "external" information.

I'm not exactly certain what the comment about restarting means, in particular, when is a stopper considered "stateful"?

The original early stopper also returned values to immediately debug/print, maybe that's also nice to add.

@Linux-cpp-lisp
Copy link
Collaborator

Great!

A stopper is "stateful" when it maintains a state like, say, how many epochs the validation loss hasn't improved (like the patience setting) or what the minimum observed value was (see https://github.com/mir-group/nequip/blob/feature-custom-early-stop/nequip/train/early_stopping.py#L120-L121). If it only depends on the current arguments to the object, and not any state stored in your custom object, then it's not stateful. (State of the trainer, if that was passed in, will be correctly preserved across restarts.)

The original early stopper also returned values to immediately debug/print, maybe that's also nice to add.

What do you mean, exactly?

@ThePauliPrinciple
Copy link
Author

return stop, stop_args, debug_args

Here debug_args is returned, which is printed to the log:
if self.early_stopping_conds is not None and hasattr(self, "mae_dict"):
early_stop, early_stop_args, debug_args = self.early_stopping_conds(
self.mae_dict
)
if debug_args is not None:
self.logger.debug(debug_args)
if early_stop:
self.stop_arg = early_stop_args
return True

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants