Home  >  Article  >  Backend Development  >  About the tf.train.batch function in Tensorflow

About the tf.train.batch function in Tensorflow

不言
不言Original
2018-04-24 14:13:435723browse

This article mainly introduces the use of the tf.train.batch function in Tensorflow. Now I will share it with you and give you a reference. Let's take a look together

I have been looking at the queue for reading data in tensorflow for the past two days. To be honest, it is really difficult to understand. Maybe I have no experience in this area before. I used Theano at the beginning and wrote everything myself. After these two days of reviewing documents and related information, I also consulted with junior fellow students in China. I have a little bit of a feeling today. To put it simply, the calculation graph reads data from a pipeline. The input pipeline uses a ready-made method, and the same is used for reading. In order to ensure that reading data from a pipe will not be messy when using multiple threads, thread management-related operations are required when reading at this time. Today I did a simple operation in the lab, which was to give an ordered data and see if it was ordered. It turned out that it was in order, so I gave the code directly:

import tensorflow as tf
import numpy as np

def generate_data():
  num = 25
  label = np.asarray(range(0, num))
  images = np.random.random([num, 5, 5, 3])
  print('label size :{}, image size {}'.format(label.shape, images.shape))
  return label, images

def get_batch_data():
  label, images = generate_data()
  images = tf.cast(images, tf.float32)
  label = tf.cast(label, tf.int32)
  input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
  image_batch, label_batch = tf.train.batch(input_queue, batch_size=10, num_threads=1, capacity=64)
  return image_batch, label_batch

image_batch, label_batch = get_batch_data()
with tf.Session() as sess:
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess, coord)
  i = 0
  try:
    while not coord.should_stop():
      image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
      i += 1
      for j in range(10):
        print(image_batch_v.shape, label_batch_v[j])
  except tf.errors.OutOfRangeError:
    print("done")
  finally:
    coord.request_stop()
  coord.join(threads)

Remember the slice_input_producer method, which requires shuffle by default.

Besides, I would like to comment this code.

1: there is a parameter 'num_epochs' in slice_input_producer, which controls how many epochs the slice_input_producer method would work. when this method runs the specified epochs, it would report the OutOfRangeRrror. I think it would be useful for our control the training epochs.

2: the output of this method is one single image, we could operate this single image with tensorflow API, such as normalization, crops, and so on, then this single image is feed to batch method, a batch of images for training or testing wouldbe received.

tf The difference between .train.batch and tf.train.shuffle_batch

tf.train.batch([example, label], batch_size=batch_size, capacity=capacity): [example, label ] represents a sample and a sample label, which can be a sample and a sample label, and batch_size is the number of samples in a batch sample set returned. capacity is the capacity in the queue. This is mainly combined into a batch in order

tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue). The parameters here have the same meaning as above. The difference is the parameter min_after_dequeue. You must ensure that this parameter is smaller than the value of the capacity parameter, otherwise an error will occur. This means that when the elements in the queue are larger than it, a disordered batch will be output. In other words, the output result of this function is a batch of samples arranged out of order, not arranged in order.

The above function return values ​​are all samples and sample labels of a batch, but one is in order and the other is random

Related recommendations:

tensorflow How to use flags to define command line parameters

Usage of Tensorflow's Saver

The above is the detailed content of About the tf.train.batch function in Tensorflow. For more information, please follow other related articles on the PHP Chinese website!

Statement:
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn