>백엔드 개발 >파이썬 튜토리얼 >연구 노트 TF024: TensorFlow는 손으로 쓴 숫자를 인식하기 위해 Softmax 회귀(회귀)를 구현합니다.

연구 노트 TF024: TensorFlow는 손으로 쓴 숫자를 인식하기 위해 Softmax 회귀(회귀)를 구현합니다.

PHP中文网
PHP中文网원래의
2017-07-10 18:13:291393검색

TensorFlow는 Softmax Regression을 구현하여 손으로 쓴 숫자를 인식합니다. MNIST (Mixed National Institute of Standards and Technology 데이터베이스), 단순 머신 비전 데이터 세트, 28X28 픽셀 손으로 쓴 숫자, 회색조 값 정보만, 공백 부분은 0, 손글씨는 색 농도에 따라 [0, 1]에서 가져옴 , 784차원, 2차원 공간정보를 버리고, 대상을 0부터 9까지 10가지 범주로 나눈다. 데이터 로딩, data.read_data_sets, 55,000개 샘플, 테스트 세트 10,000개 샘플, 검증 세트 5,000개 샘플. 샘플 주석 정보, 라벨, 10차원 벡터, 10가지 원-핫 인코딩. 훈련 세트는 모델을 훈련시키고, 검증 세트는 효과를 테스트하며, 테스트 세트는 모델을 평가합니다(정확도, 재현율, F1 점수).

알고리즘 설계, Softmax Regression은 필기 숫자 인식 분류 모델을 학습하고 범주 확률을 추정하며 최대 확률 수를 모델 출력 결과로 사용합니다. 클래스 확률을 결정하기 위해 클래스 기능이 추가됩니다. 모델 학습 및 훈련은 가중치를 조정합니다. Softmax, 다양한 특성 계산 exp 함수, 표준화됨(모든 범주의 출력 확률 값은 1임) y = 소프트맥스(Wx+b).

NumPy는 C, 포트란을 사용하고 openblas 및 mkl 행렬 연산 라이브러리를 호출합니다. TensorFlow의 조밀하고 복잡한 작업은 Python 외부에서 수행됩니다. 계산 그래프를 정의합니다. 계산 작업은 계산된 데이터를 매번 Python 외부에서 실행할 필요가 없습니다.

텐서 흐름을 tf로 가져오고 TensorFlow 라이브러리를 로드합니다. less = tf.InteractiveSession(), InteractiveSession을 생성하고 이를 기본 세션으로 등록합니다. 서로 다른 세션의 데이터와 작업은 서로 독립적입니다. x = tf.placeholder(tf.float32, [None,784]), 입력 데이터를 수신할 플레이스홀더를 생성합니다. 첫 번째 매개변수는 데이터 유형이고 두 번째 매개변수는 텐서 형태 데이터 크기를 나타냅니다. 없음 입력 수에는 제한이 없으며 각 입력은 784차원 벡터입니다.

Tensor는 데이터를 저장하고 사용하면 사라집니다. 변수는 모델 학습 반복에서 지속적이고 오랫동안 존재하며 각 반복에서 업데이트됩니다. Softmax Regression 모델의 Variable 객체 가중치와 편향은 0으로 초기화됩니다. 모델 훈련은 적절한 값을 자동으로 학습합니다. 복잡한 네트워크의 경우 초기화 방법이 중요합니다. w = tf.Variable(tf.zeros([784, 10])), 784개의 특징 차원, 10개의 카테고리. 라벨은 원-핫 인코딩 후의 10차원 벡터입니다.

Softmax 회귀 알고리즘, y = tf.nn.softmax(tf.matmul(x, W) + b). tf.nn에는 수많은 신경망 구성 요소가 포함되어 있습니다. tf.matmul, 행렬 곱셈 함수. TensorFlow는 손실이 정의된 한 자동으로 전방 및 후방 콘텐츠를 구현하며, 훈련은 자동으로 경사하강법을 도출하고 Softmax Regression 모델 매개변수의 자동 학습을 완료합니다.

문제 모델의 분류 정확도를 설명하기 위해 손실 함수를 정의합니다. Loss가 작을수록 모델 분류 결과가 실제 값과 비교되는 값이 작아지고 정확도가 높아집니다. 모델의 초기 매개변수는 모두 0이므로 초기 손실이 발생합니다. 훈련 목표는 손실을 줄이고 전역 최적 또는 국소 최적 솔루션을 찾는 것입니다. 교차 엔트로피, 손실 함수는 분류 문제에 일반적으로 사용됩니다. y 예측 확률 분포, y' 실제 확률 분포(레이블 원-핫 인코딩)는 실제 확률 분포에 대한 모델 예측의 정확도를 결정합니다. cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), 감소_indices=[1])). 자리 표시자를 정의하고 실제 레이블을 입력합니다. tf.reduce_sum은 합계를 계산하고, tf.reduce_mean은 각 배치 데이터 결과의 평균을 계산합니다.

최적화 알고리즘인 확률적 경사하강법 SGD(Stochastic Gradient Descent)를 정의합니다. 계산 그래프를 기반으로 한 자동 도출, Back Propagation 알고리즘을 기반으로 한 학습, 각 라운드마다 매개변수를 반복적으로 업데이트하여 손실을 줄입니다. 각 라운드에서 피드 데이터를 반복하기 위해 캡슐화된 최적화 프로그램이 제공됩니다. TensorFlow는 백그라운드에서 작업을 자동으로 보완하여 역전파 및 경사하강법을 구현합니다. train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy). tf.train.GradientDescentOptimizer를 호출하고, 학습 속도를 0.5로 설정하고, 최적화 목표 교차 엔트로피를 설정하고, 훈련 작업 train_step을 가져옵니다.

tf.global_variables_initializer().run(). TensorFlow 전역 매개변수 초기화 프로그램 tf.golbal_variables_initializer.

batch_xs,batch_ys = mnist.train.next_batch(100). 훈련작업 train_step. 매번 훈련 세트에서 100개의 샘플이 무작위로 선택되어 미니 배치를 형성하고 자리 표시자에 공급되며 train_step 훈련 샘플이 호출됩니다. 훈련에 적은 수의 샘플을 사용하면 확률적 경사하강법을 사용하여 수렴이 더 빨라집니다. 모든 샘플은 매번 학습되므로 많은 양의 계산이 필요하고 로컬 최적에서 벗어나기가 어렵습니다.

corright_prediction = tf.equal(tf.argmax(y,1), tf.argmzx(y_,1)), 모델 정확도를 확인합니다. tf.argmax는 텐서에서 최대값 시퀀스 번호를 찾고, tf.argmax(y,1)는 최대 예측 번호 확률을 찾고, tf.argmax(y_,1)은 샘플의 실제 번호 범주를 찾습니다. tf.equal은 예측된 숫자 범주가 올바른지 확인하고 계산 분류 연산이 올바른지 여부를 반환합니다.

accuracy = tf.reduce_mean(tf.cast(corright_prediction,tf.float32))는 모든 샘플의 예측 정확도를 계산합니다. tf.cast는 right_prediction 출력 값 유형을 변환합니다.

print(accuracy.eval({x: mnist.test.images,y_: mnist.test.labels})). 데이터 특성 테스트, 라벨 입력 평가 프로세스, 모델 테스트 세트의 정확도 계산. Softmax Regression MNIST 데이터 분류 및 인식, 테스트 세트의 평균 정확도는 약 92%입니다.

TensorFlow는 간단한 기계 알고리즘 단계를 구현합니다.
1、알고리즘 공식과 신경망 순방향 계산을 정의합니다.
2、손실을 정의하고, 옵티마이저를 선택하고, 옵티마이저를 지정하여 손실을 최적화합니다.
3、반복적인 훈련 데이터.
4、테스트 세트와 검증 세트의 정확도를 평가합니다.

정의된 공식은 단지 계산 그래프일 뿐입니다. 실행 메소드와 피드 데이터가 호출될 때만 계산이 실행됩니다.

으아악


참고 자료:
"TensorFlow 실습"

상담 비용 지불을 환영합니다(시간당 150위안), my WeChat: qingxingfengzi

위 내용은 연구 노트 TF024: TensorFlow는 손으로 쓴 숫자를 인식하기 위해 Softmax 회귀(회귀)를 구현합니다.의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.