I am trying to make a GAN with a multiscale discriminator. How do I call the parameters of the discriminator if the discriminator is composed of three other discriminators? Multiscale discriminator code:
class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
use_sigmoid=False, num_D=3):
super(MultiscaleDiscriminator, self).__init__()
self.input_nc=input_nc
self.ndf=ndf
self.n_layers=n_layers
self.norm_layer=norm_layer
self.use_sigmoid=use_sigmoid
self.num_D=num_D
self.discriminators = []
# instantiate the discriminators
for i in range(self.num_D):
self.discriminators.append(
self.singleD_forward(int(input_nc / (2 ** i)))
)
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
self.inputs = [1, 1]
def singleD_forward(self, input_nc):
return NLayerDiscriminator(
input_nc=input_nc,
ndf=self.ndf,
n_layers=self.n_layers,
norm_layer=self.norm_layer,
use_sigmoid=self.use_sigmoid)
def forward(self, inputs):
is_conditioned = isinstance(inputs, list)
if is_conditioned:
(
fake_or_real,
condition,
) = inputs # inputs is a tuple containing the generated images and the conditions
else:
fake_or_real = inputs
condition = None
outs = []
features = []
fake_or_real_i = fake_or_real
condition_i = condition
for i, discriminator in enumerate(self.discriminators):
# compute value of the i-th discriminator
out, feat = discriminator(
[fake_or_real_i, condition_i]
)
# append output and features
outs.append(out)
features.extend(feat)
# reduce input size
if i != len(self.discriminators) - 1:
fake_or_real_i = self.downsample(fake_or_real_i)
condition_i = (
self.downsample(condition_i) if condition_i is not None else None
)
# handle output values
if return_features:
return outs, features
return outs
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
super(NLayerDiscriminator, self).__init__()
self.getIntermFeat = getIntermFeat
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw-1.0)/2))
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
norm_layer(nf), nn.LeakyReLU(0.2, True)
]]
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
if use_sigmoid:
sequence += [[nn.Sigmoid()]]
if getIntermFeat:
for n in range(len(sequence)):
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
else:
sequence_stream = []
for n in range(len(sequence)):
sequence_stream += sequence[n]
self.model = nn.Sequential(*sequence_stream)
def forward(self, input):
if self.getIntermFeat:
res = [input]
for n in range(self.n_layers+2):
model = getattr(self, 'model'+str(n))
res.append(model(res[-1]))
return res[1:]
else:
return self.model(input)
When I run:
disc_A = MultiscaleDiscriminator(dim_A).to(device)
disc_A_opt = torch.optim.Adam(disc_A.parameters(), lr=lr, betas=(0.5, 0.999))
I get:
ValueError Traceback (most recent call last)
Cell In[16], line 18
14 disc_B = disc_B.apply(weights_init)
16 gen_opt = torch.optim.Adam(list(gen_AB.parameters()) + list(gen_BA.parameters()), lr=lr, betas=(0.5, 0.999))
---> 18 disc_A_opt = torch.optim.Adam(disc_A.parameters(), lr=lr, betas=(0.5, 0.999))
20 disc_B_opt = torch.optim.Adam(disc_B.parameters(), lr=lr, betas=(0.5, 0.999))
22 '''
23 # Feel free to change pretrained to False if you're training the model from scratch
24 pretrained = False
(...)
37 disc_A = disc_A.apply(weights_init)
38 disc_B = disc_B.apply(weights_init)'''
File ~\anaconda3\Lib\site-packages\torch\optim\adam.py:45, in Adam.__init__(self, params, lr, betas, eps, weight_decay, amsgrad, foreach, maximize, capturable, differentiable, fused)
39 raise ValueError(f"Invalid weight_decay value: {weight_decay}")
41 defaults = dict(lr=lr, betas=betas, eps=eps,
42 weight_decay=weight_decay, amsgrad=amsgrad,
43 maximize=maximize, foreach=foreach, capturable=capturable,
44 differentiable=differentiable, fused=fused)
---> 45 super().__init__(params, defaults)
47 if fused:
48 if differentiable:
File ~\anaconda3\Lib\site-packages\torch\optim\optimizer.py:261, in Optimizer.__init__(self, params, defaults)
259 param_groups = list(params)
260 if len(param_groups) == 0:
--> 261 raise ValueError("optimizer got an empty parameter list")
262 if not isinstance(param_groups[0], dict):
263 param_groups = [{'params': param_groups}]
ValueError: optimizer got an empty parameter list