【0から学ぶAI】第232回:転移学習の実践 〜事前学習済みモデルを活用する方法を紹介

目次

前回のおさらいと今回のテーマ

こんにちは!前回は、データ拡張の実践について学び、Kerasを使って画像データの拡張方法を紹介しました。データ拡張により、モデルがより多様なデータを学習しやすくなり、過学習を防ぐ効果も確認できました。

今回は、転移学習(Transfer Learning)について解説します。転移学習は、すでに訓練済みのモデル(事前学習済みモデル)を活用して新たなタスクに適用する方法です。この技術は、少ないデータでも高精度なモデルを構築する際に非常に有効です。特に、画像認識分野で多くの事前学習済みモデルが利用されており、効率的に高性能なモデルを構築できます。それでは、実際にKerasを使って転移学習の方法を見ていきましょう。

転移学習とは?

転移学習(Transfer Learning)は、他のデータセットやタスクで学習されたモデルの知識を、新たなタスクに活用する手法です。事前学習済みモデルは、通常、大規模なデータセット(例:ImageNet)で訓練されており、すでに画像の特徴を深く学習しています。そのため、これらのモデルを利用することで、少ないデータでも高精度な分類や認識が可能になります。

転移学習のメリット

  1. 高速な学習: すでに多くのパラメータが最適化されているため、新たなデータセットに対して迅速に適応できます。
  2. 少ないデータでも高精度: 十分なデータがない場合でも、事前学習済みモデルを使用することで精度を向上させられます。
  3. 計算リソースの節約: モデルの構築や訓練にかかる時間とリソースを大幅に削減できます。

転移学習の実装

それでは、Kerasを使って実際に転移学習を実装してみましょう。今回は、事前学習済みのMobileNetV2を利用し、CIFAR-10データセットに適用して画像分類を行います。

1. 必要なライブラリのインポート

まず、TensorFlowとKerasのライブラリをインポートします。

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2
import matplotlib.pyplot as plt

2. データセットの準備

Kerasに組み込まれているCIFAR-10データセットを読み込み、画像データを正規化します。

# CIFAR-10データセットの読み込み
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# 画像データを0-1にスケーリング
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

3. 事前学習済みモデルの読み込みと設定

次に、Kerasが提供するMobileNetV2モデルを読み込みます。ここでは、事前学習済みの重み(imagenet)を使用し、最終層を除去して特徴抽出部分のみを使用します。

# 事前学習済みのMobileNetV2モデルの読み込み
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(32, 32, 3))

# ベースモデルの凍結(学習を行わない)
base_model.trainable = False
  • weights=’imagenet’: ImageNetデータセットで訓練された重みを使用します。
  • include_top=False: 最終の全結合層を除外し、特徴抽出部分のみを使用します。
  • trainable=False: ベースモデルの重みを固定し、再訓練しないように設定します。

4. 新しい出力層の追加

ベースモデルの上に、新たな分類層を追加します。これにより、CIFAR-10の10クラスに対応するモデルを構築します。

# モデルの定義
model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])
  • GlobalAveragePooling2D: 全結合層の前に使用するプーリング層で、特徴マップの空間次元を平均化します。
  • Dense: 全結合層を追加し、10クラスの分類を行います。

5. モデルのコンパイル

次に、モデルをコンパイルします。

# モデルのコンパイル
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
  • optimizer=’adam’: 最適化アルゴリズムにAdamを使用します。
  • loss=’sparse_categorical_crossentropy’: CIFAR-10のラベルは整数のため、sparse_categorical_crossentropyを使用します。
  • metrics=[‘accuracy’]: モデルの評価指標として正解率を使用します。

6. モデルの訓練

データセットを使ってモデルを訓練します。

# モデルの訓練
history = model.fit(x_train, y_train, epochs=10, batch_size=64, validation_split=0.2)
  • epochs=10: モデルを10エポックにわたって訓練します。
  • batch_size=64: バッチサイズを64に設定し、訓練速度を最適化します。
  • validation_split=0.2: 訓練データの20%を検証データとして使用します。

7. モデルの評価

テストデータを使って、モデルの性能を評価します。

# モデルの評価
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_accuracy:.2f}")

8. 微調整(Fine-Tuning)

モデルの精度をさらに向上させたい場合、ベースモデルの一部を微調整(ファインチューニング)することが可能です。ファインチューニングでは、ベースモデルの一部の層を再度学習させ、新しいデータセットに適応させます。

# 特定の層以降を再度訓練可能にする
base_model.trainable = True

# 微調整のための再コンパイル
model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),  # 微調整のため学習率を小さく設定
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 微調整の訓練
history_fine = model.fit(x_train, y_train, epochs=5, batch_size=64, validation_split=0.2)
  • trainable=True: ベースモデル全体または一部の層を再訓練可能に設定します。
  • 学習率を低く設定: 微調整では既存の重みを大きく変更しないよう、学習率を低めに設定します。

転移学習の活用事例

転移学習は、以下のような様々な分野で活用されています。

  1. 画像分類: 少量のデータしかない場合でも、ImageNetなどの事前学習済みモデルを使用して高精度な画像分類が可能です。
  2. 物体検出: 物体の位置や種類を特定するタスクにも、転移学習で効率的にモデルを構築できます。
  3. 自然言語処理(NLP): BERTやGPTのような事前学習済みのモデルを使い、テキスト分類や文章生成、機械翻訳などのタスクで活用されています。

まとめ

今回は、Kerasを使った転移学習について解説しました。事前学習済みのモデルを活用することで、少ないデータでも高精度なモデルを迅速に構築できる

ことがわかりました。転移学習はディープラーニングの強力なツールであり、特に画像認識や自然言語処理の分野で広く利用されています。これを活用して、効率的にモデル開発を進めていきましょう!

次回予告

次回は、モデルのデプロイ方法として、学習済みモデルを実際のアプリケーションで使用する方法について解説します。実際の環境でモデルを動かし、予測を行うまでの流れを学びましょう!


注釈

  • 事前学習済みモデル: 大規模なデータセットであらかじめ訓練されたモデル。新しいタスクに転用することで、少ないデータで高精度を実現。
  • ファインチューニング: 既存のモデルの一部の層を再度訓練し、新しいタスクに適応させる技術。
よかったらシェアしてね!
  • URLをコピーしました!
  • URLをコピーしました!

この記事を書いた人

株式会社PROMPTは生成AIに関する様々な情報を発信しています。
記事にしてほしいテーマや調べてほしいテーマがあればお問合せフォームからご連絡ください。
---
PROMPT Inc. provides a variety of information related to generative AI.
If there is a topic you would like us to write an article about or research, please contact us using the inquiry form.

コメント

コメントする

目次