MNIST 데이터 세트는 손으로 쓴 숫자로 구성되며 60,000개의 훈련 샘플과 10,000개의 테스트 샘플을 포함합니다. 각 샘플은 0부터 9까지의 숫자를 나타내는 28x28픽셀 회색조 이미지입니다.
CNN(Convolutional Neural Network)은 딥러닝에서 이미지 분류에 사용되는 모델입니다. 컨볼루셔널 레이어와 풀링 레이어를 통해 이미지 특징을 추출하고 분류를 위해 완전 연결 레이어를 사용합니다.
아래에서는 Python과 TensorFlow를 사용하여 MNIST 데이터 세트를 분류하는 간단한 CNN 모델을 구현하는 방법을 소개하겠습니다.
먼저 필요한 라이브러리와 MNIST 데이터 세트를 가져와야 합니다.
import tensorflow as tf from tensorflow.keras.datasets import mnist # 加载MNIST数据集 (x_train, y_train), (x_test, y_test) = mnist.load_data()
다음으로 이미지 데이터를 정규화하고 레이블 데이터를 원-핫 인코딩 형식으로 변환해야 합니다.
# 归一化图像数据 x_train = x_train / 255.0 x_test = x_test / 255.0 # 将标签数据转换为独热编码格式 y_train = tf.keras.utils.to_categorical(y_train, num_classes=10) y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
그런 다음 CNN 모델. 이 모델에는 두 개의 컨벌루션 레이어와 두 개의 풀링 레이어, 그리고 완전 연결 레이어가 포함되어 있습니다. 분류를 위해 마지막 계층에서는 ReLU 활성화 함수와 Softmax 활성화 함수를 사용합니다. 코드는 다음과 같습니다.
model = tf.keras.models.Sequential([ # 第一个卷积层 tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)), tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), # 第二个卷积层 tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation='relu'), tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), # 将特征图展平 tf.keras.layers.Flatten(), # 全连接层 tf.keras.layers.Dense(units=128, activation='relu'), # 输出层 tf.keras.layers.Dense(units=10, activation='softmax') ])
다음으로 모델을 컴파일하고 손실 함수, 최적화 도구 및 평가 지표를 지정해야 합니다.
# 编译模型 model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
마지막으로 모델을 훈련하고 테스트합니다.
# 训练模型 model.fit(x_train.reshape(-1, 28, 28, 1), y_train, epochs=5, batch_size=32) # 测试模型 score = model.evaluate(x_test.reshape(-1, 28, 28, 1), y_test, verbose=0) print('Test loss:', score[0]) print('Test accuracy:', score[1])
전체 코드를 실행한 후, 모델을 볼 수 있습니다. 테스트 정확도는 약 99%입니다.
요약하자면, 컨볼루션 신경망을 사용하여 MNIST 데이터세트를 분류하는 단계는 다음과 같습니다.
1 MNIST 데이터세트를 로드하고 정규화 및 원-핫 인코딩을 포함한 전처리를 수행합니다.
2 . 컨벌루션 레이어, 풀링 레이어 및 완전 연결 레이어를 포함하는 CNN 모델을 정의하고
3 모델을 컴파일하고 옵티마이저 및 평가 지수를 지정합니다.
4. 모델을 만들고 테스트 세트에서 테스트해 보세요.
위는 구체적인 상황에 따라 수정 및 최적화가 가능한 간단한 예시입니다.
위 내용은 컨벌루션 신경망을 사용하여 손으로 쓴 숫자 인식의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!