MENU

[AI from Scratch] Episode 290: Image Generation Using GANs — Understanding Generative Adversarial Networks

TOC

Recap and Today’s Theme

Hello! In the previous episode, we explored the implementation of U-Net, focusing on its role as a segmentation model for medical images and its performance in fields like healthcare and autonomous driving. U-Net excels at pixel-level image classification, making it highly effective for detailed tasks.

Today, we will dive into Generative Adversarial Networks (GANs) and their application in image generation. GANs have become a hot topic in AI due to their ability to generate realistic images, augment datasets, and more. In this article, we’ll explain the basic structure of GANs, how image generation works, and potential applications.

What is GAN (Generative Adversarial Network)?

GAN (Generative Adversarial Network) is a type of neural network proposed by Ian Goodfellow in 2014. It involves two networks: a generator and a discriminator, which compete with each other during training. This adversarial relationship gives GANs their name.

Structure of GAN

A GAN consists of two main components:

  1. Generator:
  • The generator takes noise (random data) as input and generates data that resembles real data (e.g., images). Its goal is to “trick” the discriminator into believing the generated data is real.
  1. Discriminator:
  • The discriminator tries to differentiate between real data and data generated by the generator. Its goal is to correctly identify whether the input data is real or fake.

These two networks compete in a game-like process, where the generator tries to fool the discriminator, and the discriminator tries to improve at identifying fake data.

How GANs Learn

GANs work through a minimax game. Here’s an overview of the process:

  • The generator tries to create data that looks as real as possible, aiming to deceive the discriminator.
  • The discriminator learns to recognize which data is real and which is generated (fake).

As they compete, the generator gradually improves at producing more realistic data, while the discriminator gets better at spotting fake data. Over time, this competition pushes the generator to produce data that is nearly indistinguishable from real data.

Implementing GANs

Let’s build a simple GAN using Python and Keras (TensorFlow). In this example, we’ll generate handwritten digit images using the MNIST dataset.

Required Libraries

pip install tensorflow numpy matplotlib

Building a GAN

The following code implements a basic GAN. We define the generator and discriminator models and combine them for training.

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt

# Define the Generator model
def build_generator():
    model = models.Sequential([
        layers.Dense(128, input_dim=100, activation='relu'),
        layers.Dense(256, activation='relu'),
        layers.Dense(512, activation='relu'),
        layers.Dense(784, activation='tanh')  # Output is 28x28 = 784 dimensions
    ])
    return model

# Define the Discriminator model
def build_discriminator():
    model = models.Sequential([
        layers.Dense(512, input_dim=784, activation='relu'),
        layers.Dense(256, activation='relu'),
        layers.Dense(1, activation='sigmoid')  # Output is 0 or 1
    ])
    return model

# Build the models
generator = build_generator()
discriminator = build_discriminator()

# Compile the Discriminator model
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Define the GAN model (connecting Generator and Discriminator)
discriminator.trainable = False
gan_input = layers.Input(shape=(100,))
generated_image = generator(gan_input)
gan_output = discriminator(generated_image)
gan = models.Model(gan_input, gan_output)
gan.compile(optimizer='adam', loss='binary_crossentropy')

# Prepare data
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = (x_train.reshape(-1, 784).astype('float32') - 127.5) / 127.5  # Normalize

# Train the GAN
def train_gan(epochs, batch_size):
    for epoch in range(epochs):
        # Sample random real images
        idx = np.random.randint(0, x_train.shape[0], batch_size)
        real_images = x_train[idx]

        # Generate fake images
        noise = np.random.normal(0, 1, (batch_size, 100))
        fake_images = generator.predict(noise)

        # Train the Discriminator
        d_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
        d_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1)))

        # Train the Generator
        noise = np.random.normal(0, 1, (batch_size, 100))
        g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))

        # Display progress
        if epoch % 1000 == 0:
            print(f"{epoch} [D loss: {0.5 * np.add(d_loss_real, d_loss_fake)[0]}] [G loss: {g_loss}]")

# Run the GAN training
train_gan(epochs=10000, batch_size=64)

Code Explanation

  • Generator Model: Takes 100-dimensional random noise as input and generates a 28×28 pixel image representing handwritten digits.
  • Discriminator Model: Takes a 28×28 image as input and classifies it as real or fake.
  • GAN Model: Combines the generator and discriminator models to train the generator, aiming to fool the discriminator into believing generated images are real.

Training Flow

  1. The generator creates fake images from random noise.
  2. The discriminator receives both real and fake images and learns to distinguish between them.
  3. The generator is trained to improve its ability to create realistic images.

Applications of GANs

GANs have numerous applications beyond simple image generation. Below are a few examples:

1. Image Super-Resolution (SRGAN)

GANs can be used to generate high-resolution images from low-resolution ones. SRGAN is used in smartphones and cameras to upscale images while maintaining high quality.

2. Image Style Transfer (CycleGAN)

CycleGAN allows for style transformation, such as converting a photograph into a painting. This technique is used in artistic image generation or in converting day scenes to night scenes.

3. Data Augmentation

In scenarios where datasets are limited, GANs can generate synthetic data to improve model training. For example, GANs can be used to generate medical images for training AI models to recognize rare conditions.

Summary

In this episode, we explored image generation using GANs, covering the basic structure, implementation, and applications of GANs. GANs offer a powerful way to generate realistic data, with applications ranging from image enhancement to artistic style transfers and data augmentation.

Next Episode Preview

In the next episode, we will dive into style transfer, a technique for altering the style of an image using GANs. This exciting application of GANs opens new possibilities in image processing.


Notes

  • Discriminator: In a GAN, the model that learns to differentiate between real and generated data.
  • Generator: The model in a GAN that creates synthetic data from random noise.
Let's share this post !

Author of this article

株式会社PROMPTは生成AIに関する様々な情報を発信しています。
記事にしてほしいテーマや調べてほしいテーマがあればお問合せフォームからご連絡ください。
---
PROMPT Inc. provides a variety of information related to generative AI.
If there is a topic you would like us to write an article about or research, please contact us using the inquiry form.

Comments

To comment

TOC