Heim  >  Artikel  >  Technologie-Peripheriegeräte  >  Erstellen Sie mit TensorFlow und Keras einen Deep-Learning-Klassifikator für Katzen- und Hundebilder

Erstellen Sie mit TensorFlow und Keras einen Deep-Learning-Klassifikator für Katzen- und Hundebilder

PHPz
PHPznach vorne
2023-05-16 09:34:161238Durchsuche

Erstellen Sie mit TensorFlow und Keras einen Deep-Learning-Klassifikator für Katzen- und Hundebilder

In diesem Artikel werden wir TensorFlow und Keras verwenden, um einen Bildklassifikator zu erstellen, der zwischen Bildern von Katzen und Hunden unterscheiden kann. Dazu verwenden wir den Datensatz cats_vs_dogs aus dem TensorFlow-Datensatz. Der Datensatz besteht aus 25.000 beschrifteten Bildern von Katzen und Hunden, von denen 80 % für das Training, 10 % für die Validierung und 10 % für Tests verwendet werden.

Daten laden

Wir beginnen mit dem Laden des Datensatzes mithilfe von TensorFlow-Datensätzen. Teilen Sie den Datensatz in Trainingssatz, Validierungssatz und Testsatz auf, die jeweils 80 %, 10 % und 10 % der Daten ausmachen, und definieren Sie eine Funktion zum Anzeigen einiger Beispielbilder im Datensatz.

<code>import tensorflow as tfimport matplotlib.pyplot as pltimport tensorflow_datasets as tfds# 加载数据(train_data, validation_data, test_data), info = tfds.load('cats_vs_dogs', split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'], with_info=True, as_supervised=True)# 获取图像的标签label_names = info.features['label'].names# 定义一个函数来显示一些样本图像plt.figure(figsize=(10, 10))for i, (image, label) in enumerate(train_data.take(9)):ax = plt.subplot(3, 3, i + 1)plt.imshow(image)plt.title(label_names[label])plt.axis('off')</code>

Erstellen Sie mit TensorFlow und Keras einen Deep-Learning-Klassifikator für Katzen- und Hundebilder

Daten vorverarbeiten

Vor dem Training des Modells müssen die Daten vorverarbeitet werden. Die Bildgröße wird auf eine einheitliche Größe von 150 x 150 Pixel geändert, die Pixelwerte werden zwischen 0 und 1 normalisiert und die Daten werden stapelweise verarbeitet, sodass sie stapelweise in das Modell importiert werden können.

<code>IMG_SIZE = 150</code>
<code>def format_image(image, label):image = tf.cast(image, tf.float32) / 255.0# Normalize the pixel valuesimage = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))# Resize to the desired sizereturn image, labelbatch_size = 32train_data = train_data.map(format_image).shuffle(1000).batch(batch_size)validation_data = validation_data.map(format_image).batch(batch_size)test_data = test_data.map(format_image).batch(batch_size)</code>

Erstellen Sie mit TensorFlow und Keras einen Deep-Learning-Klassifikator für Katzen- und Hundebilder

Aufbau des Modells

In diesem Artikel wird das vorab trainierte MobileNet V2-Modell als Basismodell verwendet und zur Klassifizierung eine globale Durchschnitts-Pooling-Schicht und eine kompakte Schicht hinzugefügt. In diesem Artikel werden die Gewichte des Basismodells eingefroren, sodass während des Trainings nur die Gewichte der obersten Schicht aktualisiert werden.

<code>base_model = tf.keras.applications.MobileNetV2(input_shape=(IMG_SIZE, IMG_SIZE, 3), include_top=False, weights='imagenet')base_model.trainable = False</code>
<code>global_average_layer = tf.keras.layers.GlobalAveragePooling2D()prediction_layer = tf.keras.layers.Dense(1)model = tf.keras.Sequential([base_model,global_average_layer,prediction_layer])model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),metrics=['accuracy'])</code>

Training des Modells

In diesem Artikel wird das Modell für 3 Zyklen trainiert und nach jedem Zyklus auf dem Validierungssatz validiert. Wir werden das Modell nach dem Training speichern, damit wir es in zukünftigen Tests verwenden können.

<code>global_average_layer = tf.keras.layers.GlobalAveragePooling2D()prediction_layer = tf.keras.layers.Dense(1)model = tf.keras.Sequential([base_model,global_average_layer,prediction_layer])model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),metrics=['accuracy'])</code>
<code>history = model.fit(train_data,epochs=3,validation_data=validation_data)</code>

Erstellen Sie mit TensorFlow und Keras einen Deep-Learning-Klassifikator für Katzen- und Hundebilder

Modellhistorie

Wenn Sie wissen möchten, wie die Mobilenet V2-Schicht funktioniert, ist das Bild unten ein Ergebnis dieser Schicht.

Erstellen Sie mit TensorFlow und Keras einen Deep-Learning-Klassifikator für Katzen- und Hundebilder

Bewerten Sie das Modell

Nach Abschluss des Trainings wird das Modell am Testsatz bewertet, um zu sehen, wie es bei neuen Daten abschneidet.

<code>loaded_model = tf.keras.models.load_model('cats_vs_dogs.h5')test_loss, test_accuracy = loaded_model.evaluate(test_data)</code>
<code>print('Test accuracy:', test_accuracy)</code>

Vorhersage

Abschließend wird in diesem Artikel das Modell verwendet, um einige Beispielbilder im Testsatz vorherzusagen und die Ergebnisse anzuzeigen.

<code>for image , _ in test_.take(90) : passpre = loaded_model.predict(image)plt.figure(figsize = (10 , 10))j = Nonefor value in enumerate(pre) : plt.subplot(7,7,value[0]+1)plt.imshow(image[value[0]])plt.xticks([])plt.yticks([])if value[1] > pre.mean() :j = 1color = 'blue' if j == _[value[0]] else 'red'plt.title('dog' , color = color)else : j = 0color = 'blue' if j == _[value[0]] else 'red'plt.title('cat' , color = color)plt.show()</code>

Erstellen Sie mit TensorFlow und Keras einen Deep-Learning-Klassifikator für Katzen- und Hundebilder

Fertig! Mithilfe von TensorFlow und Keras haben wir einen Bildklassifizierer erstellt, der zwischen Bildern von Katzen und Hunden unterscheiden kann. Mit einigen Anpassungen und Feinabstimmungen kann dieser Ansatz auch auf andere Bildklassifizierungsprobleme angewendet werden.

Das obige ist der detaillierte Inhalt vonErstellen Sie mit TensorFlow und Keras einen Deep-Learning-Klassifikator für Katzen- und Hundebilder. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Dieser Artikel ist reproduziert unter:51cto.com. Bei Verstößen wenden Sie sich bitte an admin@php.cn löschen