From 6d8eb317d32f407bfd8c39c52cc36915981b453a Mon Sep 17 00:00:00 2001 From: Nanodo Team Date: Tue, 28 May 2024 22:20:36 +0000 Subject: [PATCH] No public description PiperOrigin-RevId: 638040825 Change-Id: Ica5963804f365fe74b74201b1c41ee90f67aa86e --- nanodo/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nanodo/train.py b/nanodo/train.py index 0b25caf..decea3f 100644 --- a/nanodo/train.py +++ b/nanodo/train.py @@ -45,7 +45,7 @@ if TYPE_CHECKING: - import ml_collections + import ml_collections # pylint: disable=g-bad-import-order PyTree = Any @@ -178,7 +178,7 @@ def _checkpoint(): if c.checkpoint: step = trainer.step logging.info("Saving last checkpoint step %d", step) - ckpt_mngr.save(step, {"state": trainer.state, "data": train_iter}) + ckpt_mngr.save(step, {"state": trainer.state, "data": train_iter}) # pylint: disable=undefined-variable def _process_metrics(step, microbatch_metrics): if microbatch_metrics and step % c.write_train_metrics_every_steps == 0: @@ -230,7 +230,7 @@ def _process_metrics(step, microbatch_metrics): _process_metrics(c.opt.num_train_steps, pending_microbatch_metrics) if c.checkpoint: - ckpt_mngr.close() + ckpt_mngr.close() # pylint: disable=undefined-variable class Trainer: