Add configure_optimizer member function in pytorch lightning model class.
refer to code:
..
def configure_optimizers(self):
optimizer = getattr(torch.optim, self.cfg.optimizer)
self.optimizer = optimizer(self.parameters(), lr=float(self.cfg.lr))
total_bs = int(self.cfg.train_dataloader_len / self.cfg.gpus)
epochs = self.cfg.epochs
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
self.optimizer, max_lr=self.cfg.lr,
anneal_strategy='linear', div_factor=100,
steps_per_epoch=total_bs, pct_start=(1/self.cfg.epochs),
epochs=epochs)
sched = {
'scheduler': self.scheduler,
'interval': 'step',
}
return [self.optimizer], [sched]
..
Thank you. ๐๐ป♂️
www.marearts.com