>  기사  >  백엔드 개발  >  Tensorflow의 tf.train.batch 함수 정보

Tensorflow의 tf.train.batch 함수 정보

不言
不言원래의
2018-04-24 14:13:435723검색

이 글에서는 주로 Tensorflow의 tf.train.batch 함수 사용법을 소개하고 있으며, 참고용으로 공유해 드립니다. 같이 구경하러 가세요

지난 이틀 동안 텐서플로우에서 데이터 읽기 큐를 살펴봤는데 솔직히 정말 이해하기 어렵습니다. 아마도 나는 처음에는 Theano를 사용했고 모든 것을 직접 작성해 본 적이 없었을 것입니다. 이틀 동안 서류와 관련 정보를 검토한 후 중국에 있는 후배들과도 상담을 했습니다. 오늘은 좀 기분이 좋네요. 간단히 말하면, 계산 그래프는 파이프라인에서 데이터를 읽어옵니다. 입력 파이프라인은 이미 만들어진 방법을 사용하며, 읽기도 마찬가지입니다. 여러 스레드를 사용할 때 파이프에서 데이터를 읽는 것이 지저분해지지 않도록 이때 읽을 때 스레드 관리 관련 작업이 필요합니다. 오늘은 주문한 데이터를 제공하고 주문되었는지 확인하는 간단한 작업을 수행했습니다. 순서대로 코드를 제공했습니다.

import tensorflow as tf
import numpy as np

def generate_data():
  num = 25
  label = np.asarray(range(0, num))
  images = np.random.random([num, 5, 5, 3])
  print('label size :{}, image size {}'.format(label.shape, images.shape))
  return label, images

def get_batch_data():
  label, images = generate_data()
  images = tf.cast(images, tf.float32)
  label = tf.cast(label, tf.int32)
  input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
  image_batch, label_batch = tf.train.batch(input_queue, batch_size=10, num_threads=1, capacity=64)
  return image_batch, label_batch

image_batch, label_batch = get_batch_data()
with tf.Session() as sess:
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess, coord)
  i = 0
  try:
    while not coord.should_stop():
      image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
      i += 1
      for j in range(10):
        print(image_batch_v.shape, label_batch_v[j])
  except tf.errors.OutOfRangeError:
    print("done")
  finally:
    coord.request_stop()
  coord.join(threads)

slice_input_producer를 기억하세요. 방법, 기본값은 섞는 것입니다.

게다가 이 코드에 대해 설명하고 싶습니다.

1: 이 메서드가 지정된 에포크를 실행할 때 슬라이스_input_producer 메서드가 작동하는 에포크 수를 제어하는 ​​'num_epochs' 매개변수가 있습니다. OutOfRangeRrror. 훈련 에포크를 제어하는 ​​데 유용할 것 같습니다.

2: 이 방법의 출력은 하나의 단일 이미지이므로 정규화, 자르기 등과 같은 tensorflow API를 사용하여 이 단일 이미지를 작동할 수 있습니다. 그러면 이 단일 이미지는 배치 방식으로 피드되며, 훈련 또는 테스트용 이미지 배치가 수신됩니다. 예: label], 배치_크기=배치_크기, 용량=용량): [예: 레이블]은 샘플 및 샘플 레이블을 나타냅니다. 샘플이고 샘플 라벨인 배치_크기는 반환된 배치 샘플 세트의 샘플 수입니다. 용량은 대기열의 용량입니다. 이는 주로

tf.train.shuffle_batch([예, 라벨], 배치_크기=batch_size, 용량=용량, min_after_dequeue) 순서로 배치로 결합됩니다. 여기의 매개변수는 위와 동일한 의미를 갖습니다. 차이점은 min_after_dequeue 매개변수입니다. 이 매개변수가 용량 매개변수의 값보다 작은지 확인해야 합니다. 그렇지 않으면 오류가 발생합니다. 이는 대기열의 요소가 그보다 크면 무질서한 배치가 출력된다는 것을 의미합니다. 즉, 이 함수의 출력 결과는 순서대로 정렬된 샘플이 아닌 순서대로 정렬된 샘플 묶음입니다.

위 함수 반환 값은 모두 일괄 처리의 샘플 및 샘플 레이블이지만 하나는 순서대로이고 다른 하나는 무작위입니다

관련 권장 사항:

tensorflow 플래그를 사용하여 명령줄 매개 변수를 정의하는 방법


Tensorflow의 Saver 사용

위 내용은 Tensorflow의 tf.train.batch 함수 정보의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.