Maison >développement back-end >Tutoriel Python >Introduction à la méthode de lecture des données personnalisées pour le projet de classificateur Tensorflow (exemple de code)
Cet article vous présente une introduction à la méthode de lecture des données personnalisées pour le projet de classificateur Tensorflow (exemple de code). Il a une certaine valeur de référence. Les amis dans le besoin peuvent s'y référer, j'espère que cela vous aidera. toi.
Lecture des données personnalisées du projet de classificateur Tensorflow
Après avoir tapé le code du projet de classificateur selon la démo sur le site officiel de Tensorflow, l'opération a réussi Pas mal non plus. . Mais en fin de compte, je dois encore entraîner mes propres données, j'ai donc essayé de me préparer au chargement de données personnalisées. Cependant, fashion_mnist.load_data() n'est apparu que dans la démo sans processus de lecture détaillé. Ensuite, j'ai trouvé quelques informations et expliqué le. processus de lecture. Enregistré ici.
Tout d'abord, permettez-moi de mentionner les modules qui doivent être utilisés :
import os import keras import matplotlib.pyplot as plt from PIL import Image from keras.preprocessing.image import ImageDataGenerator from sklearn.model_selection import train_test_split
Projet de classificateur d'images, déterminez d'abord quelle sera la résolution de l'image que vous souhaitez traiter, la exemple voici 30 pixels :
IMG_SIZE_X = 30 IMG_SIZE_Y = 30
Déterminez ensuite le répertoire de vos photos :
image_path = r'D:\Projects\ImageClassifier\data\set' path = ".\data" # 你也可以使用相对路径的方式 # image_path =os.path.join(path, "set")
La structure sous le répertoire est la suivante :
Le label.txt correspondant est le suivant :
动漫 风景 美女 物语 樱花
Ensuite, il est connecté au labels.txt, comme suit :
label_name = "labels.txt" label_path = os.path.join(path, label_name) class_names = np.loadtxt(label_path, type(""))
Pour le bien de simplicité ici, la fonction loadtxt de numpy est directement utilisée pour le charger directement.
Après cela, les données d'image sont officiellement traitées et les commentaires sont écrits à l'intérieur :
re_load = False re_build = False # re_load = True re_build = True data_name = "data.npz" data_path = os.path.join(path, data_name) model_name = "model.h5" model_path = os.path.join(path, model_name) count = 0 # 这里判断是否存在序列化之后的数据,re_load是一个开关,是否强制重新处理,测试用,可以去除。 if not os.path.exists(data_path) or re_load: labels = [] images = [] print('Handle images') # 由于label.txt是和图片防止目录的分类目录一一对应的,即每个子目录的目录名就是labels.txt里的一个label,所以这里可以通过读取class_names的每一项去拼接path后读取 for index, name in enumerate(class_names): # 这里是拼接后的子目录path classpath = os.path.join(image_path, name) # 先判断一下是否是目录 if not os.path.isdir(classpath): continue # limit是测试时候用的这里可以去除 limit = 0 for image_name in os.listdir(classpath): if limit >= max_size: break # 这里是拼接后的待处理的图片path imagepath = os.path.join(classpath, image_name) count = count + 1 limit = limit + 1 # 利用Image打开图片 img = Image.open(imagepath) # 缩放到你最初确定要处理的图片分辨率大小 img = img.resize((IMG_SIZE_X, IMG_SIZE_Y)) # 转为灰度图片,这里彩色通道会干扰结果,并且会加大计算量 img = img.convert("L") # 转为numpy数组 img = np.array(img) # 由(30,30)转为(1,30,30)(即`channels_first`),当然你也可以转换为(30,30,1)(即`channels_last`)但为了之后预览处理后的图片方便这里采用了(1,30,30)的格式存放 img = np.reshape(img, (1, IMG_SIZE_X, IMG_SIZE_Y)) # 这里利用循环生成labels数据,其中存放的实际是class_names中对应元素的索引 labels.append([index]) # 添加到images中,最后统一处理 images.append(img) # 循环中一些状态的输出,可以去除 print("{} class: {} {} limit: {} {}" .format(count, index + 1, class_names[index], limit, imagepath)) # 最后一次性将images和labels都转换成numpy数组 npy_data = np.array(images) npy_labels = np.array(labels) # 处理数据只需要一次,所以我们选择在这里利用numpy自带的方法将处理之后的数据序列化存储 np.savez(data_path, x=npy_data, y=npy_labels) print("Save images by npz") else: # 如果存在序列化号的数据,便直接读取,提高速度 npy_data = np.load(data_path)["x"] npy_labels = np.load(data_path)["y"] print("Load images by npz") image_data = npy_data labels_data = npy_labels
À ce stade, le traitement et le prétraitement des données originales sont terminés. la dernière étape est nécessaire et les résultats renvoyés par fashion_mnist.load_data()
dans la démo sont les mêmes. Le code est le suivant :
# 最后一步就是将原始数据分成训练数据和测试数据 train_images, test_images, train_labels, test_labels = \ train_test_split(image_data, labels_data, test_size=0.2, random_state=6)
La méthode d'impression des informations pertinentes est également jointe ici :
print("_________________________________________________________________") print("%-28s %-s" % ("Name", "Shape")) print("=================================================================") print("%-28s %-s" % ("Image Data", image_data.shape)) print("%-28s %-s" % ("Labels Data", labels_data.shape)) print("=================================================================") print('Split train and test data,p=%') print("_________________________________________________________________") print("%-28s %-s" % ("Name", "Shape")) print("=================================================================") print("%-28s %-s" % ("Train Images", train_images.shape)) print("%-28s %-s" % ("Test Images", test_images.shape)) print("%-28s %-s" % ("Train Labels", train_labels.shape)) print("%-28s %-s" % ("Test Labels", test_labels.shape)) print("=================================================================")
N'oubliez pas de normaliser après cela :
print("Normalize images") train_images = train_images / 255.0 test_images = test_images / 255.0
Enfin, le code complet de lecture des données personnalisées est joint :
import os import keras import matplotlib.pyplot as plt from PIL import Image from keras.layers import * from keras.models import * from keras.optimizers import Adam from keras.preprocessing.image import ImageDataGenerator from sklearn.model_selection import train_test_split os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 支持中文 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 re_load = False re_build = False # re_load = True re_build = True epochs = 50 batch_size = 5 count = 0 max_size = 2000000000 IMG_SIZE_X = 30 IMG_SIZE_Y = 30 np.random.seed(9277) image_path = r'D:\Projects\ImageClassifier\data\set' path = ".\data" data_name = "data.npz" data_path = os.path.join(path, data_name) model_name = "model.h5" model_path = os.path.join(path, model_name) label_name = "labels.txt" label_path = os.path.join(path, label_name) class_names = np.loadtxt(label_path, type("")) print('Load class names') if not os.path.exists(data_path) or re_load: labels = [] images = [] print('Handle images') for index, name in enumerate(class_names): classpath = os.path.join(image_path, name) if not os.path.isdir(classpath): continue limit = 0 for image_name in os.listdir(classpath): if limit >= max_size: break imagepath = os.path.join(classpath, image_name) count = count + 1 limit = limit + 1 img = Image.open(imagepath) img = img.resize((30, 30)) img = img.convert("L") img = np.array(img) img = np.reshape(img, (1, 30, 30)) # img = skimage.io.imread(imagepath, as_grey=True) # if img.shape[2] != 3: # print("{} shape is {}".format(image_name, img.shape)) # continue # data = transform.resize(img, (IMG_SIZE_X, IMG_SIZE_Y)) labels.append([index]) images.append(img) print("{} class: {} {} limit: {} {}" .format(count, index + 1, class_names[index], limit, imagepath)) npy_data = np.array(images) npy_labels = np.array(labels) np.savez(data_path, x=npy_data, y=npy_labels) print("Save images by npz") else: npy_data = np.load(data_path)["x"] npy_labels = np.load(data_path)["y"] print("Load images by npz") image_data = npy_data labels_data = npy_labels print("_________________________________________________________________") print("%-28s %-s" % ("Name", "Shape")) print("=================================================================") print("%-28s %-s" % ("Image Data", image_data.shape)) print("%-28s %-s" % ("Labels Data", labels_data.shape)) print("=================================================================") train_images, test_images, train_labels, test_labels = \ train_test_split(image_data, labels_data, test_size=0.2, random_state=6) print('Split train and test data,p=%') print("_________________________________________________________________") print("%-28s %-s" % ("Name", "Shape")) print("=================================================================") print("%-28s %-s" % ("Train Images", train_images.shape)) print("%-28s %-s" % ("Test Images", test_images.shape)) print("%-28s %-s" % ("Train Labels", train_labels.shape)) print("%-28s %-s" % ("Test Labels", test_labels.shape)) print("=================================================================") # 归一化 # 我们将这些值缩小到 0 到 1 之间,然后将其馈送到神经网络模型。为此,将图像组件的数据类型从整数转换为浮点数,然后除以 255。以下是预处理图像的函数: # 务必要以相同的方式对训练集和测试集进行预处理: print("Normalize images") train_images = train_images / 255.0 test_images = test_images / 255.0
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!