EMNIST in PyTorch

Barbara Streisand
Barbara StreisandOriginal
2024-12-10 00:33:10934browse

Buy Me a Coffee☕

*My post explains EMNIST.

EMNIST() can use EMNIST dataset as shown below:

*Memos:

  • The 1st argument is root(Required-Type:str or pathlib.Path). *An absolute or relative path is possible.
  • The 2nd argument is split(Required-Type:str). *"byclass", "bymerge", "balanced", "letters", "digits" or "mnist" can be set to it.
  • There is train argument(Optional-Default:False-Type:float): *Memos:
    • For split="byclass" and split="byclass", if it's True, train data(697,932 images) is used while if it's False, test data(116,323 images) is used.
    • For split="balanced", if it's True, train data(112,800 images) is used while if it's False, test data(188,00 images) is used.
    • For split="letters", if it's True, train data(124,800 images) is used while if it's False, test data(20,800 images) is used.
    • For split="digits", if it's True, train data(240,000 images) is used while if it's False, test data(40,000 images) is used.
    • For split="mnist", if it's True, train data(60,000 images) is used while if it's False, test data(10,000 images) is used.
  • There is transform argument(Optional-Default:None-Type:callable).
  • There is target_transform argument(Optional-Default:None-Type:callable).
  • There is download argument(Optional-Default:False-Type:bool): *Memos:
    • If it's True, the dataset is downloaded from the internet and extracted(unzipped) to root.
    • If it's True and the dataset is already downloaded, it's extracted.
    • If it's True and the dataset is already downloaded and extracted, nothing happens.
    • It should be False if the dataset is already downloaded and extracted because it's faster.
    • You can manually download and extract the dataset from here to e.g. data/EMNIST/raw/.
  • There is the bug which the images are flipped and rotated 90 degrees anticlockwise by default so they should be transformed.
from torchvision.datasets import EMNIST

train_data = EMNIST(
    root="data",
    split="byclass"
)

train_data = EMNIST(
    root="data",
    split="byclass",
    train=True,
    transform=None,
    target_transform=None,
    download=False
)

test_data = EMNIST(
    root="data",
    split="byclass",
    train=False
)

len(train_data), len(test_data)
# 697932 116323

train_data
# Dataset EMNIST
#     Number of datapoints: 697932
#     Root location: data
#     Split: Train

train_data.root
# 'data'

train_data.split
# 'byclass'

train_data.train
# True

print(train_data.transform)
# None

print(train_data.target_transform)
# None

train_data.download
# <bound method EMNIST.download of Dataset EMNIST
#     Number of datapoints: 697932
#     Root location: data
#     Split: Train>

train_data[0]
# (<PIL.Image.Image image mode=L size=28x28>, 35)

train_data[1]
# (<PIL.Image.Image image mode=L size=28x28>, 36)

train_data[2]
# (<PIL.Image.Image image mode=L size=28x28>, 6)

train_data[3]
# (<PIL.Image.Image image mode=L size=28x28>, 3)

train_data[4]
# (<PIL.Image.Image image mode=L size=28x28>, 22)

train_data.classes
# ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
#  'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
#  'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
#  'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
#  'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
from torchvision.datasets import EMNIST

train_data = EMNIST(
    root="data",
    split="byclass",
    train=True
)

test_data = EMNIST(
    root="data",
    split="byclass",
    train=False
)

import matplotlib.pyplot as plt

def show_images(data):
    plt.figure(figsize=(12, 2))
    col = 5
    for i, (image, label) in enumerate(data, 1):
        plt.subplot(1, col, i)
        plt.title(label)
        plt.imshow(image)
        if i == col:
            break
    plt.show()

show_images(data=train_data)
show_images(data=test_data)

EMNIST in PyTorch

from torchvision.datasets import EMNIST
from torchvision.transforms import v2

train_data = EMNIST(
    root="data",
    split="byclass",
    train=True,
    transform=v2.Compose([
        v2.RandomHorizontalFlip(p=1.0),
        v2.RandomRotation(degrees=(90, 90))
    ])
)

test_data = EMNIST(
    root="data",
    split="byclass",
    train=False,
    transform=v2.Compose([
        v2.RandomHorizontalFlip(p=1.0),
        v2.RandomRotation(degrees=(90, 90))
    ])
)

import matplotlib.pyplot as plt

def show_images(data):
    plt.figure(figsize=(12, 2))
    col = 5
    for i, (image, label) in enumerate(data, 1):
        plt.subplot(1, col, i)
        plt.title(label)
        plt.imshow(image)
        if i == col:
            break
    plt.show()

show_images(data=train_data)
show_images(data=test_data)

EMNIST in PyTorch

The above is the detailed content of EMNIST in PyTorch. For more information, please follow other related articles on the PHP Chinese website!

Statement:
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn