diff --git a/nanodo/train.py b/nanodo/train.py index 0b25caf..decea3f 100644 --- a/nanodo/train.py +++ b/nanodo/train.py @@ -45,7 +45,7 @@ if TYPE_CHECKING: - import ml_collections + import ml_collections # pylint: disable=g-bad-import-order PyTree = Any @@ -178,7 +178,7 @@ def _checkpoint(): if c.checkpoint: step = trainer.step logging.info("Saving last checkpoint step %d", step) - ckpt_mngr.save(step, {"state": trainer.state, "data": train_iter}) + ckpt_mngr.save(step, {"state": trainer.state, "data": train_iter}) # pylint: disable=undefined-variable def _process_metrics(step, microbatch_metrics): if microbatch_metrics and step % c.write_train_metrics_every_steps == 0: @@ -230,7 +230,7 @@ def _process_metrics(step, microbatch_metrics): _process_metrics(c.opt.num_train_steps, pending_microbatch_metrics) if c.checkpoint: - ckpt_mngr.close() + ckpt_mngr.close() # pylint: disable=undefined-variable class Trainer: