MENU

[AI from Scratch] Episode 281: Building an Image Classification Model Using CNN

TOC

Recap and Today’s Theme

Hello! In the previous episode, we explained feature extraction methods like SIFT, SURF, and ORB for detecting key points in images. While these techniques are effective for object recognition and pattern matching, they require manual feature extraction.

Today, we will explore how to build an image classification model using Convolutional Neural Networks (CNNs). CNNs automatically extract features from image data and are widely used for classification and recognition tasks. In this episode, we’ll cover the basic structure of CNNs and how to implement a CNN for image classification using Python.

What is Image Classification?

Image classification refers to the task of categorizing an input image into a specific class. Examples include:

  • Animal classification: Categorizing images into categories like cat, dog, or bird.
  • Handwritten digit classification: Categorizing handwritten digits into classes from 0 to 9.

Applications of Image Classification

  • Autonomous Vehicles: Classifying objects like road signs, other vehicles, and pedestrians in real-time.
  • Medical Image Analysis: Detecting abnormalities in X-ray or MRI images.
  • Security Systems: Detecting and classifying specific people or objects in surveillance footage.

What is a CNN (Convolutional Neural Network)?

CNN (Convolutional Neural Network) is a type of deep learning model that automatically extracts features from image data and performs classification based on those features. Unlike traditional methods, where features must be manually set, CNNs learn to extract relevant features from large datasets.

Basic Structure of CNN

CNNs typically consist of the following layers:

  1. Convolutional Layer: Applies filters (kernels) to the input image, creating feature maps that detect edges, patterns, and textures.
  2. Pooling Layer: Downsamples the feature maps to reduce image size and computation cost, improving robustness to noise and distortion.
  3. Fully Connected Layer: Flattens the extracted features into a vector and performs classification based on learned features. The class probabilities are calculated here.

Implementing an Image Classification Model Using CNN

Let’s now build a CNN model for classifying handwritten digits from the MNIST dataset, a commonly used dataset for image classification.

1. Installing Required Libraries

First, install the tensorflow library:

pip install tensorflow

2. Building and Training the CNN Model

Here’s the code for building and training a CNN model using TensorFlow and Keras to classify the MNIST dataset:

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Reshape and normalize the data
x_train = x_train.reshape((x_train.shape[0], 28, 28, 1)).astype('float32') / 255
x_test = x_test.reshape((x_test.shape[0], 28, 28, 1)).astype('float32') / 255

# Build the CNN model
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(10, activation='softmax')
])

# Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Train the model
history = model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

# Plot the training results
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

Explanation:

  • Data Loading: Loads the MNIST dataset and splits it into training and testing sets.
  • Data Normalization: Normalizes pixel values (0-255) to a 0-1 range for better model performance.
  • CNN Model: A CNN with three convolutional layers and two pooling layers is built, followed by fully connected layers for classification.
  • Training: The model is trained for five epochs.
  • Plotting Accuracy: The accuracy of the model is plotted to visualize the training progress.

3. Evaluating the Model

After training, the model is evaluated on the test data to check its accuracy:

# Evaluate the model on the test data
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"Test accuracy: {test_acc:.4f}")

The model achieves high accuracy in classifying handwritten digits, with a test accuracy of around 98%.

Improving the CNN Model

1. Data Augmentation

To improve the model’s performance, data augmentation is often applied. Data augmentation involves creating variations of the original images (e.g., rotations, shifts, zooms) to increase the diversity of the training data, improving the model’s generalization.

2. Transfer Learning

Transfer learning allows you to use a pre-trained model (e.g., VGG, ResNet) and adapt it to new datasets. This is effective for achieving high accuracy in a short time with limited data.

3. Using Different Architectures

Advanced CNN architectures like ResNet or Inception can be used to handle more complex classification tasks. These architectures include deeper layers and different strategies for feature extraction, which can enhance performance.

Applications of CNNs in Image Classification

1. Object Detection for Autonomous Vehicles

CNNs are used in self-driving cars to recognize road signs, other vehicles, and pedestrians, helping the vehicle make decisions for safe driving.

2. Medical Diagnosis Support Systems

CNNs help analyze MRI and X-ray images to detect abnormalities, supporting early diagnosis in the medical field.

3. Surveillance Systems

CNNs are used in security systems to detect and classify objects or individuals in surveillance footage, enhancing the effectiveness of monitoring.

Summary

In this episode, we explored how to build an image classification model using CNNs. CNNs are a powerful tool for automatically extracting features from image data and classifying objects or patterns. Next time, we will dive into data augmentation, learning how to enhance the performance of models by increasing the variety of training data.

Next Episode Preview

In the next episode, we will cover data augmentation in practice, learning how to increase the size and diversity of image data to improve model accuracy.


Notes

  • Convolutional Layer: A layer in CNN that extracts local features like edges and patterns.
  • Pooling Layer: A layer that downsamples feature maps to reduce computation and improve robustness.
Let's share this post !

Author of this article

株式会社PROMPTは生成AIに関する様々な情報を発信しています。
記事にしてほしいテーマや調べてほしいテーマがあればお問合せフォームからご連絡ください。
---
PROMPT Inc. provides a variety of information related to generative AI.
If there is a topic you would like us to write an article about or research, please contact us using the inquiry form.

Comments

To comment

TOC