PyTorch의 EMNIST

Barbara Streisand
Barbara Streisand원래의
2024-12-10 00:33:10872검색

커피 한잔 사주세요😄

*내 게시물은 EMNIST를 설명합니다.

EMNIST()는 아래와 같이 EMNIST 데이터세트를 사용할 수 있습니다.

*메모:

  • 첫 번째 인수는 루트(필수 유형:str 또는 pathlib.Path)입니다. *절대경로, 상대경로 모두 가능합니다.
  • 두 번째 인수는 분할(필수 유형:str)입니다. *"byclass", "bymerge", "balanced", "letters", "digits" 또는 "mnist"를 설정할 수 있습니다.
  • 열차 인수가 있습니다(Optional-Default:False-Type:float): *메모:
    • split="byclass", Split="byclass"의 경우, True이면 train 데이터(697,932개 이미지)가 사용되고, False이면 테스트 데이터(116,323개 이미지)가 사용됩니다.
    • split="balanced"의 경우 True이면 학습 데이터(112,800개 이미지)를 사용하고, False이면 테스트 데이터(188,00개 이미지)를 사용합니다.
    • split="letters"의 경우, True이면 학습 데이터(124,800개 이미지)가 사용되고, False이면 테스트 데이터(20,800개 이미지)가 사용됩니다.
    • split="digits"의 경우, True이면 학습 데이터(240,000개 이미지)가 사용되고, False이면 테스트 데이터(40,000개 이미지)가 사용됩니다.
    • split="mnist"의 경우 True이면 학습 데이터(60,000개 이미지)를 사용하고, False이면 테스트 데이터(10,000개 이미지)를 사용합니다.
  • Transform 인수(Optional-Default:None-Type:callable)가 있습니다.
  • target_transform 인수(Optional-Default:None-Type:callable)가 있습니다.
  • 다운로드 인수가 있습니다(Optional-Default:False-Type:bool): *메모:
    • True인 경우 데이터 세트가 인터넷에서 다운로드되어 루트에 추출(압축 해제)됩니다.
    • True이고 데이터세트가 이미 다운로드된 경우 추출됩니다.
    • True이고 데이터 세트가 이미 다운로드되어 추출된 경우 아무 일도 일어나지 않습니다.
    • 데이터 세트가 이미 다운로드되어 추출된 경우 더 빠르므로 False여야 합니다.
    • 여기에서 데이터 세트를 수동으로 다운로드하고 추출할 수 있습니다. 데이터/EMNIST/raw/.
  • 기본적으로 이미지가 반시계 방향으로 90도 회전 및 반전되는 버그가 있으므로 변환해야 합니다.
from torchvision.datasets import EMNIST

train_data = EMNIST(
    root="data",
    split="byclass"
)

train_data = EMNIST(
    root="data",
    split="byclass",
    train=True,
    transform=None,
    target_transform=None,
    download=False
)

test_data = EMNIST(
    root="data",
    split="byclass",
    train=False
)

len(train_data), len(test_data)
# 697932 116323

train_data
# Dataset EMNIST
#     Number of datapoints: 697932
#     Root location: data
#     Split: Train

train_data.root
# 'data'

train_data.split
# 'byclass'

train_data.train
# True

print(train_data.transform)
# None

print(train_data.target_transform)
# None

train_data.download
# <bound method EMNIST.download of Dataset EMNIST
#     Number of datapoints: 697932
#     Root location: data
#     Split: Train>

train_data[0]
# (<PIL.Image.Image image mode=L size=28x28>, 35)

train_data[1]
# (<PIL.Image.Image image mode=L size=28x28>, 36)

train_data[2]
# (<PIL.Image.Image image mode=L size=28x28>, 6)

train_data[3]
# (<PIL.Image.Image image mode=L size=28x28>, 3)

train_data[4]
# (<PIL.Image.Image image mode=L size=28x28>, 22)

train_data.classes
# ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
#  'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
#  'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
#  'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
#  'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
from torchvision.datasets import EMNIST

train_data = EMNIST(
    root="data",
    split="byclass",
    train=True
)

test_data = EMNIST(
    root="data",
    split="byclass",
    train=False
)

import matplotlib.pyplot as plt

def show_images(data):
    plt.figure(figsize=(12, 2))
    col = 5
    for i, (image, label) in enumerate(data, 1):
        plt.subplot(1, col, i)
        plt.title(label)
        plt.imshow(image)
        if i == col:
            break
    plt.show()

show_images(data=train_data)
show_images(data=test_data)

EMNIST in PyTorch

from torchvision.datasets import EMNIST
from torchvision.transforms import v2

train_data = EMNIST(
    root="data",
    split="byclass",
    train=True,
    transform=v2.Compose([
        v2.RandomHorizontalFlip(p=1.0),
        v2.RandomRotation(degrees=(90, 90))
    ])
)

test_data = EMNIST(
    root="data",
    split="byclass",
    train=False,
    transform=v2.Compose([
        v2.RandomHorizontalFlip(p=1.0),
        v2.RandomRotation(degrees=(90, 90))
    ])
)

import matplotlib.pyplot as plt

def show_images(data):
    plt.figure(figsize=(12, 2))
    col = 5
    for i, (image, label) in enumerate(data, 1):
        plt.subplot(1, col, i)
        plt.title(label)
        plt.imshow(image)
        if i == col:
            break
    plt.show()

show_images(data=train_data)
show_images(data=test_data)

EMNIST in PyTorch

위 내용은 PyTorch의 EMNIST의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.