ホームページ  >  記事  >  テクノロジー周辺機器  >  TensorFlow と Keras を使用して猫と犬の写真の深層学習分類器を作成する

TensorFlow と Keras を使用して猫と犬の写真の深層学習分類器を作成する

PHPz
PHPz転載
2023-05-16 09:34:161238ブラウズ

TensorFlow と Keras を使用して猫と犬の写真の深層学習分類器を作成する

この記事では、TensorFlow と Keras を使用して、猫と犬の画像を区別できる画像分類器を作成します。これを行うには、TensorFlow データセットの cat_vs_dogs データセットを使用します。データセットは 25,000 枚の猫と犬のラベル付き画像で構成されており、そのうち 80% がトレーニングに、10% が検証に、10% がテストに使用されます。

データのロード

TensorFlow データセットを使用してデータセットをロードすることから始めます。データ セットをトレーニング セット、検証セット、テスト セットに分割し、それぞれデータの 80%、10%、10% を占め、データ セット内のいくつかのサンプル画像を表示する関数を定義します。

<code>import tensorflow as tfimport matplotlib.pyplot as pltimport tensorflow_datasets as tfds# 加载数据(train_data, validation_data, test_data), info = tfds.load('cats_vs_dogs', split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'], with_info=True, as_supervised=True)# 获取图像的标签label_names = info.features['label'].names# 定义一个函数来显示一些样本图像plt.figure(figsize=(10, 10))for i, (image, label) in enumerate(train_data.take(9)):ax = plt.subplot(3, 3, i + 1)plt.imshow(image)plt.title(label_names[label])plt.axis('off')</code>

TensorFlow と Keras を使用して猫と犬の写真の深層学習分類器を作成する

データの前処理

モデルをトレーニングする前に、データを前処理する必要があります。画像は 150x150 ピクセルの均一なサイズにサイズ変更され、ピクセル値は 0 と 1 の間で正規化され、データはバッチでモデルにインポートできるようにバッチ処理されます。

<code>IMG_SIZE = 150</code>
<code>def format_image(image, label):image = tf.cast(image, tf.float32) / 255.0# Normalize the pixel valuesimage = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))# Resize to the desired sizereturn image, labelbatch_size = 32train_data = train_data.map(format_image).shuffle(1000).batch(batch_size)validation_data = validation_data.map(format_image).batch(batch_size)test_data = test_data.map(format_image).batch(batch_size)</code>

TensorFlow と Keras を使用して猫と犬の写真の深層学習分類器を作成する

##モデルの構築

この記事では、事前トレーニングされた MobileNet V2 モデルを基本として使用しますそして、分類のためにグローバル平均プーリング層とコンパクト層を追加します。この記事では、ベース モデルの重みをフリーズして、トレーニング中に最上層の重みのみが更新されるようにします。

<code>base_model = tf.keras.applications.MobileNetV2(input_shape=(IMG_SIZE, IMG_SIZE, 3), include_top=False, weights='imagenet')base_model.trainable = False</code>
<code>global_average_layer = tf.keras.layers.GlobalAveragePooling2D()prediction_layer = tf.keras.layers.Dense(1)model = tf.keras.Sequential([base_model,global_average_layer,prediction_layer])model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),metrics=['accuracy'])</code>

トレーニング モデル

この記事では、3 サイクルにわたってモデルをトレーニングし、各サイクルの認証後に検証セットでテストします。トレーニング後にモデルを保存して、将来のテストで使用できるようにします。

<code>global_average_layer = tf.keras.layers.GlobalAveragePooling2D()prediction_layer = tf.keras.layers.Dense(1)model = tf.keras.Sequential([base_model,global_average_layer,prediction_layer])model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),metrics=['accuracy'])</code>
<code>history = model.fit(train_data,epochs=3,validation_data=validation_data)</code>

TensorFlow と Keras を使用して猫と犬の写真の深層学習分類器を作成する

モデル履歴

Mobilenet V2 レイヤーがどのように機能するかを知りたい場合は、次の図がこのレイヤーの結果です。

TensorFlow と Keras を使用して猫と犬の写真の深層学習分類器を作成する

#モデルの評価トレーニングが完了すると、モデルはテスト セットで評価されます。新しいデータでどのように動作するかを確認してください。

<code>loaded_model = tf.keras.models.load_model('cats_vs_dogs.h5')test_loss, test_accuracy = loaded_model.evaluate(test_data)</code>
<code>print('Test accuracy:', test_accuracy)</code>

予測最後に、この記事では、モデルを使用してテスト セット内のいくつかのサンプル画像を予測し、結果を表示します。

<code>for image , _ in test_.take(90) : passpre = loaded_model.predict(image)plt.figure(figsize = (10 , 10))j = Nonefor value in enumerate(pre) : plt.subplot(7,7,value[0]+1)plt.imshow(image[value[0]])plt.xticks([])plt.yticks([])if value[1] > pre.mean() :j = 1color = 'blue' if j == _[value[0]] else 'red'plt.title('dog' , color = color)else : j = 0color = 'blue' if j == _[value[0]] else 'red'plt.title('cat' , color = color)plt.show()</code>

TensorFlow と Keras を使用して猫と犬の写真の深層学習分類器を作成する 完了しました! TensorFlow と Keras を使用して、猫と犬の画像を区別できる画像分類器を作成しました。いくつかの調整と微調整を行うことで、このアプローチは他の画像分類問題にも適用できます。

以上がTensorFlow と Keras を使用して猫と犬の写真の深層学習分類器を作成するの詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

声明:
この記事は51cto.comで複製されています。侵害がある場合は、admin@php.cn までご連絡ください。