FashionMNIST dalam PyTorch

Patricia Arquette
Patricia Arquetteasal
2024-12-11 15:24:16757semak imbas

Beli Saya Kopi☕

*Siaran saya menerangkan Fashion-MNIST.

FashionMNIST() boleh menggunakan dataset Fashion-MNIST seperti yang ditunjukkan di bawah:

*Memo:

  • Argumen pertama ialah root(Required-Type:str or pathlib.Path). *Laluan mutlak atau relatif boleh dilakukan.
  • Argumen ke-2 ialah train(Pilihan-Lalai:True-Type:bool). *Jika Benar, data kereta api(60,000 imej) digunakan manakala jika Salah, data ujian(10,000 imej) digunakan.
  • Argumen ke-3 ialah transform(Optional-Default:None-Type:callable).
  • Argumen ke-4 ialah target_transform(Optional-Default:None-Type:callable).
  • Argumen ke-5 ialah muat turun(Optional-Default:False-Type:bool): *Memo:
    • Jika Benar, set data dimuat turun dari internet dan diekstrak (dibuka zip) ke akar.
    • Jika ia Benar dan set data sudah dimuat turun, ia akan diekstrak.
    • Jika ia Benar dan set data sudah dimuat turun dan diekstrak, tiada apa yang berlaku.
    • Ia sepatutnya Palsu jika set data sudah dimuat turun dan diekstrak kerana ia lebih pantas.
    • Anda boleh memuat turun dan mengekstrak set data secara manual (t10k-images-idx3-ubyte.gz, t10k-labels-idx1-ubyte.gz, train-images-idx3-ubyte.gz dan train-labels-idx1-ubyte. gz) dari sini ke data/FashionMNIST/raw/.
from torchvision.datasets import FashionMNIST

train_data = FashionMNIST(
    root="data"
)

train_data = FashionMNIST(
    root="data",
    train=True,
    transform=None,
    target_transform=None,
    download=False
)

test_data = FashionMNIST(
    root="data",
    train=False
)

len(train_data), len(test_data)
# (60000, 10000)

train_data
# Dataset FashionMNIST
#     Number of datapoints: 60000
#     Root location: data
#     Split: Train

train_data.root
# 'data'

train_data.train
# True

print(train_data.transform)
# None

print(train_data.target_transform)
# None

train_data.download
# <bound method MNIST.download of Dataset FashionMNIST
#     Number of datapoints: 60000
#     Root location: data
#     Split: Train>

len(train_data.classes)
# 10

train_data.classes
# ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
#  'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

train_data[0]
# (<PIL.Image.Image image mode=L size=28x28>, 9)

train_data[1]
# (<PIL.Image.Image image mode=L size=28x28>, 0)

train_data[2]
# (<PIL.Image.Image image mode=L size=28x28>, 0)

train_data[3]
# (<PIL.Image.Image image mode=L size=28x28>, 3)

train_data[4]
# (<PIL.Image.Image image mode=L size=28x28>, 0)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(8, 4))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, (image, label) in enumerate(data, 1):
        plt.subplot(2, 5, i)
        plt.tight_layout()
        plt.title(label)
        plt.imshow(image)
        if i == 10:
            break
    plt.show()

show_images(data=train_data, main_title="train_data")
show_images(data=test_data, main_title="test_data")

FashionMNIST in PyTorch

Atas ialah kandungan terperinci FashionMNIST dalam PyTorch. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Kenyataan:
Kandungan artikel ini disumbangkan secara sukarela oleh netizen, dan hak cipta adalah milik pengarang asal. Laman web ini tidak memikul tanggungjawab undang-undang yang sepadan. Jika anda menemui sebarang kandungan yang disyaki plagiarisme atau pelanggaran, sila hubungi admin@php.cn