2/10/2024

pytorch lightning, save pth with ckpt for top k

 


it's custom checkpoint function

.

class CustomModelCheckpoint(ModelCheckpoint):
def __init__(self, save_top_k_pth=0, *args, **kwargs):
super(CustomModelCheckpoint, self).__init__(*args, **kwargs)
self.save_top_k_pth = save_top_k_pth
# Keep track of saved .pth files to manage the top K
self.saved_pth_files = []

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
# Construct checkpoint path manually (simplified example)
epoch = trainer.current_epoch
metric_score = "{:.2f}".format(trainer.callback_metrics['val_loss'].item())
filename = f"model-epoch={epoch}-val_loss={metric_score}.pth"
dirpath = self.dirpath if self.dirpath else trainer.default_root_dir
pth_path = os.path.join(dirpath, filename)

torch.save(pl_module.state_dict(), pth_path)
self.saved_pth_files.append(pth_path)
# Manage the top K saved .pth files
while len(self.saved_pth_files) > self.save_top_k_pth:
oldest_pth = self.saved_pth_files.pop(0)
if os.path.exists(oldest_pth):
os.remove(oldest_pth)

# Ensure to call the superclass method
return super().on_save_checkpoint(trainer, pl_module, checkpoint)

..


call it on training process

.

logger = loggers.TensorBoardLogger(save_dir="lightning_logs", name=config.model_version)

# Define the checkpoint callback
checkpoint_callback = CustomModelCheckpoint(
monitor='val_loss',
dirpath=f"{logger.save_dir}/{logger.name}/version_{logger.version}",
filename='model-{epoch:02d}-{val_loss:.2f}',
save_top_k=2, # Top 2 checkpoints
save_top_k_pth=2, # Also save top 2 .pth files
mode='min'
)

trainer = Trainer(max_epochs=config.num_epochs, accelerator='gpu',
devices=1, callbacks=[checkpoint_callback],
logger=logger, log_every_n_steps=10)

..



saved top k files (ckpt, pth) file showing up on folder.

Thank you.

๐Ÿ™‡๐Ÿป‍♂️

No comments:

Post a Comment