MNIST in PyTorch

Susan Sarandon
Susan SarandonOriginal
2024-12-23 05:04:31511Durchsuche

Kauf mir einen Kaffee☕

*Mein Beitrag erklärt MNIST.

MNIST() kann den MNIST-Datensatz wie unten gezeigt verwenden:

*Memos:

  • Das 1. Argument ist root(Required-Type:str oder pathlib.Path). *Ein absoluter oder relativer Pfad ist möglich.
  • Das 2. Argument ist train(Optional-Default:False-Type:float). *Wenn es wahr ist, werden Trainingsdaten (60.000 Proben) verwendet, während wenn es falsch ist, Testdaten (60.000 Proben) verwendet werden.
  • Das dritte Argument ist transform(Optional-Default:None-Type:callable).
  • Das 4. Argument ist target_transform(Optional-Default:None-Type:callable).
  • Das 5. Argument ist download(Optional-Default:False-Type:bool): *Memos:
    • Wenn es wahr ist, wird der Datensatz aus dem Internet heruntergeladen und in das Stammverzeichnis extrahiert (entpackt).
    • Wenn es „True“ ist und der Datensatz bereits heruntergeladen wurde, wird er extrahiert.
    • Wenn es „True“ ist und der Datensatz bereits heruntergeladen und extrahiert wurde, passiert nichts.
    • Es sollte „False“ sein, wenn der Datensatz bereits heruntergeladen und extrahiert wurde, da es schneller ist.
    • Sie können den Datensatz hier manuell herunterladen und extrahieren, um ihn z. data/MNIST/raw/.
from torchvision.datasets import MNIST

train_data = MNIST(
    root="data"
)

train_data = MNIST(
    root="data",
    train=True,
    transform=None,
    target_transform=None,
    download=False
)

train_data
# Dataset MNIST
#     Number of datapoints: 60000
#     Root location: data
#     Split: Train

train_data.root
# 'data'

train_data.train
# True

print(train_data.transform)
# None

print(train_data.target_transform)
# None

train_data.download
# <bound method MNIST.download of Dataset MNIST
#     Number of datapoints: 60000
#     Root location: data
#     Split: Train>

train_data[0]
# (<PIL.Image.Image image mode=L size=28x28>, 5)

train_data[1]
# (<PIL.Image.Image image mode=L size=28x28>, 0)

train_data[2]
# (<PIL.Image.Image image mode=L size=28x28>, 4)

train_data[3]
# (<PIL.Image.Image image mode=L size=28x28>, 1)

train_data.classes
# ['0 - zero',
#  '1 - one',
#  '2 - two',
#  '3 - three',
#  '4 - four',
#  '5 - five',
#  '6 - six',
#  '7 - seven',
#  '8 - eight',
#  '9 - nine']
from torchvision.datasets import MNIST

train_data = MNIST(
    root="data"
)

test_data = MNIST(
    root="data",
    train=False
)

import matplotlib.pyplot as plt

def show_images(data):
    plt.figure(figsize=(10, 2))
    col = 4
    for i, (image, label) in enumerate(data, 1):
        plt.subplot(1, col, i)
        plt.title(label)
        plt.imshow(image)
        if i == col:
            break
    plt.show()

show_images(data=train_data)
show_images(data=test_data)

MNIST in PyTorch

Das obige ist der detaillierte Inhalt vonMNIST in PyTorch. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn