Heim >Backend-Entwicklung >Python-Tutorial >FashionMNIST in PyTorch

FashionMNIST in PyTorch

Patricia Arquette
Patricia ArquetteOriginal
2024-12-11 15:24:16763Durchsuche

Kauf mir einen Kaffee☕

*Mein Beitrag erklärt Fashion-MNIST.

FashionMNIST() kann den Fashion-MNIST-Datensatz wie unten gezeigt verwenden:

*Memos:

  • Das 1. Argument ist root(Required-Type:str oder pathlib.Path). *Ein absoluter oder relativer Pfad ist möglich.
  • Das 2. Argument ist train(Optional-Default:True-Type:bool). *Wenn es wahr ist, werden Trainingsdaten (60.000 Bilder) verwendet, während wenn es falsch ist, Testdaten (10.000 Bilder) verwendet werden.
  • Das dritte Argument ist transform(Optional-Default:None-Type:callable).
  • Das 4. Argument ist target_transform(Optional-Default:None-Type:callable).
  • Das 5. Argument ist download(Optional-Default:False-Type:bool): *Memos:
    • Wenn es wahr ist, wird der Datensatz aus dem Internet heruntergeladen und in das Stammverzeichnis extrahiert (entpackt).
    • Wenn es „True“ ist und der Datensatz bereits heruntergeladen wurde, wird er extrahiert.
    • Wenn es „True“ ist und der Datensatz bereits heruntergeladen und extrahiert wurde, passiert nichts.
    • Es sollte „False“ sein, wenn der Datensatz bereits heruntergeladen und extrahiert wurde, da es schneller ist.
    • Sie können den Datensatz (t10k-images-idx3-ubyte.gz, t10k-labels-idx1-ubyte.gz, train-images-idx3-ubyte.gz und train-labels-idx1-ubyte) manuell herunterladen und extrahieren. gz) von hier nach data/FashionMNIST/raw/.
from torchvision.datasets import FashionMNIST

train_data = FashionMNIST(
    root="data"
)

train_data = FashionMNIST(
    root="data",
    train=True,
    transform=None,
    target_transform=None,
    download=False
)

test_data = FashionMNIST(
    root="data",
    train=False
)

len(train_data), len(test_data)
# (60000, 10000)

train_data
# Dataset FashionMNIST
#     Number of datapoints: 60000
#     Root location: data
#     Split: Train

train_data.root
# 'data'

train_data.train
# True

print(train_data.transform)
# None

print(train_data.target_transform)
# None

train_data.download
# <bound method MNIST.download of Dataset FashionMNIST
#     Number of datapoints: 60000
#     Root location: data
#     Split: Train>

len(train_data.classes)
# 10

train_data.classes
# ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
#  'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

train_data[0]
# (<PIL.Image.Image image mode=L size=28x28>, 9)

train_data[1]
# (<PIL.Image.Image image mode=L size=28x28>, 0)

train_data[2]
# (<PIL.Image.Image image mode=L size=28x28>, 0)

train_data[3]
# (<PIL.Image.Image image mode=L size=28x28>, 3)

train_data[4]
# (<PIL.Image.Image image mode=L size=28x28>, 0)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(8, 4))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, (image, label) in enumerate(data, 1):
        plt.subplot(2, 5, i)
        plt.tight_layout()
        plt.title(label)
        plt.imshow(image)
        if i == 10:
            break
    plt.show()

show_images(data=train_data, main_title="train_data")
show_images(data=test_data, main_title="test_data")

FashionMNIST in PyTorch

Das obige ist der detaillierte Inhalt vonFashionMNIST in PyTorch. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn