Maison >Périphériques technologiques >IA >Créez un classificateur d'apprentissage profond pour les photos de chats et de chiens à l'aide de TensorFlow et Keras.

Créez un classificateur d'apprentissage profond pour les photos de chats et de chiens à l'aide de TensorFlow et Keras.

PHPz
PHPzavant
2023-05-16 09:34:161315parcourir

Créez un classificateur dapprentissage profond pour les photos de chats et de chiens à laide de TensorFlow et Keras.

Dans cet article, nous utiliserons TensorFlow et Keras pour créer un classificateur d'images capable de différencier les images de chats et de chiens. Pour ce faire, nous utiliserons l'ensemble de données cats_vs_dogs de l'ensemble de données TensorFlow. L'ensemble de données se compose de 25 000 images étiquetées de chats et de chiens, dont 80 % sont utilisées pour l'entraînement, 10 % pour la validation et 10 % pour les tests.

Chargement des données

Nous commençons par charger l'ensemble de données à l'aide des ensembles de données TensorFlow. Divisez l'ensemble de données en ensemble d'entraînement, ensemble de validation et ensemble de test, représentant respectivement 80 %, 10 % et 10 % des données, et définissez une fonction pour afficher quelques exemples d'images dans l'ensemble de données.

<code>import tensorflow as tfimport matplotlib.pyplot as pltimport tensorflow_datasets as tfds# 加载数据(train_data, validation_data, test_data), info = tfds.load('cats_vs_dogs', split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'], with_info=True, as_supervised=True)# 获取图像的标签label_names = info.features['label'].names# 定义一个函数来显示一些样本图像plt.figure(figsize=(10, 10))for i, (image, label) in enumerate(train_data.take(9)):ax = plt.subplot(3, 3, i + 1)plt.imshow(image)plt.title(label_names[label])plt.axis('off')</code>

Créez un classificateur dapprentissage profond pour les photos de chats et de chiens à laide de TensorFlow et Keras.

Prétraitement des données

Avant d'entraîner le modèle, les données doivent être prétraitées. L'image sera redimensionnée à une taille uniforme de 150 x 150 pixels, les valeurs des pixels seront normalisées entre 0 et 1 et les données seront traitées par lots afin de pouvoir être importées dans le modèle par lots.

<code>IMG_SIZE = 150</code>
<code>def format_image(image, label):image = tf.cast(image, tf.float32) / 255.0# Normalize the pixel valuesimage = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))# Resize to the desired sizereturn image, labelbatch_size = 32train_data = train_data.map(format_image).shuffle(1000).batch(batch_size)validation_data = validation_data.map(format_image).batch(batch_size)test_data = test_data.map(format_image).batch(batch_size)</code>

Créez un classificateur dapprentissage profond pour les photos de chats et de chiens à laide de TensorFlow et Keras.

Création du modèle

Cet article utilisera le modèle MobileNet V2 pré-entraîné comme modèle de base et y ajoutera une couche de pooling moyenne globale et une couche compacte pour la classification. Cet article va geler les poids du modèle de base afin que seuls les poids de la couche supérieure soient mis à jour pendant l'entraînement.

<code>base_model = tf.keras.applications.MobileNetV2(input_shape=(IMG_SIZE, IMG_SIZE, 3), include_top=False, weights='imagenet')base_model.trainable = False</code>
<code>global_average_layer = tf.keras.layers.GlobalAveragePooling2D()prediction_layer = tf.keras.layers.Dense(1)model = tf.keras.Sequential([base_model,global_average_layer,prediction_layer])model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),metrics=['accuracy'])</code>

Entraînement du modèle

Cet article entraînera le modèle pendant 3 cycles et le validera sur l'ensemble de validation après chaque cycle. Nous sauvegarderons le modèle après la formation afin de pouvoir l'utiliser lors de futurs tests.

<code>global_average_layer = tf.keras.layers.GlobalAveragePooling2D()prediction_layer = tf.keras.layers.Dense(1)model = tf.keras.Sequential([base_model,global_average_layer,prediction_layer])model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),metrics=['accuracy'])</code>
<code>history = model.fit(train_data,epochs=3,validation_data=validation_data)</code>

Créez un classificateur dapprentissage profond pour les photos de chats et de chiens à laide de TensorFlow et Keras.

Historique du modèle

Si vous voulez savoir comment fonctionne la couche Mobilenet V2, l'image ci-dessous est le résultat de cette couche.

Créez un classificateur dapprentissage profond pour les photos de chats et de chiens à laide de TensorFlow et Keras.

Évaluer le modèle

Une fois la formation terminée, le modèle sera évalué sur l'ensemble de test pour voir ses performances sur les nouvelles données.

<code>loaded_model = tf.keras.models.load_model('cats_vs_dogs.h5')test_loss, test_accuracy = loaded_model.evaluate(test_data)</code>
<code>print('Test accuracy:', test_accuracy)</code>

Prédiction

Enfin, cet article utilisera le modèle pour prédire quelques exemples d'images dans l'ensemble de test et afficher les résultats.

<code>for image , _ in test_.take(90) : passpre = loaded_model.predict(image)plt.figure(figsize = (10 , 10))j = Nonefor value in enumerate(pre) : plt.subplot(7,7,value[0]+1)plt.imshow(image[value[0]])plt.xticks([])plt.yticks([])if value[1] > pre.mean() :j = 1color = 'blue' if j == _[value[0]] else 'red'plt.title('dog' , color = color)else : j = 0color = 'blue' if j == _[value[0]] else 'red'plt.title('cat' , color = color)plt.show()</code>

Créez un classificateur dapprentissage profond pour les photos de chats et de chiens à laide de TensorFlow et Keras.

Fait ! Nous avons créé un classificateur d'images capable de différencier les images de chats et de chiens à l'aide de TensorFlow et Keras. Avec quelques ajustements et ajustements, cette approche peut également être appliquée à d’autres problèmes de classification d’images.

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Cet article est reproduit dans:. en cas de violation, veuillez contacter admin@php.cn Supprimer