How do I make my discriminator and generator loss converge in DCGAN?

122 Views Asked by At

I am trying to build a DCGAN for T shirt design generation. I started with creating a DCGAN with the MNIST dataset and then felt confident to make one for a more complex use case.

I created a model following the architectural guidelines provided in the DCGAN research paper which are as follows:

Architecture guidelines for stable Deep Convolutional GANs

  • Replace any pooling layers with strided convolutions (discriminator) and fractional-strided convolutions (generator).
  • Use batchnorm in both the generator and the discriminator.
  • Remove fully connected hidden layers for deeper architectures.
  • Use ReLU activation in generator for all layers except for the output, which uses Tanh.
  • Use LeakyReLU activation in the discriminator for all layers

I'm having trouble understanding whether it's a mode collapse, vanishing gradient or convergence failure problem. The generated outputs are all same so it seems like a mode collapse. If it is, how do I solve this? I've tried experimenting with learning rates, the architecture, label smoothing and training fake and real batches separately instead of stacking them (I should add I stopped the training a lot of times and didn't let it complete when I saw d_loss or g_loss becoming 0 or fluctuating a lot)

These are some images from the dataset:

Real (training) images

And these are the generated images:

Generated Images

This is the plot of d_real loss, d_fake loss and gan loss against number of epochs.

Loss Plot

And this is the code:

import matplotlib.pyplot as plt
import numpy as np
import os

from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Dense, Dropout, Flatten, BatchNormalization, Input, Reshape, LeakyReLU, ReLU
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import img_to_array, load_img

def discriminator(input_shape=(128, 128, 3)):
    model = Sequential([
        Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=input_shape, kernel_initializer='glorot_uniform'), # kernel initializer was removed
        LeakyReLU(alpha=0.2),

        Conv2D(128, (5, 5), strides=(2, 2), padding='same', kernel_initializer='glorot_uniform'),
        BatchNormalization(momentum=0.5),
        LeakyReLU(alpha=0.2),

        Conv2D(256, (5, 5), strides=(2, 2), padding='same', kernel_initializer='glorot_uniform'),
        BatchNormalization(momentum=0.5),
        LeakyReLU(alpha=0.2),

        Conv2D(512, (5, 5), strides=(2, 2), padding='same', kernel_initializer='glorot_uniform'),
        BatchNormalization(momentum=0.5),
        LeakyReLU(alpha=0.2),

        # Output => 8 * 8 * 512
        Flatten(),
        Dense(1, activation='sigmoid')
    ])

    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=None)

    return model

def generator():
    model = Sequential([
        Dense(8*8*512, input_shape=(100,), kernel_initializer='glorot_uniform'),
        Reshape((8, 8, 512)),
        BatchNormalization(momentum=0.5),
        ReLU(),

        Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', kernel_initializer='glorot_uniform'),
        BatchNormalization(momentum=0.5),
        ReLU(),

        Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', kernel_initializer='glorot_uniform'),
        BatchNormalization(momentum=0.5),
        ReLU(),

        Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', kernel_initializer='glorot_uniform'),
        BatchNormalization(momentum=0.5),
        ReLU(),

        Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', activation='tanh', kernel_initializer='glorot_uniform'),
    ])

    return model

def gan(gen_model, disc_model):
    disc_model.trainable = False

    model = Sequential([
        gen_model,
        disc_model
    ])

    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)

    return model

def load_dataset(directory="/content/tshirt-resized", target_size=(128, 128)):
    images = []
    for filename in os.listdir(directory):
        img = load_img(os.path.join(directory, filename), target_size=target_size)
        images.append(img_to_array(img))

    dataset = np.array(images)
    dataset = dataset.astype('float32')
    dataset /= 255.0

    return dataset

def generate_real_samples(dataset, num_samples):
    ix = np.random.randint(0, dataset.shape[0], num_samples)

    X = dataset[ix]
    y = np.ones((num_samples, 1)) - 0.1  # Label smoothing


    return X, y

def generate_latent_points(num_samples): # gen model input
    x_input = np.random.randn(100 * num_samples)
    x_input = x_input.reshape(num_samples, 100)

    return x_input

def generate_fake_samples(gen_model, num_samples): # gen model output
    x_input = generate_latent_points(num_samples)

    X = gen_model.predict(x_input)
    y = np.zeros((num_samples, 1)) + 0.1  # Label smoothing


    return X, y

def train(gen_model, disc_model, gan_model, dataset, epochs=200, batch_size=128):
    num_batches_per_epoch = int(dataset.shape[0] / batch_size)
    d1_loss_hist = []
    d2_loss_hist = []
    gan_loss_hist = []
    for i in range(epochs):
        for j in range(num_batches_per_epoch):
            X_real, y_real = generate_real_samples(dataset, batch_size // 2)
            d1_loss = disc_model.train_on_batch(X_real, y_real)

            X_fake, y_fake = generate_fake_samples(gen_model, batch_size // 2)
            d2_loss = disc_model.train_on_batch(X_fake, y_fake)

            X_gan = generate_latent_points(batch_size)
            y_gan = np.ones((batch_size, 1))
            gan_loss = gan_model.train_on_batch(X_gan, y_gan)

            print('>%d, %d/%d, d1=%.3f, d2=%.3f, g=%.3f' % (i+1, j+1, num_batches_per_epoch, d1_loss, d2_loss, gan_loss))

        d1_loss_hist.append(d1_loss)
        d2_loss_hist.append(d2_loss)
        gan_loss_hist.append(gan_loss)
        if (i+1) % 50 == 0:
            filename = 'generator_model_%03d.keras' % (i + 1)
            gen_model.save(filename)

    plt.plot(d1_loss_hist, label='d-real')
    plt.plot(d2_loss_hist, label='d-fake')
    plt.plot(gan_loss_hist, label='gan')
    plt.legend()
    plt.savefig("Loss Plot.png")

disc_model = discriminator()
gen_model = generator()
gan_model = gan(gen_model, disc_model)
dataset = load_dataset()

train(gen_model, disc_model, gan_model, dataset)
1

There are 1 best solutions below

1
Karan Dhingra On

In the gan_model function, you have disabled the training of disc_model, which implies your discriminator is never being trained.

You need to modify your train function to enable discriminator training during first loop and disable it when you are optimizing the generator.

for i in range(epochs):
    for j in range(num_batches_per_epoch):
        # enable discriminator training
        disc_model.trainable = True
        X_real, y_real = generate_real_samples(dataset, batch_size // 2)
        d1_loss = disc_model.train_on_batch(X_real, y_real)

        X_fake, y_fake = generate_fake_samples(gen_model, batch_size // 2)
        d2_loss = disc_model.train_on_batch(X_fake, y_fake)

        # disable discriminator training
        disc_model.trainable = False
        X_gan = generate_latent_points(batch_size)
        y_gan = np.ones((batch_size, 1))
        gan_loss = gan_model.train_on_batch(X_gan, y_gan)