MNIST dans PyTorch

Susan Sarandon
Susan Sarandonoriginal
2024-12-23 05:04:31510parcourir

Achetez-moi un café☕

*Mon message explique MNIST.

MNIST() peut utiliser l'ensemble de données MNIST comme indiqué ci-dessous :

*Mémos :

  • Le 1er argument est root (Required-Type:str ou pathlib.Path). *Un chemin absolu ou relatif est possible.
  • Le 2ème argument est train(Optional-Default:False-Type:float). *Si c'est vrai, les données d'entraînement (60 000 échantillons) sont utilisées tandis que si c'est faux, les données de test (60 000 échantillons) sont utilisées.
  • Le 3ème argument est transform(Optional-Default:None-Type:callable).
  • Le 4ème argument est target_transform(Optional-Default:None-Type:callable).
  • Le 5ème argument est download(Optional-Default:False-Type:bool) : *Mémos :
    • Si c'est vrai, l'ensemble de données est téléchargé depuis Internet et extrait (décompressé) vers root.
    • Si c'est Vrai et que l'ensemble de données est déjà téléchargé, il est extrait.
    • Si c'est vrai et que l'ensemble de données est déjà téléchargé et extrait, rien ne se passe.
    • Il devrait être faux si l'ensemble de données est déjà téléchargé et extrait car il est plus rapide.
    • Vous pouvez télécharger et extraire manuellement l'ensemble de données à partir d'ici, par exemple. data/MNIST/brut/.
from torchvision.datasets import MNIST

train_data = MNIST(
    root="data"
)

train_data = MNIST(
    root="data",
    train=True,
    transform=None,
    target_transform=None,
    download=False
)

train_data
# Dataset MNIST
#     Number of datapoints: 60000
#     Root location: data
#     Split: Train

train_data.root
# 'data'

train_data.train
# True

print(train_data.transform)
# None

print(train_data.target_transform)
# None

train_data.download
# <bound method MNIST.download of Dataset MNIST
#     Number of datapoints: 60000
#     Root location: data
#     Split: Train>

train_data[0]
# (<PIL.Image.Image image mode=L size=28x28>, 5)

train_data[1]
# (<PIL.Image.Image image mode=L size=28x28>, 0)

train_data[2]
# (<PIL.Image.Image image mode=L size=28x28>, 4)

train_data[3]
# (<PIL.Image.Image image mode=L size=28x28>, 1)

train_data.classes
# ['0 - zero',
#  '1 - one',
#  '2 - two',
#  '3 - three',
#  '4 - four',
#  '5 - five',
#  '6 - six',
#  '7 - seven',
#  '8 - eight',
#  '9 - nine']
from torchvision.datasets import MNIST

train_data = MNIST(
    root="data"
)

test_data = MNIST(
    root="data",
    train=False
)

import matplotlib.pyplot as plt

def show_images(data):
    plt.figure(figsize=(10, 2))
    col = 4
    for i, (image, label) in enumerate(data, 1):
        plt.subplot(1, col, i)
        plt.title(label)
        plt.imshow(image)
        if i == col:
            break
    plt.show()

show_images(data=train_data)
show_images(data=test_data)

MNIST in PyTorch

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn