CIFAR dans PyTorch

Susan Sarandon
Susan Sarandonoriginal
2024-12-16 17:15:15873parcourir

Achetez-moi un café☕

*Mon message explique CIFAR-100.

CIFAR100() peut utiliser l'ensemble de données CIFAR-100 comme indiqué ci-dessous :

*Mémos :

  • Le 1er argument est root (Required-Type:str ou pathlib.Path). *Un chemin absolu ou relatif est possible.
  • Le 2ème argument est train(Optional-Default:True-Type:bool). *Si c'est vrai, les données du train (50 000 images) sont utilisées tandis que si c'est faux, les données de test (10 000 images) sont utilisées.
  • Le 3ème argument est transform(Optional-Default:None-Type:callable).
  • Le 4ème argument est target_transform(Optional-Default:None-Type:callable).
  • Le 5ème argument est download(Optional-Default:False-Type:bool) : *Mémos :
    • Si c'est vrai, l'ensemble de données est téléchargé depuis Internet et extrait (décompressé) à la racine.
    • Si c'est Vrai et que l'ensemble de données est déjà téléchargé, il est extrait.
    • Si c'est vrai et que l'ensemble de données est déjà téléchargé et extrait, rien ne se passe.
    • Il devrait être faux si l'ensemble de données est déjà téléchargé et extrait car il est plus rapide.
    • Vous pouvez télécharger et extraire manuellement l'ensemble de données (cifar-100-python.tar.gz) d'ici vers data/cifar-100-python/.
from torchvision.datasets import CIFAR100

train_data = CIFAR100(
    root="data"
)

train_data = CIFAR100(
    root="data",
    train=True,
    transform=None,
    target_transform=None,
    download=False
)

test_data = CIFAR100(
    root="data",
    train=False
)

len(train_data), len(test_data)
# (50000, 10000)

train_data
# Dataset CIFAR100
#     Number of datapoints: 50000
#     Root location: data
#     Split: Train

train_data.root
# 'data'

train_data.train
# True

print(train_data.transform)
# None

print(train_data.target_transform)
# None

train_data.download
# <bound method CIFAR10.download of Dataset CIFAR100
#    Number of datapoints: 50000
#    Root location: data
#    Split: Train>

len(train_data.classes), train_data.classes
# (100,
#  ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed',
#   'bicycle', 'bottle', 'bowl', ..., 'wolf', 'woman', 'worm']

train_data[0]
# (<PIL.Image.Image image mode=RGB size=32x32>, 19)

train_data[1]
# (<PIL.Image.Image image mode=RGB size=32x32>, 29)

train_data[2]
# (<PIL.Image.Image image mode=RGB size=32x32>, 0)

train_data[3]
# (<PIL.Image.Image image mode=RGB size=32x32>, 11)

train_data[4]
# (<PIL.Image.Image image mode=RGB size=32x32>, 1)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(10, 5))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, (im, lab) in enumerate(data, start=1):
        plt.subplot(2, 5, i)
        plt.title(label=lab)
        plt.imshow(X=im)
        if i == 10:
            break
    plt.tight_layout()
    plt.show()

show_images(data=train_data, main_title="train_data")
show_images(data=test_data, main_title="test_data")

CIFAR in PyTorch

CIFAR in PyTorch

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn