MENU

[AI from Scratch] Episode 283: Image Classification Using Transfer Learning

TOC

Recap and Today’s Theme

Hello! In the previous episode, we discussed data augmentation, which improves model performance even with limited data by using techniques such as rotation, shifting, and noise addition.

Today, we will dive into Transfer Learning for image classification. Transfer learning utilizes pre-trained models, making it a highly efficient approach in deep learning, especially when training on small datasets. This episode explains the basic concepts of transfer learning and demonstrates its implementation using pre-trained models.

What is Transfer Learning?

Transfer Learning is a technique where a model trained on a large dataset is adapted for a new task. Instead of starting from scratch, transfer learning leverages the feature extraction capabilities of a pre-trained model to efficiently learn new tasks.

Benefits of Transfer Learning

  • Reduced Training Time: Using a pre-trained model significantly reduces training time since there is no need to train the model from scratch.
  • High Accuracy with Limited Data: Pre-trained models, having been trained on large datasets, can achieve high accuracy even with smaller datasets for the new task.
  • Reusing Powerful Models: Popular models like VGG, ResNet, and Inception are trained to capture general image features, making them highly adaptable to different tasks.

Popular Pre-Trained Models

Here are some commonly used pre-trained models:

  1. VGG (Visual Geometry Group): A simple yet powerful model used for image classification.
  2. ResNet (Residual Network): Known for its ability to train very deep networks effectively.
  3. Inception: A model that uses different-sized filters in parallel, capturing multiple features simultaneously.

Implementing Image Classification with Transfer Learning

Let’s implement transfer learning for image classification using a pre-trained model, VGG16, with Keras and TensorFlow.

1. Installing Required Libraries

First, install the tensorflow library:

pip install tensorflow

2. Transfer Learning Implementation with VGG16

The following code demonstrates how to use the pre-trained VGG16 model for transfer learning:

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import VGG16
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam

# Load pre-trained VGG16 model without the top layers
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3))

# Freeze the base model layers
base_model.trainable = False

# Build a new model for transfer learning
model = models.Sequential([
    base_model,
    layers.Flatten(),
    layers.Dense(256, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(1, activation='sigmoid')  # For binary classification
])

# Compile the model
model.compile(optimizer=Adam(),
              loss='binary_crossentropy',
              metrics=['accuracy'])

# Data augmentation for training
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True
)

test_datagen = ImageDataGenerator(rescale=1./255)

# Load training and validation data
train_generator = train_datagen.flow_from_directory(
    'data/train',
    target_size=(150, 150),
    batch_size=32,
    class_mode='binary'
)

validation_generator = test_datagen.flow_from_directory(
    'data/validation',
    target_size=(150, 150),
    batch_size=32,
    class_mode='binary'
)

# Train the model
history = model.fit(
    train_generator,
    epochs=10,
    validation_data=validation_generator
)

# Plot training history
import matplotlib.pyplot as plt

plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

Explanation

  • VGG16 Pre-trained Model: The VGG16 model is loaded without its fully connected layers (top layers) using the include_top=False argument.
  • Freezing Layers: By setting base_model.trainable = False, the pre-trained layers are frozen, meaning they won’t be updated during training.
  • New Layers: New fully connected layers are added to adapt the model to a binary classification task.
  • Data Augmentation: The training data is augmented to increase variation and improve model generalization.
  • Training: The model is trained on a new dataset using the pre-trained VGG16 model’s feature extraction capabilities.

3. Evaluating and Fine-Tuning the Model

After training, the model can be further evaluated and fine-tuned:

# Evaluate the model on the validation data
validation_loss, validation_acc = model.evaluate(validation_generator, verbose=2)
print(f"Validation accuracy: {validation_acc:.4f}")

Fine-Tuning: In the next step, you can fine-tune some layers of the pre-trained model to improve accuracy further by unfreezing certain layers and allowing them to be trained on the new dataset.

Benefits and Applications of Transfer Learning

1. High Accuracy with Small Datasets

Transfer learning enables high accuracy even with small datasets. This is because the pre-trained model has already learned general image features that can be applied to the new task.

2. Medical Image Analysis

Transfer learning is frequently used in medical applications, where labeled data is often scarce. Pre-trained models like ResNet are adapted for tasks like disease diagnosis using X-ray or MRI images.

3. Object Detection in Autonomous Vehicles

In autonomous vehicles, transfer learning is used to detect objects like road signs, pedestrians, and other vehicles in real-time, using models pre-trained on large datasets.

Challenges and Future of Transfer Learning

Challenges

  • Choosing the Right Pre-Trained Model: Not all pre-trained models are suitable for every task. For specialized fields (e.g., medical or satellite images), the pre-trained model may not generalize well.
  • Overfitting: When the dataset is too small, the model might overfit. In such cases, regularization techniques like dropout and data augmentation are crucial.

Future Trends

  • Larger Pre-Trained Models: Just as BERT and GPT have revolutionized NLP, larger pre-trained models for image recognition are emerging, offering improved accuracy and versatility.
  • Multi-Modal Transfer Learning: Transfer learning across multiple modalities, such as combining images with text or audio, is becoming more prominent in complex tasks.

Summary

In this episode, we explored transfer learning and demonstrated how to use a pre-trained model for image classification. Transfer learning allows for fast and efficient training of high-accuracy models, even with limited data. Next time, we will cover Fine-Tuning, a technique to further enhance the model by retraining certain layers.

Next Episode Preview

Next time, we will dive into Fine-Tuning to further improve model performance by selectively retraining parts of the pre-trained model.


Notes

  • VGG16: A popular pre-trained model developed by the Visual Geometry Group.
  • ResNet: A deep learning model that uses residual connections to train very deep networks effectively.
Let's share this post !

Author of this article

株式会社PROMPTは生成AIに関する様々な情報を発信しています。
記事にしてほしいテーマや調べてほしいテーマがあればお問合せフォームからご連絡ください。
---
PROMPT Inc. provides a variety of information related to generative AI.
If there is a topic you would like us to write an article about or research, please contact us using the inquiry form.

Comments

To comment

TOC