cari
Rumahpembangunan bahagian belakangTutorial PythonTensorflow分类器项目自定义数据读入的方法介绍(代码示例)

本篇文章给大家带来的内容是关于Tensorflow分类器项目自定义数据读入的方法介绍(代码示例),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助。

Tensorflow分类器项目自定义数据读入

在照着Tensorflow官网的demo敲了一遍分类器项目的代码后,运行倒是成功了,结果也不错。但是最终还是要训练自己的数据,所以尝试准备加载自定义的数据,然而demo中只是出现了fashion_mnist.load_data()并没有详细的读取过程,随后我又找了些资料,把读取的过程记录在这里。

首先提一下需要用到的模块:

import os

import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

图片分类器项目,首先确定你要处理的图片分辨率将是多少,这里的例子为30像素:

IMG_SIZE_X = 30
IMG_SIZE_Y = 30

其次确定你图片的方式目录:

image_path = r'D:\Projects\ImageClassifier\data\set'
path = ".\data"
# 你也可以使用相对路径的方式
# image_path =os.path.join(path, "set")

目录下的结构如下:

210901301-5c57a5e09df2b_articlex.png

相应的label.txt如下:

动漫
风景
美女
物语
樱花

接下来是接在labels.txt,如下:

label_name = "labels.txt"
label_path = os.path.join(path, label_name)
class_names = np.loadtxt(label_path, type(""))

这里简便起见,直接利用了numpy的loadtxt函数直接加载。

之后便是正式处理图片数据了,注释就写在里面了:

re_load = False
re_build = False
# re_load = True
re_build = True

data_name = "data.npz"
data_path = os.path.join(path, data_name)
model_name = "model.h5"
model_path = os.path.join(path, model_name)

count = 0

# 这里判断是否存在序列化之后的数据,re_load是一个开关,是否强制重新处理,测试用,可以去除。
if not os.path.exists(data_path) or re_load:
    labels = []
    images = []
    print('Handle images')
    # 由于label.txt是和图片防止目录的分类目录一一对应的,即每个子目录的目录名就是labels.txt里的一个label,所以这里可以通过读取class_names的每一项去拼接path后读取
    for index, name in enumerate(class_names):
        # 这里是拼接后的子目录path
        classpath = os.path.join(image_path, name)
        # 先判断一下是否是目录
        if not os.path.isdir(classpath):
            continue
        # limit是测试时候用的这里可以去除
        limit = 0
        for image_name in os.listdir(classpath):
            if limit >= max_size:
                break
            # 这里是拼接后的待处理的图片path
            imagepath = os.path.join(classpath, image_name)
            count = count + 1
            limit = limit + 1
            # 利用Image打开图片
            img = Image.open(imagepath)
            # 缩放到你最初确定要处理的图片分辨率大小
            img = img.resize((IMG_SIZE_X, IMG_SIZE_Y))
            # 转为灰度图片,这里彩色通道会干扰结果,并且会加大计算量
            img = img.convert("L")
            # 转为numpy数组
            img = np.array(img)
            # 由(30,30)转为(1,30,30)(即`channels_first`),当然你也可以转换为(30,30,1)(即`channels_last`)但为了之后预览处理后的图片方便这里采用了(1,30,30)的格式存放
            img = np.reshape(img, (1, IMG_SIZE_X, IMG_SIZE_Y))
            # 这里利用循环生成labels数据,其中存放的实际是class_names中对应元素的索引
            labels.append([index])
            # 添加到images中,最后统一处理
            images.append(img)
            # 循环中一些状态的输出,可以去除
            print("{} class: {} {} limit: {} {}"
                  .format(count, index + 1, class_names[index], limit, imagepath))
    # 最后一次性将images和labels都转换成numpy数组
    npy_data = np.array(images)
    npy_labels = np.array(labels)
    # 处理数据只需要一次,所以我们选择在这里利用numpy自带的方法将处理之后的数据序列化存储
    np.savez(data_path, x=npy_data, y=npy_labels)
    print("Save images by npz")
else:
    # 如果存在序列化号的数据,便直接读取,提高速度
    npy_data = np.load(data_path)["x"]
    npy_labels = np.load(data_path)["y"]
    print("Load images by npz")
image_data = npy_data
labels_data = npy_labels

到了这里原始数据的加工预处理便已经完成,只需要最后一步,就和demo中fashion_mnist.load_data()返回的结果一样了。代码如下:

# 最后一步就是将原始数据分成训练数据和测试数据
train_images, test_images, train_labels, test_labels = \
    train_test_split(image_data, labels_data, test_size=0.2, random_state=6)

这里将相关信息打印的方法也附上:

print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Image Data", image_data.shape))
print("%-28s %-s" % ("Labels Data", labels_data.shape))
print("=================================================================")

print('Split train and test data,p=%')
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Train Images", train_images.shape))
print("%-28s %-s" % ("Test Images", test_images.shape))
print("%-28s %-s" % ("Train Labels", train_labels.shape))
print("%-28s %-s" % ("Test Labels", test_labels.shape))
print("=================================================================")

之后别忘了归一化哟:

print("Normalize images")
train_images = train_images / 255.0
test_images = test_images / 255.0

最后附上读取自定义数据的完整代码:

import os

import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.layers import *
from keras.models import *
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
re_load = False
re_build = False
# re_load = True
re_build = True
epochs = 50
batch_size = 5
count = 0
max_size = 2000000000
IMG_SIZE_X = 30
IMG_SIZE_Y = 30
np.random.seed(9277)
image_path = r'D:\Projects\ImageClassifier\data\set'
path = ".\data"
data_name = "data.npz"
data_path = os.path.join(path, data_name)
model_name = "model.h5"
model_path = os.path.join(path, model_name)
label_name = "labels.txt"
label_path = os.path.join(path, label_name)
class_names = np.loadtxt(label_path, type(""))
print('Load class names')
if not os.path.exists(data_path) or re_load:
    labels = []
    images = []
    print('Handle images')
    for index, name in enumerate(class_names):
        classpath = os.path.join(image_path, name)
        if not os.path.isdir(classpath):
            continue
        limit = 0
        for image_name in os.listdir(classpath):
            if limit >= max_size:
                break
            imagepath = os.path.join(classpath, image_name)
            count = count + 1
            limit = limit + 1
            img = Image.open(imagepath)
            img = img.resize((30, 30))
            img = img.convert("L")
            img = np.array(img)
            img = np.reshape(img, (1, 30, 30))
            # img = skimage.io.imread(imagepath, as_grey=True)
            # if img.shape[2] != 3:
            #     print("{} shape is {}".format(image_name, img.shape))
            #     continue
            # data = transform.resize(img, (IMG_SIZE_X, IMG_SIZE_Y))
            labels.append([index])
            images.append(img)
            print("{} class: {} {} limit: {} {}"
                  .format(count, index + 1, class_names[index], limit, imagepath))
    npy_data = np.array(images)
    npy_labels = np.array(labels)
    np.savez(data_path, x=npy_data, y=npy_labels)
    print("Save images by npz")
else:
    npy_data = np.load(data_path)["x"]
    npy_labels = np.load(data_path)["y"]
    print("Load images by npz")
image_data = npy_data
labels_data = npy_labels
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Image Data", image_data.shape))
print("%-28s %-s" % ("Labels Data", labels_data.shape))
print("=================================================================")
train_images, test_images, train_labels, test_labels = \
    train_test_split(image_data, labels_data, test_size=0.2, random_state=6)
print('Split train and test data,p=%')
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Train Images", train_images.shape))
print("%-28s %-s" % ("Test Images", test_images.shape))
print("%-28s %-s" % ("Train Labels", train_labels.shape))
print("%-28s %-s" % ("Test Labels", test_labels.shape))
print("=================================================================")

# 归一化
# 我们将这些值缩小到 0 到 1 之间,然后将其馈送到神经网络模型。为此,将图像组件的数据类型从整数转换为浮点数,然后除以 255。以下是预处理图像的函数:
# 务必要以相同的方式对训练集和测试集进行预处理:
print("Normalize images")
train_images = train_images / 255.0
test_images = test_images / 255.0

Atas ialah kandungan terperinci Tensorflow分类器项目自定义数据读入的方法介绍(代码示例). Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Kenyataan
Artikel ini dikembalikan pada:segmentfault. Jika ada pelanggaran, sila hubungi admin@php.cn Padam
Python vs C: Lengkung pembelajaran dan kemudahan penggunaanPython vs C: Lengkung pembelajaran dan kemudahan penggunaanApr 19, 2025 am 12:20 AM

Python lebih mudah dipelajari dan digunakan, manakala C lebih kuat tetapi kompleks. 1. Sintaks Python adalah ringkas dan sesuai untuk pemula. Penaipan dinamik dan pengurusan memori automatik menjadikannya mudah digunakan, tetapi boleh menyebabkan kesilapan runtime. 2.C menyediakan kawalan peringkat rendah dan ciri-ciri canggih, sesuai untuk aplikasi berprestasi tinggi, tetapi mempunyai ambang pembelajaran yang tinggi dan memerlukan memori manual dan pengurusan keselamatan jenis.

Python vs C: Pengurusan dan Kawalan MemoriPython vs C: Pengurusan dan Kawalan MemoriApr 19, 2025 am 12:17 AM

Python dan C mempunyai perbezaan yang signifikan dalam pengurusan dan kawalan memori. 1. Python menggunakan pengurusan memori automatik, berdasarkan pengiraan rujukan dan pengumpulan sampah, memudahkan kerja pengaturcara. 2.C memerlukan pengurusan memori manual, memberikan lebih banyak kawalan tetapi meningkatkan risiko kerumitan dan kesilapan. Bahasa mana yang harus dipilih harus berdasarkan keperluan projek dan timbunan teknologi pasukan.

Python untuk pengkomputeran saintifik: rupa terperinciPython untuk pengkomputeran saintifik: rupa terperinciApr 19, 2025 am 12:15 AM

Aplikasi Python dalam pengkomputeran saintifik termasuk analisis data, pembelajaran mesin, simulasi berangka dan visualisasi. 1.Numpy menyediakan susunan pelbagai dimensi yang cekap dan fungsi matematik. 2. Scipy memanjangkan fungsi numpy dan menyediakan pengoptimuman dan alat algebra linear. 3. Pandas digunakan untuk pemprosesan dan analisis data. 4.Matplotlib digunakan untuk menghasilkan pelbagai graf dan hasil visual.

Python dan C: Mencari alat yang betulPython dan C: Mencari alat yang betulApr 19, 2025 am 12:04 AM

Sama ada untuk memilih Python atau C bergantung kepada keperluan projek: 1) Python sesuai untuk pembangunan pesat, sains data, dan skrip kerana sintaks ringkas dan perpustakaan yang kaya; 2) C sesuai untuk senario yang memerlukan prestasi tinggi dan kawalan asas, seperti pengaturcaraan sistem dan pembangunan permainan, kerana kompilasi dan pengurusan memori manualnya.

Python untuk sains data dan pembelajaran mesinPython untuk sains data dan pembelajaran mesinApr 19, 2025 am 12:02 AM

Python digunakan secara meluas dalam sains data dan pembelajaran mesin, terutamanya bergantung pada kesederhanaannya dan ekosistem perpustakaan yang kuat. 1) PANDAS digunakan untuk pemprosesan dan analisis data, 2) Numpy menyediakan pengiraan berangka yang cekap, dan 3) SCIKIT-Learn digunakan untuk pembinaan dan pengoptimuman model pembelajaran mesin, perpustakaan ini menjadikan Python alat yang ideal untuk sains data dan pembelajaran mesin.

Pembelajaran Python: Adakah 2 jam kajian harian mencukupi?Pembelajaran Python: Adakah 2 jam kajian harian mencukupi?Apr 18, 2025 am 12:22 AM

Adakah cukup untuk belajar Python selama dua jam sehari? Ia bergantung pada matlamat dan kaedah pembelajaran anda. 1) Membangunkan pelan pembelajaran yang jelas, 2) Pilih sumber dan kaedah pembelajaran yang sesuai, 3) mengamalkan dan mengkaji semula dan menyatukan amalan tangan dan mengkaji semula dan menyatukan, dan anda secara beransur-ansur boleh menguasai pengetahuan asas dan fungsi lanjutan Python dalam tempoh ini.

Python untuk Pembangunan Web: Aplikasi UtamaPython untuk Pembangunan Web: Aplikasi UtamaApr 18, 2025 am 12:20 AM

Aplikasi utama Python dalam pembangunan web termasuk penggunaan kerangka Django dan Flask, pembangunan API, analisis data dan visualisasi, pembelajaran mesin dan AI, dan pengoptimuman prestasi. 1. Rangka Kerja Django dan Flask: Django sesuai untuk perkembangan pesat aplikasi kompleks, dan Flask sesuai untuk projek kecil atau sangat disesuaikan. 2. Pembangunan API: Gunakan Flask atau DjangorestFramework untuk membina Restfulapi. 3. Analisis Data dan Visualisasi: Gunakan Python untuk memproses data dan memaparkannya melalui antara muka web. 4. Pembelajaran Mesin dan AI: Python digunakan untuk membina aplikasi web pintar. 5. Pengoptimuman Prestasi: Dioptimumkan melalui pengaturcaraan, caching dan kod tak segerak

Python vs C: Meneroka Prestasi dan KecekapanPython vs C: Meneroka Prestasi dan KecekapanApr 18, 2025 am 12:20 AM

Python lebih baik daripada C dalam kecekapan pembangunan, tetapi C lebih tinggi dalam prestasi pelaksanaan. 1. Sintaks ringkas Python dan perpustakaan yang kaya meningkatkan kecekapan pembangunan. 2. Ciri-ciri jenis kompilasi dan kawalan perkakasan meningkatkan prestasi pelaksanaan. Apabila membuat pilihan, anda perlu menimbang kelajuan pembangunan dan kecekapan pelaksanaan berdasarkan keperluan projek.

See all articles

Alat AI Hot

Undresser.AI Undress

Undresser.AI Undress

Apl berkuasa AI untuk mencipta foto bogel yang realistik

AI Clothes Remover

AI Clothes Remover

Alat AI dalam talian untuk mengeluarkan pakaian daripada foto.

Undress AI Tool

Undress AI Tool

Gambar buka pakaian secara percuma

Clothoff.io

Clothoff.io

Penyingkiran pakaian AI

AI Hentai Generator

AI Hentai Generator

Menjana ai hentai secara percuma.

Alat panas

MantisBT

MantisBT

Mantis ialah alat pengesan kecacatan berasaskan web yang mudah digunakan yang direka untuk membantu dalam pengesanan kecacatan produk. Ia memerlukan PHP, MySQL dan pelayan web. Lihat perkhidmatan demo dan pengehosan kami.

SublimeText3 Linux versi baharu

SublimeText3 Linux versi baharu

SublimeText3 Linux versi terkini

SublimeText3 versi Cina

SublimeText3 versi Cina

Versi Cina, sangat mudah digunakan

Muat turun versi mac editor Atom

Muat turun versi mac editor Atom

Editor sumber terbuka yang paling popular

SublimeText3 versi Mac

SublimeText3 versi Mac

Perisian penyuntingan kod peringkat Tuhan (SublimeText3)