こちらのコンテンツは最新の情報ではない可能性があります。無料で学べる最新のコンテンツは Python&機械学習入門コース 脱ブラックボックスコース をぜひご利用ください。

データ拡張

ディープラーニングの世界では大量のデータが必要であることが前提となるため、欲しいデータが潤沢にあれば良いのですが、実現場ではなかなか求めているデータ数を集めることは思ったようにいかないケースが多いです。

そういった場合に、学習データの画像に対して移動、回転、拡大・縮小など人工的な操作を加えることでデータ数を水増しするテクニックがあります。水増しされることで同じ画像が学習されることが少なくなるので汎化性能が向上されることが期待されます。

本章では、水増しテクニックである データ拡張 (Data Augmentation) の代表的な処理を確認したうえで、適用前後で精度がどのように変化するかを確認します。

TensorFlow では、tensorflow.keras.preprocess.image.ImageDataGenerator に様々な水増しのメソッドが用意されています。

通常の学習では、データセットから指定した枚数だけ画像を選択し、ミニバッチを作成します。 一方、 ImageDataGenerator を使用すると、画像を選択したあと、各画像にデータ拡張を行い、ミニバッチを作成します。どのような処理をおこなうかはインスタンス生成時の引数で指定することができ、変換はリアルタイムでおこなられるため、保存するわけではないのでディスク容量を圧迫する心配はありません。

代表的な処理として、以下があげられます。

  • 回転
  • 水平移動
  • 拡大
  • せん断
  • 水平反転
  • 垂直反転

本章の流れ

  • ベースモデルの作成
  • 各処理の確認
  • 各処理適用後の画像を保存
  • データ拡張による精度の確認

ベースモデルの作成

前章と同じように、まずはベースモデルを作成しましょう。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
      

データセットの準備

本章でも CIFAR10 と呼ばれる 10 クラス分類を行います。tf.keras.datasets にデータセットが用意されています。

(x_train, t_train), (x_test, t_test) = tf.keras.datasets.cifar10.load_data()
      

それでは、今回扱うデータを 25 枚ランダムに抜粋して表示します。

正解ラベル 種別
0 airplane
1 automobile
2 bird
3 cat
4 deer
5 dog
6 frog
7 horse
8 ship
9 truck

10 クラス分類となっており、上記の表の種別を分類することが目標です。32×3232 \times 32 と低解像度なところも、CIFAR10 がよく画像の練習問題として扱われる理由のひとつです。

# データをプロット
plt.figure(figsize=(12,12))
for i in range(25):
    plt.subplot(5, 5, i+1)
    plt.imshow(x_train[i])
      
<Figure size 864x864 with 25 Axes>
# 正規化
x_train = x_train / 255.0
x_test = x_test / 255.0
      
x_train.shape, x_test.shape, t_train.shape, t_test.shape
      
((50000, 32, 32, 3), (10000, 32, 32, 3), (50000, 1), (10000, 1))

モデルの定義と学習

import os
import random

def reset_seed(seed=0):

    os.environ['PYTHONHASHSEED'] = '0'
    random.seed(seed) # random関数のシードを固定
    np.random.seed(seed) #numpyのシードを固定
    tf.random.set_seed(seed) #tensorflowのシードを固定
      
from tensorflow.keras import models, layers
      
# シードの固定
reset_seed(0)

# モデル構築
model = models.Sequential([
    layers.Conv2D(32, (3, 3), padding='same', activation='relu', input_shape=(32, 32, 3)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
]) 
  
# optimizerの設定
optimizer = tf.keras.optimizers.Adam(lr=1e-3)

# モデルのコンパイル
model.compile(loss='sparse_categorical_crossentropy', 
              optimizer=optimizer,
              metrics=['accuracy'])

model.summary()
      
Model: "sequential_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_9 (Conv2D) (None, 32, 32, 32) 896 _________________________________________________________________ max_pooling2d_9 (MaxPooling2 (None, 16, 16, 32) 0 _________________________________________________________________ conv2d_10 (Conv2D) (None, 16, 16, 64) 18496 _________________________________________________________________ max_pooling2d_10 (MaxPooling (None, 8, 8, 64) 0 _________________________________________________________________ conv2d_11 (Conv2D) (None, 8, 8, 128) 73856 _________________________________________________________________ max_pooling2d_11 (MaxPooling (None, 4, 4, 128) 0 _________________________________________________________________ flatten_3 (Flatten) (None, 2048) 0 _________________________________________________________________ dense_6 (Dense) (None, 128) 262272 _________________________________________________________________ dense_7 (Dense) (None, 10) 1290 ================================================================= Total params: 356,810 Trainable params: 356,810 Non-trainable params: 0 _________________________________________________________________
# 学習の詳細設定
batch_size = 1024
epochs = 100

# 学習の実行
history = model.fit(x_train, t_train, 
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_test, t_test))
      
Train on 50000 samples, validate on 10000 samples Epoch 1/100 50000/50000 [==============================] - 3s 68us/sample - loss: 1.9287 - accuracy: 0.3006 - val_loss: 1.6475 - val_accuracy: 0.4054 Epoch 2/100 50000/50000 [==============================] - 3s 59us/sample - loss: 1.5343 - accuracy: 0.4490 - val_loss: 1.4547 - val_accuracy: 0.4773 Epoch 3/100 50000/50000 [==============================] - 3s 59us/sample - loss: 1.3860 - accuracy: 0.5064 - val_loss: 1.3851 - val_accuracy: 0.5217 Epoch 4/100 50000/50000 [==============================] - 3s 59us/sample - loss: 1.2859 - accuracy: 0.5461 - val_loss: 1.2407 - val_accuracy: 0.5622 Epoch 5/100 50000/50000 [==============================] - 3s 59us/sample - loss: 1.1936 - accuracy: 0.5816 - val_loss: 1.1649 - val_accuracy: 0.5840 Epoch 6/100 50000/50000 [==============================] - 3s 59us/sample - loss: 1.1268 - accuracy: 0.6079 - val_loss: 1.1075 - val_accuracy: 0.6076 Epoch 7/100 50000/50000 [==============================] - 3s 59us/sample - loss: 1.0628 - accuracy: 0.6314 - val_loss: 1.0692 - val_accuracy: 0.6276 Epoch 8/100 50000/50000 [==============================] - 3s 59us/sample - loss: 1.0319 - accuracy: 0.6411 - val_loss: 1.0387 - val_accuracy: 0.6347 Epoch 9/100 50000/50000 [==============================] - 3s 58us/sample - loss: 0.9722 - accuracy: 0.6635 - val_loss: 1.0511 - val_accuracy: 0.6362 Epoch 10/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.9536 - accuracy: 0.6701 - val_loss: 0.9851 - val_accuracy: 0.6547 Epoch 11/100 50000/50000 [==============================] - 3s 60us/sample - loss: 0.9105 - accuracy: 0.6865 - val_loss: 0.9589 - val_accuracy: 0.6674 Epoch 12/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.8707 - accuracy: 0.7004 - val_loss: 0.9561 - val_accuracy: 0.6704 Epoch 13/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.8622 - accuracy: 0.7026 - val_loss: 0.9228 - val_accuracy: 0.6805 Epoch 14/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.8234 - accuracy: 0.7162 - val_loss: 0.9441 - val_accuracy: 0.6715 Epoch 15/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.8007 - accuracy: 0.7242 - val_loss: 0.8914 - val_accuracy: 0.6911 Epoch 16/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.7790 - accuracy: 0.7312 - val_loss: 0.9086 - val_accuracy: 0.6850 Epoch 17/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.7621 - accuracy: 0.7375 - val_loss: 0.9128 - val_accuracy: 0.6908 Epoch 18/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.7462 - accuracy: 0.7431 - val_loss: 0.9140 - val_accuracy: 0.6931 Epoch 19/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.7308 - accuracy: 0.7491 - val_loss: 0.8777 - val_accuracy: 0.7023 Epoch 20/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.7033 - accuracy: 0.7563 - val_loss: 0.8611 - val_accuracy: 0.7060 Epoch 21/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.6772 - accuracy: 0.7670 - val_loss: 0.8612 - val_accuracy: 0.7032 Epoch 22/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.6662 - accuracy: 0.7729 - val_loss: 0.8458 - val_accuracy: 0.7133 Epoch 23/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.6520 - accuracy: 0.7767 - val_loss: 0.8377 - val_accuracy: 0.7166 Epoch 24/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.6227 - accuracy: 0.7876 - val_loss: 0.8411 - val_accuracy: 0.7147 Epoch 25/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.6168 - accuracy: 0.7893 - val_loss: 0.8412 - val_accuracy: 0.7184 Epoch 26/100 50000/50000 [==============================] - 3s 58us/sample - loss: 0.5910 - accuracy: 0.7978 - val_loss: 0.8380 - val_accuracy: 0.7215 Epoch 27/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.5771 - accuracy: 0.8049 - val_loss: 0.8261 - val_accuracy: 0.7234 Epoch 28/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.5684 - accuracy: 0.8065 - val_loss: 0.8381 - val_accuracy: 0.7211 Epoch 29/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.5535 - accuracy: 0.8105 - val_loss: 0.8604 - val_accuracy: 0.7151 Epoch 30/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.5467 - accuracy: 0.8119 - val_loss: 0.8512 - val_accuracy: 0.7232 Epoch 31/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.5244 - accuracy: 0.8220 - val_loss: 0.8399 - val_accuracy: 0.7279 Epoch 32/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.4968 - accuracy: 0.8315 - val_loss: 0.8358 - val_accuracy: 0.7275 Epoch 33/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.4889 - accuracy: 0.8331 - val_loss: 0.8295 - val_accuracy: 0.7293 Epoch 34/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.4744 - accuracy: 0.8375 - val_loss: 0.8462 - val_accuracy: 0.7295 Epoch 35/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.4648 - accuracy: 0.8412 - val_loss: 0.8363 - val_accuracy: 0.7327 Epoch 36/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.4546 - accuracy: 0.8449 - val_loss: 0.8516 - val_accuracy: 0.7319 Epoch 37/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.4310 - accuracy: 0.8541 - val_loss: 0.8655 - val_accuracy: 0.7255 Epoch 38/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.4275 - accuracy: 0.8541 - val_loss: 0.8434 - val_accuracy: 0.7361 Epoch 39/100 50000/50000 [==============================] - 3s 58us/sample - loss: 0.4097 - accuracy: 0.8620 - val_loss: 0.8724 - val_accuracy: 0.7264 Epoch 40/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.4007 - accuracy: 0.8642 - val_loss: 0.8958 - val_accuracy: 0.7282 Epoch 41/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.3807 - accuracy: 0.8707 - val_loss: 0.8994 - val_accuracy: 0.7255 Epoch 42/100 50000/50000 [==============================] - 3s 58us/sample - loss: 0.3626 - accuracy: 0.8783 - val_loss: 0.9011 - val_accuracy: 0.7292 Epoch 43/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.3638 - accuracy: 0.8778 - val_loss: 0.9024 - val_accuracy: 0.7318 Epoch 44/100 50000/50000 [==============================] - 3s 60us/sample - loss: 0.3440 - accuracy: 0.8845 - val_loss: 0.9128 - val_accuracy: 0.7286 Epoch 45/100 50000/50000 [==============================] - 3s 58us/sample - loss: 0.3428 - accuracy: 0.8839 - val_loss: 0.9162 - val_accuracy: 0.7300 Epoch 46/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.3198 - accuracy: 0.8938 - val_loss: 0.9415 - val_accuracy: 0.7285 Epoch 47/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.3069 - accuracy: 0.8972 - val_loss: 0.9541 - val_accuracy: 0.7289 Epoch 48/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.2950 - accuracy: 0.9019 - val_loss: 0.9638 - val_accuracy: 0.7293 Epoch 49/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.2866 - accuracy: 0.9058 - val_loss: 0.9720 - val_accuracy: 0.7304 Epoch 50/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.2700 - accuracy: 0.9118 - val_loss: 0.9801 - val_accuracy: 0.7278 Epoch 51/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.2776 - accuracy: 0.9073 - val_loss: 0.9913 - val_accuracy: 0.7296 Epoch 52/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.2492 - accuracy: 0.9188 - val_loss: 1.0304 - val_accuracy: 0.7256 Epoch 53/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.2451 - accuracy: 0.9196 - val_loss: 1.0506 - val_accuracy: 0.7218 Epoch 54/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.2322 - accuracy: 0.9252 - val_loss: 1.0457 - val_accuracy: 0.7271 Epoch 55/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.2224 - accuracy: 0.9277 - val_loss: 1.0932 - val_accuracy: 0.7224 Epoch 56/100 50000/50000 [==============================] - 3s 58us/sample - loss: 0.2143 - accuracy: 0.9295 - val_loss: 1.0800 - val_accuracy: 0.7311 Epoch 57/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.1929 - accuracy: 0.9402 - val_loss: 1.1001 - val_accuracy: 0.7311 Epoch 58/100 50000/50000 [==============================] - 3s 58us/sample - loss: 0.1858 - accuracy: 0.9423 - val_loss: 1.1281 - val_accuracy: 0.7297 Epoch 59/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.1729 - accuracy: 0.9472 - val_loss: 1.1377 - val_accuracy: 0.7269 Epoch 60/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.1719 - accuracy: 0.9460 - val_loss: 1.1467 - val_accuracy: 0.7285 Epoch 61/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.1613 - accuracy: 0.9504 - val_loss: 1.1951 - val_accuracy: 0.7267 Epoch 62/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.1494 - accuracy: 0.9562 - val_loss: 1.2144 - val_accuracy: 0.7199 Epoch 63/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.1556 - accuracy: 0.9514 - val_loss: 1.2419 - val_accuracy: 0.7226 Epoch 64/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.1352 - accuracy: 0.9602 - val_loss: 1.2575 - val_accuracy: 0.7235 Epoch 65/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.1369 - accuracy: 0.9585 - val_loss: 1.2704 - val_accuracy: 0.7291 Epoch 66/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.1200 - accuracy: 0.9657 - val_loss: 1.3055 - val_accuracy: 0.7240 Epoch 67/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.1160 - accuracy: 0.9660 - val_loss: 1.3385 - val_accuracy: 0.7214 Epoch 68/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.1104 - accuracy: 0.9685 - val_loss: 1.3640 - val_accuracy: 0.7206 Epoch 69/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.1102 - accuracy: 0.9675 - val_loss: 1.3694 - val_accuracy: 0.7202 Epoch 70/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0929 - accuracy: 0.9760 - val_loss: 1.4323 - val_accuracy: 0.7155 Epoch 71/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0925 - accuracy: 0.9752 - val_loss: 1.4741 - val_accuracy: 0.7194 Epoch 72/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0878 - accuracy: 0.9760 - val_loss: 1.4473 - val_accuracy: 0.7205 Epoch 73/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0780 - accuracy: 0.9804 - val_loss: 1.4946 - val_accuracy: 0.7222 Epoch 74/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0627 - accuracy: 0.9866 - val_loss: 1.5408 - val_accuracy: 0.7170 Epoch 75/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0593 - accuracy: 0.9881 - val_loss: 1.5517 - val_accuracy: 0.7175 Epoch 76/100 50000/50000 [==============================] - 3s 58us/sample - loss: 0.0692 - accuracy: 0.9822 - val_loss: 1.6181 - val_accuracy: 0.7135 Epoch 77/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0733 - accuracy: 0.9800 - val_loss: 1.6135 - val_accuracy: 0.7178 Epoch 78/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0586 - accuracy: 0.9864 - val_loss: 1.6297 - val_accuracy: 0.7182 Epoch 79/100 50000/50000 [==============================] - 3s 60us/sample - loss: 0.0427 - accuracy: 0.9929 - val_loss: 1.6640 - val_accuracy: 0.7215 Epoch 80/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0360 - accuracy: 0.9955 - val_loss: 1.6898 - val_accuracy: 0.7204 Epoch 81/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0362 - accuracy: 0.9944 - val_loss: 1.7405 - val_accuracy: 0.7179 Epoch 82/100 50000/50000 [==============================] - 3s 60us/sample - loss: 0.0347 - accuracy: 0.9944 - val_loss: 1.7415 - val_accuracy: 0.7189 Epoch 83/100 50000/50000 [==============================] - 3s 60us/sample - loss: 0.0279 - accuracy: 0.9971 - val_loss: 1.7960 - val_accuracy: 0.7181 Epoch 84/100 50000/50000 [==============================] - 3s 58us/sample - loss: 0.0237 - accuracy: 0.9980 - val_loss: 1.8378 - val_accuracy: 0.7172 Epoch 85/100 50000/50000 [==============================] - 3s 58us/sample - loss: 0.0236 - accuracy: 0.9980 - val_loss: 1.8625 - val_accuracy: 0.7187 Epoch 86/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0202 - accuracy: 0.9985 - val_loss: 1.8809 - val_accuracy: 0.7198 Epoch 87/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0203 - accuracy: 0.9984 - val_loss: 1.9538 - val_accuracy: 0.7124 Epoch 88/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0741 - accuracy: 0.9761 - val_loss: 1.9715 - val_accuracy: 0.7102 Epoch 89/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.1252 - accuracy: 0.9550 - val_loss: 1.8228 - val_accuracy: 0.7141 Epoch 90/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0474 - accuracy: 0.9872 - val_loss: 1.9177 - val_accuracy: 0.7168 Epoch 91/100 50000/50000 [==============================] - 3s 58us/sample - loss: 0.0353 - accuracy: 0.9919 - val_loss: 1.9268 - val_accuracy: 0.7190 Epoch 92/100 50000/50000 [==============================] - 3s 58us/sample - loss: 0.0203 - accuracy: 0.9980 - val_loss: 1.9481 - val_accuracy: 0.7217 Epoch 93/100 50000/50000 [==============================] - 3s 58us/sample - loss: 0.0137 - accuracy: 0.9995 - val_loss: 1.9906 - val_accuracy: 0.7189 Epoch 94/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0107 - accuracy: 0.9997 - val_loss: 2.0081 - val_accuracy: 0.7212 Epoch 95/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0091 - accuracy: 0.9999 - val_loss: 2.0391 - val_accuracy: 0.7186 Epoch 96/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0079 - accuracy: 0.9999 - val_loss: 2.0685 - val_accuracy: 0.7193 Epoch 97/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0073 - accuracy: 1.0000 - val_loss: 2.0944 - val_accuracy: 0.7195 Epoch 98/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0068 - accuracy: 1.0000 - val_loss: 2.1108 - val_accuracy: 0.7216 Epoch 99/100 50000/50000 [==============================] - 3s 59us/sample - loss: 0.0061 - accuracy: 1.0000 - val_loss: 2.1307 - val_accuracy: 0.7208 Epoch 100/100 50000/50000 [==============================] - 3s 58us/sample - loss: 0.0056 - accuracy: 1.0000 - val_loss: 2.1479 - val_accuracy: 0.7210
results = pd.DataFrame(history.history)
results.tail(1)
      
loss accuracy val_loss val_accuracy
99 0.005615 0.99998 2.147902 0.721
results[['loss', 'val_loss']].plot()
      
<matplotlib.axes._subplots.AxesSubplot at 0x7f4631195e80>
<Figure size 432x288 with 1 Axes>
Train Val
Base Accuracy 0.999 0.721
Base Loss 0.006 2.148

過学習を起こしてしまっていることがわかります。上記のスコアをベースラインとして、データ拡張を適用することで汎化性能が向上するか確認しましょう。

各処理の確認

具体的に実装する前に、以下の代表的な水増し処理を一つ一つ見てきましょう。

  • 回転
  • 水平移動
  • 拡大
  • せん断
  • 水平反転
  • 垂直反転

まずは、CIFAR10 の画像を 1 枚取り出し、サンプルイメージとします。

# サンプルイメージ
img = x_train[4]
      

変換をかけて、変換前と後の画像を表示する関数を作成します。何度も使う処理は先に関数化しておくと使い回せるため便利です。

# データ拡張のためのモジュール
from tensorflow.keras.preprocessing.image import ImageDataGenerator
      
# 処理後を可視化する関数
def show(img, datagen):
    # (batch_size, height, width, channel) に reshape する
    img_batch = img.reshape(1, 32, 32, 3)

    # datagen.flow() でデータセットから Augmentation 処理をかけながらミニバッチを読み込む
    # 今回は、1 サンプルのデータセットを batch_size=1 で読み込む。
    for img_augmented in datagen.flow(img_batch, batch_size=1):
        # batch_size の次元を削除
        out = img_augmented.reshape(32, 32, 3)
        break

    plt.figure(figsize=(10, 10))
    plt.subplot(1, 2, 1)
    plt.title('before')
    plt.imshow(img)
  
    plt.subplot(1, 2, 2)
    plt.title('after')
    plt.imshow(out)
      

回転

rotation_range 引数で指定します。

datagen = ImageDataGenerator(rotation_range=60)

show(img, datagen)
      
<Figure size 720x720 with 2 Axes>

水平移動

width_shift_range, height_shift_range で横方向あるいは縦方向への水平移動を指定します。

datagen = ImageDataGenerator(width_shift_range=0.5,
                             height_shift_range=0.5)

show(img, datagen)
      
<Figure size 720x720 with 2 Axes>

せん断

せん断は、四角形の画像を平行四辺形に変形する処理です。shear_range 引数で指定します。

datagen = ImageDataGenerator(shear_range=30)

show(img, datagen)
      
<Figure size 720x720 with 2 Axes>

拡大

zoom_range 引数で指定します。 例えば 0.5 を与えると、-0.5 ~ 0.5 の範囲でランダムに元のサイズに掛け算をして拡大縮小が行われます。

datagen = ImageDataGenerator(zoom_range=0.5)

show(img, datagen)
      
<Figure size 720x720 with 2 Axes>

水平反転

horizontal_flip 引数で指定します。

datagen = ImageDataGenerator(horizontal_flip=True)

show(img, datagen)
      
<Figure size 720x720 with 2 Axes>

垂直反転

vertical_flip 引数で指定します。

datagen = ImageDataGenerator(vertical_flip=True)

show(img, datagen)
      
<Figure size 720x720 with 2 Axes>

option : fill_mode

変形により画像に空白箇所ができてしまう際、その画素をどのようにして埋めるかどうかを fill_mode 引数で指定できます。以下の 4 つの値を受け取ることができます。初期値は nearest です。

  • constant
  • nearest
  • reflect
  • wrap
datagen = ImageDataGenerator(width_shift_range=0.5, fill_mode='constant')
show(img, datagen)
      
<Figure size 720x720 with 2 Axes>
datagen = ImageDataGenerator(width_shift_range=0.5, fill_mode='reflect')
show(img, datagen)
      
<Figure size 720x720 with 2 Axes>
datagen = ImageDataGenerator(width_shift_range=0.5, fill_mode='wrap')
show(img, datagen)
      
<Figure size 720x720 with 2 Axes>

各処理適用後の画像の保存

変換をかけた画像の保存は以下のように行います。

# 画像保存する関数
from tensorflow.keras.preprocessing import image
      

tensorflow.keras.preprocessing.image.save_img() で保存でき、引数は大きく 2 つ準備します。

  • path:保存先ディレクトリ
  • x:保存する画像

公式ドキュメントはこちらを参考にしてください。上記 2 つを設定し、保存しましょう。 今回は画像を 1 枚保存する方法を紹介します。

# サンプル画像の用意
img = x_train[4]

# 水増し処理を定義
datagen = ImageDataGenerator(vertical_flip=True)

# (batch_size, height, width, channel) に reshape する
img_batch = img.reshape(1, 32, 32, 3)

# 今回は 1 枚だけ保存します
max_img_num = 1
counts = 1
for img_augmented in datagen.flow(img_batch, batch_size=1):
    # batch_size の次元を削除
    img_augmented = img_augmented.reshape(32, 32, 3)
    # 画像を保存
    image.save_img('augmented_output.png', img_augmented)
    # max_img_num の枚数を保存したら終了
    if (counts % max_img_num) == 0:
        print('Finish!!')
        break
    counts += 1
      
Finish!!

これで保存完了です。複数枚保存したい場合でも、max_img_num の数を変更したら良いので使いまわしてください。保存先ディレクトリは任意の場所を指定して管理しやすい形にしましょう。

データ拡張による精度の変化

今回は水平反転と垂直反転をランダムに入れて、汎化性能が向上するか確認しましょう。

# 適用したいデータ拡張の種類を定義
datagen = ImageDataGenerator(
    horizontal_flip=True,
    vertical_flip=True)
      
# シードの固定
reset_seed(0)

# モデルの構築
model = models.Sequential([
    layers.Conv2D(32, (3, 3), padding='same', activation='relu', input_shape=(32, 32, 3)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
]) 
  
# optimizerの設定
optimizer = tf.keras.optimizers.Adam(lr=1e-3)
 
# モデルのコンパイル
model.compile(loss='sparse_categorical_crossentropy', 
              optimizer=optimizer,
              metrics=['accuracy'])

model.summary()
      
batch_size = 1024
epochs = 100

history = model.fit_generator(datagen.flow(x_train, t_train, batch_size=batch_size),
                    steps_per_epoch = len(x_train) / batch_size, 
                    epochs = epochs,
                    validation_data = (x_test, t_test))
      
WARNING:tensorflow:sample_weight modes were coerced from ... to ['...'] Train for 48.828125 steps, validate on 10000 samples Epoch 1/100 49/48 [==============================] - 3s 63ms/step - loss: 2.0579 - accuracy: 0.2475 - val_loss: 1.8347 - val_accuracy: 0.3389 Epoch 2/100 49/48 [==============================] - 3s 54ms/step - loss: 1.7161 - accuracy: 0.3777 - val_loss: 1.6143 - val_accuracy: 0.4064 Epoch 3/100 49/48 [==============================] - 3s 54ms/step - loss: 1.5610 - accuracy: 0.4324 - val_loss: 1.5248 - val_accuracy: 0.4463 Epoch 4/100 49/48 [==============================] - 3s 54ms/step - loss: 1.4796 - accuracy: 0.4685 - val_loss: 1.4301 - val_accuracy: 0.4868 Epoch 5/100 49/48 [==============================] - 3s 54ms/step - loss: 1.4027 - accuracy: 0.4966 - val_loss: 1.3982 - val_accuracy: 0.5013 Epoch 6/100 49/48 [==============================] - 3s 55ms/step - loss: 1.3535 - accuracy: 0.5163 - val_loss: 1.3374 - val_accuracy: 0.5205 Epoch 7/100 49/48 [==============================] - 3s 54ms/step - loss: 1.3161 - accuracy: 0.5302 - val_loss: 1.2839 - val_accuracy: 0.5405 Epoch 8/100 49/48 [==============================] - 3s 54ms/step - loss: 1.2700 - accuracy: 0.5496 - val_loss: 1.2530 - val_accuracy: 0.5522 Epoch 9/100 49/48 [==============================] - 3s 53ms/step - loss: 1.2280 - accuracy: 0.5629 - val_loss: 1.2011 - val_accuracy: 0.5726 Epoch 10/100 49/48 [==============================] - 3s 54ms/step - loss: 1.1973 - accuracy: 0.5746 - val_loss: 1.1905 - val_accuracy: 0.5692 Epoch 11/100 49/48 [==============================] - 3s 54ms/step - loss: 1.1808 - accuracy: 0.5802 - val_loss: 1.1745 - val_accuracy: 0.5864 Epoch 12/100 49/48 [==============================] - 3s 53ms/step - loss: 1.1456 - accuracy: 0.5957 - val_loss: 1.1272 - val_accuracy: 0.6026 Epoch 13/100 49/48 [==============================] - 3s 54ms/step - loss: 1.1264 - accuracy: 0.6035 - val_loss: 1.1284 - val_accuracy: 0.5965 Epoch 14/100 49/48 [==============================] - 3s 54ms/step - loss: 1.1037 - accuracy: 0.6120 - val_loss: 1.0977 - val_accuracy: 0.6117 Epoch 15/100 49/48 [==============================] - 3s 54ms/step - loss: 1.0711 - accuracy: 0.6238 - val_loss: 1.0921 - val_accuracy: 0.6165 Epoch 16/100 49/48 [==============================] - 3s 54ms/step - loss: 1.0722 - accuracy: 0.6221 - val_loss: 1.0808 - val_accuracy: 0.6173 Epoch 17/100 49/48 [==============================] - 3s 54ms/step - loss: 1.0491 - accuracy: 0.6304 - val_loss: 1.0525 - val_accuracy: 0.6325 Epoch 18/100 49/48 [==============================] - 3s 54ms/step - loss: 1.0276 - accuracy: 0.6396 - val_loss: 1.0377 - val_accuracy: 0.6350 Epoch 19/100 49/48 [==============================] - 3s 54ms/step - loss: 1.0151 - accuracy: 0.6457 - val_loss: 1.0236 - val_accuracy: 0.6394 Epoch 20/100 49/48 [==============================] - 3s 54ms/step - loss: 0.9980 - accuracy: 0.6536 - val_loss: 1.0214 - val_accuracy: 0.6460 Epoch 21/100 49/48 [==============================] - 3s 54ms/step - loss: 0.9945 - accuracy: 0.6526 - val_loss: 1.0229 - val_accuracy: 0.6447 Epoch 22/100 49/48 [==============================] - 3s 54ms/step - loss: 0.9867 - accuracy: 0.6539 - val_loss: 0.9907 - val_accuracy: 0.6506 Epoch 23/100 49/48 [==============================] - 3s 54ms/step - loss: 0.9638 - accuracy: 0.6658 - val_loss: 0.9845 - val_accuracy: 0.6560 Epoch 24/100 49/48 [==============================] - 3s 54ms/step - loss: 0.9510 - accuracy: 0.6700 - val_loss: 0.9891 - val_accuracy: 0.6469 Epoch 25/100 49/48 [==============================] - 3s 54ms/step - loss: 0.9344 - accuracy: 0.6732 - val_loss: 0.9818 - val_accuracy: 0.6499 Epoch 26/100 49/48 [==============================] - 3s 54ms/step - loss: 0.9233 - accuracy: 0.6769 - val_loss: 0.9794 - val_accuracy: 0.6556 Epoch 27/100 49/48 [==============================] - 3s 54ms/step - loss: 0.9305 - accuracy: 0.6751 - val_loss: 0.9618 - val_accuracy: 0.6621 Epoch 28/100 49/48 [==============================] - 3s 53ms/step - loss: 0.9065 - accuracy: 0.6834 - val_loss: 0.9486 - val_accuracy: 0.6695 Epoch 29/100 49/48 [==============================] - 3s 54ms/step - loss: 0.8909 - accuracy: 0.6886 - val_loss: 0.9615 - val_accuracy: 0.6607 Epoch 30/100 49/48 [==============================] - 3s 54ms/step - loss: 0.8915 - accuracy: 0.6879 - val_loss: 0.9464 - val_accuracy: 0.6700 Epoch 31/100 49/48 [==============================] - 3s 54ms/step - loss: 0.8828 - accuracy: 0.6922 - val_loss: 0.9644 - val_accuracy: 0.6630 Epoch 32/100 49/48 [==============================] - 3s 54ms/step - loss: 0.8694 - accuracy: 0.6993 - val_loss: 0.9536 - val_accuracy: 0.6661 Epoch 33/100 49/48 [==============================] - 3s 54ms/step - loss: 0.8521 - accuracy: 0.7013 - val_loss: 0.9127 - val_accuracy: 0.6778 Epoch 34/100 49/48 [==============================] - 3s 55ms/step - loss: 0.8465 - accuracy: 0.7039 - val_loss: 0.9116 - val_accuracy: 0.6832 Epoch 35/100 49/48 [==============================] - 3s 54ms/step - loss: 0.8405 - accuracy: 0.7085 - val_loss: 0.9098 - val_accuracy: 0.6793 Epoch 36/100 49/48 [==============================] - 3s 54ms/step - loss: 0.8291 - accuracy: 0.7104 - val_loss: 0.9428 - val_accuracy: 0.6643 Epoch 37/100 49/48 [==============================] - 3s 54ms/step - loss: 0.8273 - accuracy: 0.7111 - val_loss: 0.9270 - val_accuracy: 0.6755 Epoch 38/100 49/48 [==============================] - 3s 54ms/step - loss: 0.8231 - accuracy: 0.7134 - val_loss: 0.8949 - val_accuracy: 0.6803 Epoch 39/100 49/48 [==============================] - 3s 54ms/step - loss: 0.8077 - accuracy: 0.7180 - val_loss: 0.8856 - val_accuracy: 0.6913 Epoch 40/100 49/48 [==============================] - 3s 54ms/step - loss: 0.7935 - accuracy: 0.7220 - val_loss: 0.8810 - val_accuracy: 0.6919 Epoch 41/100 49/48 [==============================] - 3s 54ms/step - loss: 0.7944 - accuracy: 0.7225 - val_loss: 0.8871 - val_accuracy: 0.6898 Epoch 42/100 49/48 [==============================] - 3s 54ms/step - loss: 0.7724 - accuracy: 0.7312 - val_loss: 0.8773 - val_accuracy: 0.6943 Epoch 43/100 49/48 [==============================] - 3s 54ms/step - loss: 0.7800 - accuracy: 0.7295 - val_loss: 0.8642 - val_accuracy: 0.6974 Epoch 44/100 49/48 [==============================] - 3s 54ms/step - loss: 0.7699 - accuracy: 0.7328 - val_loss: 0.8674 - val_accuracy: 0.6939 Epoch 45/100 49/48 [==============================] - 3s 54ms/step - loss: 0.7732 - accuracy: 0.7299 - val_loss: 0.8632 - val_accuracy: 0.7010 Epoch 46/100 49/48 [==============================] - 3s 55ms/step - loss: 0.7584 - accuracy: 0.7348 - val_loss: 0.8709 - val_accuracy: 0.6941 Epoch 47/100 49/48 [==============================] - 3s 54ms/step - loss: 0.7471 - accuracy: 0.7386 - val_loss: 0.8588 - val_accuracy: 0.7048 Epoch 48/100 49/48 [==============================] - 3s 54ms/step - loss: 0.7453 - accuracy: 0.7406 - val_loss: 0.8287 - val_accuracy: 0.7108 Epoch 49/100 49/48 [==============================] - 3s 54ms/step - loss: 0.7292 - accuracy: 0.7451 - val_loss: 0.8608 - val_accuracy: 0.6993 Epoch 50/100 49/48 [==============================] - 3s 54ms/step - loss: 0.7364 - accuracy: 0.7447 - val_loss: 0.9036 - val_accuracy: 0.6823 Epoch 51/100 49/48 [==============================] - 3s 54ms/step - loss: 0.7256 - accuracy: 0.7473 - val_loss: 0.8232 - val_accuracy: 0.7090 Epoch 52/100 49/48 [==============================] - 3s 53ms/step - loss: 0.7120 - accuracy: 0.7526 - val_loss: 0.8248 - val_accuracy: 0.7128 Epoch 53/100 49/48 [==============================] - 3s 53ms/step - loss: 0.7013 - accuracy: 0.7561 - val_loss: 0.8335 - val_accuracy: 0.7112 Epoch 54/100 49/48 [==============================] - 3s 55ms/step - loss: 0.7084 - accuracy: 0.7518 - val_loss: 0.8357 - val_accuracy: 0.7085 Epoch 55/100 49/48 [==============================] - 3s 54ms/step - loss: 0.7069 - accuracy: 0.7521 - val_loss: 0.8182 - val_accuracy: 0.7150 Epoch 56/100 49/48 [==============================] - 3s 53ms/step - loss: 0.6888 - accuracy: 0.7596 - val_loss: 0.8173 - val_accuracy: 0.7141 Epoch 57/100 49/48 [==============================] - 3s 53ms/step - loss: 0.6934 - accuracy: 0.7578 - val_loss: 0.8293 - val_accuracy: 0.7115 Epoch 58/100 49/48 [==============================] - 3s 53ms/step - loss: 0.6825 - accuracy: 0.7633 - val_loss: 0.8221 - val_accuracy: 0.7140 Epoch 59/100 49/48 [==============================] - 3s 54ms/step - loss: 0.6837 - accuracy: 0.7613 - val_loss: 0.8239 - val_accuracy: 0.7138 Epoch 60/100 49/48 [==============================] - 3s 54ms/step - loss: 0.6639 - accuracy: 0.7687 - val_loss: 0.8041 - val_accuracy: 0.7212 Epoch 61/100 49/48 [==============================] - 3s 54ms/step - loss: 0.6618 - accuracy: 0.7686 - val_loss: 0.8113 - val_accuracy: 0.7167 Epoch 62/100 49/48 [==============================] - 3s 56ms/step - loss: 0.6645 - accuracy: 0.7690 - val_loss: 0.8196 - val_accuracy: 0.7156 Epoch 63/100 49/48 [==============================] - 3s 54ms/step - loss: 0.6500 - accuracy: 0.7742 - val_loss: 0.8069 - val_accuracy: 0.7197 Epoch 64/100 49/48 [==============================] - 3s 54ms/step - loss: 0.6555 - accuracy: 0.7711 - val_loss: 0.7885 - val_accuracy: 0.7245 Epoch 65/100 49/48 [==============================] - 3s 54ms/step - loss: 0.6461 - accuracy: 0.7773 - val_loss: 0.7843 - val_accuracy: 0.7225 Epoch 66/100 49/48 [==============================] - 3s 55ms/step - loss: 0.6384 - accuracy: 0.7763 - val_loss: 0.8032 - val_accuracy: 0.7196 Epoch 67/100 49/48 [==============================] - 3s 54ms/step - loss: 0.6278 - accuracy: 0.7804 - val_loss: 0.7968 - val_accuracy: 0.7263 Epoch 68/100 49/48 [==============================] - 3s 54ms/step - loss: 0.6413 - accuracy: 0.7764 - val_loss: 0.7908 - val_accuracy: 0.7231 Epoch 69/100 49/48 [==============================] - 3s 54ms/step - loss: 0.6199 - accuracy: 0.7843 - val_loss: 0.8207 - val_accuracy: 0.7171 Epoch 70/100 49/48 [==============================] - 3s 54ms/step - loss: 0.6225 - accuracy: 0.7815 - val_loss: 0.7979 - val_accuracy: 0.7221 Epoch 71/100 49/48 [==============================] - 3s 54ms/step - loss: 0.6200 - accuracy: 0.7831 - val_loss: 0.7782 - val_accuracy: 0.7266 Epoch 72/100 49/48 [==============================] - 3s 54ms/step - loss: 0.6090 - accuracy: 0.7874 - val_loss: 0.7948 - val_accuracy: 0.7245 Epoch 73/100 49/48 [==============================] - 3s 54ms/step - loss: 0.6082 - accuracy: 0.7873 - val_loss: 0.7875 - val_accuracy: 0.7257 Epoch 74/100 49/48 [==============================] - 3s 54ms/step - loss: 0.6017 - accuracy: 0.7894 - val_loss: 0.7958 - val_accuracy: 0.7234 Epoch 75/100 49/48 [==============================] - 3s 54ms/step - loss: 0.6118 - accuracy: 0.7864 - val_loss: 0.7779 - val_accuracy: 0.7315 Epoch 76/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5854 - accuracy: 0.7947 - val_loss: 0.7766 - val_accuracy: 0.7307 Epoch 77/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5834 - accuracy: 0.7961 - val_loss: 0.8060 - val_accuracy: 0.7207 Epoch 78/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5955 - accuracy: 0.7937 - val_loss: 0.7838 - val_accuracy: 0.7272 Epoch 79/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5782 - accuracy: 0.7972 - val_loss: 0.7914 - val_accuracy: 0.7303 Epoch 80/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5807 - accuracy: 0.7953 - val_loss: 0.7959 - val_accuracy: 0.7268 Epoch 81/100 49/48 [==============================] - 3s 53ms/step - loss: 0.5790 - accuracy: 0.7957 - val_loss: 0.7808 - val_accuracy: 0.7313 Epoch 82/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5662 - accuracy: 0.8022 - val_loss: 0.8006 - val_accuracy: 0.7261 Epoch 83/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5653 - accuracy: 0.8012 - val_loss: 0.7924 - val_accuracy: 0.7300 Epoch 84/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5473 - accuracy: 0.8094 - val_loss: 0.7751 - val_accuracy: 0.7311 Epoch 85/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5612 - accuracy: 0.8029 - val_loss: 0.7779 - val_accuracy: 0.7291 Epoch 86/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5566 - accuracy: 0.8074 - val_loss: 0.7734 - val_accuracy: 0.7359 Epoch 87/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5488 - accuracy: 0.8089 - val_loss: 0.7734 - val_accuracy: 0.7343 Epoch 88/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5473 - accuracy: 0.8078 - val_loss: 0.7660 - val_accuracy: 0.7361 Epoch 89/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5410 - accuracy: 0.8123 - val_loss: 0.7898 - val_accuracy: 0.7315 Epoch 90/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5346 - accuracy: 0.8137 - val_loss: 0.8019 - val_accuracy: 0.7253 Epoch 91/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5327 - accuracy: 0.8144 - val_loss: 0.8042 - val_accuracy: 0.7285 Epoch 92/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5285 - accuracy: 0.8153 - val_loss: 0.8235 - val_accuracy: 0.7248 Epoch 93/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5330 - accuracy: 0.8138 - val_loss: 0.7906 - val_accuracy: 0.7274 Epoch 94/100 49/48 [==============================] - 3s 55ms/step - loss: 0.5314 - accuracy: 0.8146 - val_loss: 0.7744 - val_accuracy: 0.7349 Epoch 95/100 49/48 [==============================] - 3s 55ms/step - loss: 0.5217 - accuracy: 0.8181 - val_loss: 0.7782 - val_accuracy: 0.7347 Epoch 96/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5286 - accuracy: 0.8164 - val_loss: 0.7853 - val_accuracy: 0.7378 Epoch 97/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5122 - accuracy: 0.8215 - val_loss: 0.7812 - val_accuracy: 0.7371 Epoch 98/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5024 - accuracy: 0.8255 - val_loss: 0.7815 - val_accuracy: 0.7397 Epoch 99/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5051 - accuracy: 0.8233 - val_loss: 0.7794 - val_accuracy: 0.7363 Epoch 100/100 49/48 [==============================] - 3s 54ms/step - loss: 0.5028 - accuracy: 0.8252 - val_loss: 0.7952 - val_accuracy: 0.7314
results = pd.DataFrame(history.history)
results.tail(1)
      
loss accuracy val_loss val_accuracy
99 0.502889 0.82516 0.795237 0.7314
results[['loss', 'val_loss']].plot()
      
<matplotlib.axes._subplots.AxesSubplot at 0x7f46081e36d8>
<Figure size 432x288 with 1 Axes>
Train Val
Base Accuracy 0.999 0.721
Base Loss 0.006 2.148
Augmentation Accuracy 0.825 0.731
Augmentation Loss 0.503 0.795

検証データの正解率が向上し、学習データの正解率との乖離が小さくなりました。汎化性能が向上したことが確認できました。

データ拡張は、簡単な処理でありながら手軽に精度向上に貢献してくれる重要な手法です。最新の手法では、適用していないモデルは無いと言ってよいほどスタンダードになっていますので、ぜひ適用してください。

shareアイコン