CIFARin PyTorch

Linda Hamilton
Linda Hamiltonオリジナル
2024-12-16 12:57:15461ブラウズ

コーヒー買ってきて☕

*私の投稿では CIFAR-10 について説明しています。

CIFAR10() は、以下に示すように CIFAR-10 データセットを使用できます。

*メモ:

  • 最初の引数は root(Required-Type:str または pathlib.Path) です。 *絶対パスまたは相対パスが可能です。
  • 2 番目の引数は train(Optional-Default:True-Type:bool) です。 ※Trueの場合はトレーニングデータ(50,000枚)、Falseの場合はテストデータ(10,000枚)を使用します。
  • 3 番目の引数は、transform(Optional-Default:None-Type:callable) です。
  • 4 番目の引数は target_transform(Optional-Default:None-Type:callable) です。
  • 5 番目の引数は download(Optional-Default:False-Type:bool) です。 *メモ:
    • True の場合、データセットはインターネットからダウンロードされ、ルートに抽出 (解凍) されます。
    • これが True で、データセットが既にダウンロードされている場合、データセットは抽出されます。
    • これが True で、データセットがすでにダウンロードされ抽出されている場合は、何も起こりません。
    • データセットがすでにダウンロードされ抽出されている場合は、その方が高速であるため、False にする必要があります。
    • データセット (cifar-10-python.tar.gz) をここから data/cifar-10-batches-py/ に手動でダウンロードして抽出できます。
from torchvision.datasets import CIFAR10

train_data = CIFAR10(
    root="data"
)

train_data = CIFAR10(
    root="data",
    train=True,
    transform=None,
    target_transform=None,
    download=False
)

test_data = CIFAR10(
    root="data",
    train=False
)

len(train_data), len(test_data)
# (50000, 10000)

train_data
# Dataset CIFAR10
#     Number of datapoints: 50000
#     Root location: data
#     Split: Train

train_data.root
# 'data'

train_data.train
# True

print(train_data.transform)
# None

print(train_data.target_transform)
# None

train_data.download
# bound method CIFAR10.download of Dataset CIFAR10
#     Number of datapoints: 50000
#     Root location: data
#     Split: Train>

len(train_data.classes)
# 10

train_data.classes
# ['airplane', 'automobile', 'bird', 'cat', 'deer',
#  'dog', 'frog', 'horse', 'ship', 'truck']

train_data[0]
# (<PIL.Image.Image image mode=RGB size=32x32>, 6)

train_data[1]
# (<PIL.Image.Image image mode=RGB size=32x32>, 9)

train_data[2]
# (<PIL.Image.Image image mode=RGB size=32x32>, 9)

train_data[3]
# (<PIL.Image.Image image mode=RGB size=32x32>, 4)

train_data[4]
# (<PIL.Image.Image image mode=RGB size=32x32>, 1)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(10, 5))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, (im, lab) in enumerate(data, start=1):
        plt.subplot(2, 5, i)
        plt.title(label=lab)
        plt.imshow(X=im)
        if i == 10:
            break
    plt.tight_layout()
    plt.show()

show_images(data=train_data, main_title="train_data")
show_images(data=test_data, main_title="test_data")

CIFARin PyTorch

CIFARin PyTorch

以上がCIFARin PyTorchの詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

声明:
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。