Showing posts with label pytorch lightning. Show all posts
Showing posts with label pytorch lightning. Show all posts

2/10/2024

pytorch lightning, save pth with ckpt for top k

 


it's custom checkpoint function

.

class CustomModelCheckpoint(ModelCheckpoint):
def __init__(self, save_top_k_pth=0, *args, **kwargs):
super(CustomModelCheckpoint, self).__init__(*args, **kwargs)
self.save_top_k_pth = save_top_k_pth
# Keep track of saved .pth files to manage the top K
self.saved_pth_files = []

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
# Construct checkpoint path manually (simplified example)
epoch = trainer.current_epoch
metric_score = "{:.2f}".format(trainer.callback_metrics['val_loss'].item())
filename = f"model-epoch={epoch}-val_loss={metric_score}.pth"
dirpath = self.dirpath if self.dirpath else trainer.default_root_dir
pth_path = os.path.join(dirpath, filename)

torch.save(pl_module.state_dict(), pth_path)
self.saved_pth_files.append(pth_path)
# Manage the top K saved .pth files
while len(self.saved_pth_files) > self.save_top_k_pth:
oldest_pth = self.saved_pth_files.pop(0)
if os.path.exists(oldest_pth):
os.remove(oldest_pth)

# Ensure to call the superclass method
return super().on_save_checkpoint(trainer, pl_module, checkpoint)

..


call it on training process

.

logger = loggers.TensorBoardLogger(save_dir="lightning_logs", name=config.model_version)

# Define the checkpoint callback
checkpoint_callback = CustomModelCheckpoint(
monitor='val_loss',
dirpath=f"{logger.save_dir}/{logger.name}/version_{logger.version}",
filename='model-{epoch:02d}-{val_loss:.2f}',
save_top_k=2, # Top 2 checkpoints
save_top_k_pth=2, # Also save top 2 .pth files
mode='min'
)

trainer = Trainer(max_epochs=config.num_epochs, accelerator='gpu',
devices=1, callbacks=[checkpoint_callback],
logger=logger, log_every_n_steps=10)

..



saved top k files (ckpt, pth) file showing up on folder.

Thank you.

🙇🏻‍♂️

9/05/2023

Saving additional file while pytorch lightning training.

 if you want to save some additional file in checkpoints where PyTorch lightning save latest or best model in certain folder automatically, add this function in 

.

#training class using pl
class my_trainer(pl.LightningModule):
def __init__(self, cfg):
super().__init__()

..

add this model to save additional file

.

def on_save_checkpoint(self, checkpoint):
# Call the parent method first (optional)
super().on_save_checkpoint(checkpoint)
# Your custom code to save additional files
dirpath = None
for callback in self.trainer.callbacks:
if isinstance(callback, ModelCheckpoint):
dirpath = callback.dirpath
break

if dirpath is not None:
additional_filepath = os.path.join(dirpath, "my_additional_file.txt")
with open(additional_filepath, "w") as f:
f.write("Some additional data")
print(f"Saved additional file to {additional_filepath}")
else:
print("Could not find ModelCheckpoint dirpath to save additional file.")

..

ok, now try it!

Good luck!


www.marearts.com

10/28/2022

OneCycle LR set in Pytorch lightning

 

Add configure_optimizer member function in pytorch lightning model class.

refer to code:


..

def configure_optimizers(self):
optimizer = getattr(torch.optim, self.cfg.optimizer)
self.optimizer = optimizer(self.parameters(), lr=float(self.cfg.lr))

total_bs = int(self.cfg.train_dataloader_len / self.cfg.gpus)

epochs = self.cfg.epochs
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
self.optimizer, max_lr=self.cfg.lr,
anneal_strategy='linear', div_factor=100,
steps_per_epoch=total_bs, pct_start=(1/self.cfg.epochs),
epochs=epochs)
sched = {
'scheduler': self.scheduler,
'interval': 'step',
}
return [self.optimizer], [sched]

..




Thank you. 🙇🏻‍♂️

www.marearts.com

10/14/2022

pytorch lightning set validation interval

 

refer to ex:

..

# default used by the Trainer
trainer = Trainer(val_check_interval=1.0)
# check validation set 4 times during a training epoch
trainer = Trainer(val_check_interval=0.25)
# check validation set every 1000 training batches in the current epoch
trainer = Trainer(val_check_interval=1000)

..


detail is here: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html


www.marearts.com