Maison > Article > développement back-end > Comment générer et lire des fichiers Tensorflow TFRecords
Cet article présente principalement la méthode de génération et de lecture des fichiers Tensorflow TFRecords. Il a une certaine valeur de référence. Maintenant, je le partage avec vous. Les amis dans le besoin peuvent s'y référer
TensorFlow fournit le format TFRecords. les données de manière uniforme. En théorie, TFRecords peut stocker n'importe quelle forme de données.
Les données du fichier TFRecords sont stockées au format tf.train.Example Protocol Buffer. Le code suivant donne la définition de tf.train.Example.
message Example { Features features = 1; }; message Features { map<string, Feature> feature = 1; }; message Feature { oneof kind { BytesList bytes_list = 1; FloatList float_list = 2; Int64List int64_list = 3; } };
Ce qui suit présente comment générer et lire des fichiers tfrecords :
Introduisez d'abord la génération de tfrecords , allez directement au code :
from random import shuffle import numpy as np import glob import tensorflow as tf import cv2 import sys import os # 因为我装的是CPU版本的,运行起来会有'warning',解决方法入下,眼不见为净~ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' shuffle_data = True image_path = '/path/to/image/*.jpg' # 取得该路径下所有图片的路径,type(addrs)= list addrs = glob.glob(image_path) # 标签数据的获得具体情况具体分析,type(labels)= list labels = ... # 这里是打乱数据的顺序 if shuffle_data: c = list(zip(addrs, labels)) shuffle(c) addrs, labels = zip(*c) # 按需分割数据集 train_addrs = addrs[0:int(0.7*len(addrs))] train_labels = labels[0:int(0.7*len(labels))] val_addrs = addrs[int(0.7*len(addrs)):int(0.9*len(addrs))] val_labels = labels[int(0.7*len(labels)):int(0.9*len(labels))] test_addrs = addrs[int(0.9*len(addrs)):] test_labels = labels[int(0.9*len(labels)):] # 上面不是获得了image的地址么,下面这个函数就是根据地址获取图片 def load_image(addr): # A function to Load image img = cv2.imread(addr) img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 这里/255是为了将像素值归一化到[0,1] img = img / 255. img = img.astype(np.float32) return img # 将数据转化成对应的属性 def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _float_feature(value): return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) # 下面这段就开始把数据写入TFRecods文件 train_filename = '/path/to/train.tfrecords' # 输出文件地址 # 创建一个writer来写 TFRecords 文件 writer = tf.python_io.TFRecordWriter(train_filename) for i in range(len(train_addrs)): # 这是写入操作可视化处理 if not i % 1000: print('Train data: {}/{}'.format(i, len(train_addrs))) sys.stdout.flush() # 加载图片 img = load_image(train_addrs[i]) label = train_labels[i] # 创建一个属性(feature) feature = {'train/label': _int64_feature(label), 'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))} # 创建一个 example protocol buffer example = tf.train.Example(features=tf.train.Features(feature=feature)) # 将上面的example protocol buffer写入文件 writer.write(example.SerializeToString()) writer.close() sys.stdout.flush()
Ce qui précède présente uniquement la génération du fichier train.tfrecords, et le reste de la validation et le test permettront de tirer des conclusions. .
Ensuite, nous présenterons la lecture des fichiers tfrecords :
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' data_path = 'train.tfrecords' # tfrecords 文件的地址 with tf.Session() as sess: # 先定义feature,这里要和之前创建的时候保持一致 feature = { 'train/image': tf.FixedLenFeature([], tf.string), 'train/label': tf.FixedLenFeature([], tf.int64) } # 创建一个队列来维护输入文件列表 filename_queue = tf.train.string_input_producer([data_path], num_epochs=1) # 定义一个 reader ,读取下一个 record reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) # 解析读入的一个record features = tf.parse_single_example(serialized_example, features=feature) # 将字符串解析成图像对应的像素组 image = tf.decode_raw(features['train/image'], tf.float32) # 将标签转化成int32 label = tf.cast(features['train/label'], tf.int32) # 这里将图片还原成原来的维度 image = tf.reshape(image, [224, 224, 3]) # 你还可以进行其他一些预处理.... # 这里是创建顺序随机 batches(函数不懂的自行百度) images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, min_after_dequeue=10) # 初始化 init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) # 启动多线程处理输入数据 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) .... #关闭线程 coord.request_stop() coord.join(threads) sess.close()
Recommandations associées :
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!