Maison  >  Article  >  développement back-end  >  À propos de la fonction tf.train.batch dans Tensorflow

À propos de la fonction tf.train.batch dans Tensorflow

不言
不言original
2018-04-24 14:13:435667parcourir

Cet article présente principalement l'utilisation de la fonction tf.train.batch dans Tensorflow. Maintenant, je le partage avec vous et le donne comme référence. Venez jeter un œil ensemble

Cela fait deux jours que je regarde la file d'attente de lecture des données dans tensorflow. Pour être honnête, c'est vraiment difficile à comprendre. Peut-être que je n'ai aucune expérience dans ce domaine auparavant. J'ai utilisé Theano au début et j'ai tout écrit moi-même. Après ces deux jours passés à examiner des documents et des informations connexes, j'ai également consulté de jeunes camarades en Chine. J'ai un petit sentiment aujourd'hui. Pour faire simple, le graphique de calcul lit les données d'un pipeline d'entrée utilise une méthode prête à l'emploi, et la même est utilisée pour la lecture. Afin de garantir que la lecture des données d'un canal ne sera pas compliquée lors de l'utilisation de plusieurs threads, des opérations liées à la gestion des threads sont requises lors de la lecture à ce stade. Aujourd'hui, j'ai fait une opération simple au laboratoire, qui consistait à donner une donnée ordonnée et à voir si elle était ordonnée. Il s'est avéré que c'était le cas, j'ai donc donné le code directement :

import tensorflow as tf
import numpy as np

def generate_data():
  num = 25
  label = np.asarray(range(0, num))
  images = np.random.random([num, 5, 5, 3])
  print('label size :{}, image size {}'.format(label.shape, images.shape))
  return label, images

def get_batch_data():
  label, images = generate_data()
  images = tf.cast(images, tf.float32)
  label = tf.cast(label, tf.int32)
  input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
  image_batch, label_batch = tf.train.batch(input_queue, batch_size=10, num_threads=1, capacity=64)
  return image_batch, label_batch

image_batch, label_batch = get_batch_data()
with tf.Session() as sess:
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess, coord)
  i = 0
  try:
    while not coord.should_stop():
      image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
      i += 1
      for j in range(10):
        print(image_batch_v.shape, label_batch_v[j])
  except tf.errors.OutOfRangeError:
    print("done")
  finally:
    coord.request_stop()
  coord.join(threads)
.

Rappelez-vous la méthode slice_input_producer, qui nécessite une lecture aléatoire par défaut.

De plus, je voudrais commenter ce code

1 : il y a un paramètre 'num_epochs' dans slice_input_producer, qui contrôle le nombre d'époques pendant lesquelles la méthode slice_input_producer fonctionnerait. La méthode exécute les époques spécifiées, elle signalerait l'OutOfRangeRrror. Je pense que cela serait utile pour notre contrôle des époques d'entraînement

2 : la sortie de cette méthode est une seule image, nous pourrions l'utiliser. image unique avec l'API tensorflow, telle que la normalisation, les recadrages, etc., alors cette image unique est transmise à la méthode par lots, un lot d'images pour la formation ou les tests serait reçu.

tf La différence entre .train.batch et tf.train.shuffle_batch

tf.train.batch([exemple, étiquette], batch_size=batch_size, capacité=capacité) : [exemple, étiquette ] représente un échantillon et une étiquette d'échantillon, qui peut être un échantillon et une étiquette d'échantillon, et batch_size est le nombre d'échantillons dans un ensemble d'échantillons par lots renvoyé. La capacité est la capacité de la file d'attente. Ceci est principalement combiné en un lot

tf.train.shuffle_batch([example, label], batch_size=batch_size,capacity=capacity, min_after_dequeue) dans l'ordre. Les paramètres ici ont la même signification que ci-dessus. La différence est le paramètre min_after_dequeue. Vous devez vous assurer que ce paramètre est inférieur à la valeur du paramètre de capacité, sinon une erreur se produira. Cela signifie que lorsque les éléments de la file d'attente sont plus grands que lui, un lot désordonné sera généré. En d’autres termes, le résultat de sortie de cette fonction est un lot d’échantillons disposés dans le désordre et non disposés dans l’ordre.

Les valeurs de retour de la fonction ci-dessus sont tous les échantillons et étiquettes d'échantillons d'un lot, mais l'un est en ordre et l'autre est aléatoire

Recommandations associées :

Tensorflow utilise des indicateurs pour définir les paramètres de ligne de commande

Utilisation de l'économiseur de Tensorflow

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn