Home > Article > Backend Development > How to use neural networks for classification in Python?
When it comes to classifying large amounts of data, manually processing this data is a very time-consuming and difficult task. In this case, using a neural network for classification can do the job quickly and easily. Python is a good choice because it has many mature and easy-to-use neural network libraries. This article explains how to use neural networks for classification in Python.
Before explaining how to use neural networks for classification, we need to briefly understand the concept of neural networks. A neural network is a computational model that works by building a model based on relationships between large amounts of input and output data to predict certain properties of unknown data. This model performs very well on classification problems and can be used to classify different types of data such as pictures, emails, and voices.
Classification is one of the main applications of neural networks. The purpose of classification problems is to classify data into different categories. For example, in image recognition, neural networks can classify different images into different categories such as cats, dogs, or cars. In this case, the neural network takes images as input data and classification as output data. Classification is the process of dividing data into different categories, usually using supervised learning methods.
There are many neural network libraries to choose from in Python, such as TensorFlow, Keras, PyTorch, etc. In this article, we will use TensorFlow, an open source artificial intelligence library developed by the Google brain team. TensorFlow is a very popular framework that is easy to learn and use, and it is used in a large number of machine learning projects.
If you have not installed TensorFlow, you can open a terminal or command prompt and enter the following command:
pip install tensorflow
After the installation is complete, you can Use the TensorFlow library.
Data preparation is a critical step in the classification task. The data needs to be converted into a numerical format that can be understood by the neural network. Here, we will introduce a very popular dataset MNIST, which consists of digital images, each image represents a number. The MNIST dataset is available in TensorFlow, and you can load the data directly using the following command:
from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
This command loads the MNIST dataset into the variables x_train and y_train, which are used to train the neural network. Test data is loaded into variables x_test and y_test and is used to test the neural network. x_train and x_test contain the numeric image data, y_train and y_test contain the labels of the numeric images.
Next, let’s take a look at the dataset to learn more:
print('x_train shape:', x_train.shape)
print('y_train shape :', y_train.shape)
print('x_test shape:', x_test.shape)
print('y_test shape:', y_test.shape)
at In the output, you will see the following information:
x_train shape: (60000, 28, 28)
y_train shape: (60000,)
x_test shape: (10000 , 28, 28)
y_test shape: (10000,)
This shows that the training data set contains 60000 digital images, each image is 28 pixels x 28 pixels. The test dataset has 10,000 images.
After preparing the data, you need to select a neural network model. We will choose a very simple neural network model consisting of two fully connected layers (Dense). The first fully connected layer contains 128 neurons, and the second fully connected layer contains 10 neurons. The code is as follows:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(128, activation='relu'))
model.add(Dense(10, activation='softmax'))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
Here, we first created a Sequential model and then added a Flatten layer, which is used to flatten the 28x28 image data into a one-dimensional array. Next, we added a fully connected layer with 128 neurons and used ReLU as the activation function. Finally, we add another fully connected layer with 10 neurons and use the Softmax activation function to obtain a probability distribution for each number. The model is compiled using the adam optimizer and the sparse categorical cross-entropy loss function.
We have prepared the data and model, now we need to use the training data to train the model. The following command can be used to train the model:
history = model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test))
This code will use 10 epochs (epochs) to train the model and use the test set for validation. After training is complete, we can use the following code to evaluate the model:
test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)
In the output you will see the accuracy metrics on the test set.
After training and evaluating the model, we can use the model to predict unknown data. We can use the following code to predict the label of an image:
import numpy as np
image_index = 7777 # Starting from 0
img = x_test[image_index]
img = np.expand_dims(img, axis=0)
predictions = model.predict(img)
print(predictions)
print("Predicted label :", np.argmax(predictions))
In the output, we can see that the image is predicted to be the number 2.
In this article, we introduced how to use neural networks for classification in Python. We used TensorFlow to build and train the neural network model, and the MNIST dataset for testing and prediction. You can use this model for different categories of image classification tasks and adjust the neural network layers in the model as needed. Classification using neural networks is a very effective method that can easily handle large amounts of data classification, allowing us to perform model development and classification tasks faster.
The above is the detailed content of How to use neural networks for classification in Python?. For more information, please follow other related articles on the PHP Chinese website!