I have a trouble implementing of pretraining the generator of GAN in pytorchlightning. Is it a right method to define two classses like PretrainGenerator(pl.LightningModule) and GanForecastTask(pl.LightningModule)??
Then how they can update one same generator?? Is the below code can update the same generator? Thanks in advace :)
generator = model
discriminator = discriminator
pre_task = PretrainGenerator(config, generator=generator)
task = GanForecastTask(config, generator = generator,discriminator = discriminator)
pre_trainer = pl.Trainer(max_epochs = config['train']['pretrain_epochs'])
pre_train_dataloader = dm.pre_train_dataloader()
pre_trainer.fit(pre_task, pre_train_dataloader)
trainer = pl.Trainer(max_epochs = config['train']['max_epochs'], callbacks=callbacks)
trainer.fit(task, dm)
I decided to define two classes of pl.LightningModule. But I'm not sure whether it's a right way....