How to pretrain generator of GAN in pytorchlightning

37 Views Asked by At

I have a trouble implementing of pretraining the generator of GAN in pytorchlightning. Is it a right method to define two classses like PretrainGenerator(pl.LightningModule) and GanForecastTask(pl.LightningModule)??

Then how they can update one same generator?? Is the below code can update the same generator? Thanks in advace :)

generator = model
discriminator = discriminator

pre_task = PretrainGenerator(config, generator=generator)
task = GanForecastTask(config, generator = generator,discriminator = discriminator)
                                                                                                      
pre_trainer = pl.Trainer(max_epochs = config['train']['pretrain_epochs'])
pre_train_dataloader = dm.pre_train_dataloader()
pre_trainer.fit(pre_task, pre_train_dataloader)

trainer = pl.Trainer(max_epochs = config['train']['max_epochs'], callbacks=callbacks)
trainer.fit(task, dm)

I decided to define two classes of pl.LightningModule. But I'm not sure whether it's a right way....

0

There are 0 best solutions below