The code I have is as follows:
class Opt:
def __init__(self):
super(Opt, self).__init__()
self.n_epochs = 10
self.batch_size = 64
self.lr = 0.0002
self.b1 = 0.5
self.b2 = 0.999
self.latent_dim = 100
self.img_size = 64
self.channels = 3
self.sample_interval = 400
self.n_cpu = 14
opt= Opt()
img_shape = (opt.channels, opt.img_size, opt.img_size)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Generator(nn.Module):
def __init__(self, image_channels, age_embedding_size, latent_dim):
super(Generator, self).__init__()
# Embedding layer for age
self.age_embedding = nn.Embedding(age_embedding_size, latent_dim)
self.fc = nn.Sequential(
nn.Linear(latent_dim + image_channels, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, image_channels),
nn.Tanh() # To ensure output is between -1 and 1
)
def forward(self, current_image, current_age, future_age):
batch_size = current_image.size(0)
age_embedded = self.age_embedding(future_age)
age_embedded = age_embedded.view(batch_size, -1) # Flatten
x = torch.cat((current_image.view(batch_size, -1), age_embedded), dim=1)
generated_image = self.fc(x.view(batch_size, -1, 1, 1))
return generated_image
class Discriminator(nn.Module):
def __init__(self, image_channels, age_embedding_size):
super(Discriminator, self).__init__()
self.age_embedding = nn.Embedding(age_embedding_size, image_channels)
self.fc = nn.Sequential(
nn.Linear(image_channels + image_channels, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, image, current_age, future_age):
batch_size = image.size(0)
age_embedded = self.age_embedding(future_age)
age_embedded = age_embedded.view(batch_size, -1) # Flatten
x = torch.cat((image.view(batch_size, -1), age_embedded), dim=1)
validity = self.fc(x.view(batch_size, -1, 1, 1))
return validity
# Initialize the generator and discriminator
generator = Generator(image_channels=opt.channels,
age_embedding_size=opt.latent_dim,
latent_dim=opt.latent_dim)
discriminator = Discriminator(image_channels=opt.channels,
age_embedding_size=opt.latent_dim)
generator.to(device)
discriminator.to(device)
for epoch in range(3):
for i, batch in enumerate(dataloader):
real_images = batch['original_image'].to(device)
current_ages = batch['current_age'].to(device)
future_ages = batch['desired_age'].to(device)
print(real_images.shape, current_ages.shape, future_ages.shape)
# Adversarial ground truths
valid = torch.ones(real_images.size(0), 1).to(device)
fake = torch.zeros(real_images.size(0), 1).to(device)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Generate fake images
fake_images = generator(real_images, current_ages, future_ages)
# Discriminator loss for real images
d_real_loss = adversarial_loss(discriminator(real_images, current_ages, future_ages), valid)
# Discriminator loss for fake images
d_fake_loss = adversarial_loss(discriminator(fake_images.detach(), current_ages, future_ages), fake)
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
optimizer_D.step()
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Generate fake images
gen_images = generator(real_images, current_ages, future_ages)
# Generator loss
g_loss = adversarial_loss(discriminator(gen_images, current_ages, future_ages), valid)
g_loss.backward()
optimizer_G.step()
# Print progress
if (i + 1) % opt.sample_interval == 0:
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)
# Save generated images
save_image(gen_images.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
My dataset is in the format of
{'original_image': 'CACD2000\\16_Christopher_Mintz-Plasse_0005.jpg',
'target_image': 'CACD2000\\16_Christopher_Mintz-Plasse_0016.jpg',
'current_age': 16,
'desired_age': 27},
{'original_image': 'CACD2000\\16_Chris_Brown_0003.jpg',
'target_image': 'CACD2000\\16_Chris_Brown_0004.jpg',
'current_age': 14,
'desired_age': 15},
Yet, I am getting errors in the dimension of the matrices being multiplied in the network.
I am trying to create a face-aging project that takes 1 image and 2 ages and gives me the aged face. Yet, I am having trouble in the architecture. How can I solve this??
print(real_images.shape, current_ages.shape, future_ages.shape)
gives the output torch.Size([64, 3, 250, 250]) torch.Size([64]) torch.Size([64])
How do I solve this?
I tried changing some dimensions as well as the embedding shapes. Yet I encountered multiple errors and I am lost.