Blumen in PyTorch

Patricia Arquette
Patricia ArquetteOriginal
2024-12-16 16:40:11429Durchsuche

Kauf mir einen Kaffee☕

*Mein Beitrag erklärt Oxford 102 Flower.

Flowers102() kann den Oxford 102 Flower-Datensatz wie unten gezeigt verwenden:

*Memos:

  • Das 1. Argument ist root(Required-Type:str oder pathlib.Path). *Ein absoluter oder relativer Pfad ist möglich.
  • Das 2. Argument ist geteilt (Optional-Default:"train"-Type:str). *Es können „train“ (1.020 Bilder), „val“ (1.020 Bilder) oder „test“ (6.149 Bilder) eingestellt werden.
  • Das dritte Argument ist transform(Optional-Default:None-Type:callable).
  • Das 4. Argument ist target_transform(Optional-Default:None-Type:callable).
  • Das 5. Argument ist download(Optional-Default:False-Type:bool): *Memos:
    • Wenn es wahr ist, wird der Datensatz aus dem Internet heruntergeladen und in das Stammverzeichnis extrahiert (entpackt).
    • Wenn es „True“ ist und der Datensatz bereits heruntergeladen wurde, wird er extrahiert.
    • Wenn es „True“ ist und der Datensatz bereits heruntergeladen und extrahiert wurde, passiert nichts.
    • Es sollte False sein, wenn der Datensatz bereits heruntergeladen und extrahiert wurde, da es schneller ist.
    • Sie können den Datensatz (102flowers.tgz mit imagelabels.mat und setid.matff) manuell herunterladen und von hier nach data/flowers-102/ extrahieren.
  • Über die Beschriftung der Kategorien (Klassen) für die Zug- und Validierungsbildindizes: 0 ist 0–9, 1 ist 10–19, 2 ist 20–29, 3 ist 30–39, 4 ist 40–49, 5 ist 50~59, 6 ist 60~69, 7 ist 70~79, 8 ist 80~89, 9 ist 90~99 usw.
  • Über die Bezeichnung der Kategorien (Klassen) für die Testbildindizes: 0 ist 0–19, 1 ist 20–59, 2 ist 60–79, 3 ist 80–115, 4 ist 116–160, 5 ist 161~185, 6 ist 186~205, 7 ist 206~270, 8 ist 271~296, 9 ist 297~321 usw.
from torchvision.datasets import Flowers102

train_data = Flowers102(
    root="data"
)

train_data = Flowers102(
    root="data",
    split="train",
    transform=None,
    target_transform=None,
    download=False
)

val_data = Flowers102(
    root="data",
    split="val"
)

test_data = Flowers102(
    root="data",
    split="test"
)

len(train_data), len(val_data), len(test_data)
# (1020, 1020, 6149)

train_data
# Dataset Flowers102
#     Number of datapoints: 1020
#     Root location: data
#     split=train

train_data.root
# 'data'

train_data._split
# 'train'

print(train_data.transform)
# None

print(train_data.target_transform)
# None

train_data.download
# <bound method Flowers102.download of Dataset Flowers102
#     Number of datapoints: 1020
#     Root location: data
#     split=train>

len(set(train_data._labels)), train_data._labels
# (102,
#  [0, 0, 0, ..., 1, ..., 2, ..., 3, ..., 4, ..., 5, ..., 6, ..., 101])

train_data[0]
# (<PIL.Image.Image image mode=RGB size=754x500>, 0)

train_data[1]
# (<PIL.Image.Image image mode=RGB size=624x500>, 0)

train_data[2]
# (<PIL.Image.Image image mode=RGB size=667x500>, 0)

train_data[10]
# (<PIL.Image.Image image mode=RGB size=500x682>, 1)

train_data[20]
# (<PIL.Image.Image image mode=RGB size=667x500>, 2)

val_data[0]
# (<PIL.Image.Image image mode=RGB size=606x500>, 0)

val_data[1]
# (<PIL.Image.Image image mode=RGB size=667x500>, 0)

val_data[2]
# (<PIL.Image.Image image mode=RGB size=500x628>, 0)

val_data[10]
# (<PIL.Image.Image image mode=RGB size=500x766>, 1)

val_data[20]
# (<PIL.Image.Image image mode=RGB size=624x500>, 2)

test_data[0]
# (<PIL.Image.Image image mode=RGB size=523x500>, 0)

test_data[1]
# (<PIL.Image.Image image mode=RGB size=666x500>, 0)

test_data[2]
# (<PIL.Image.Image image mode=RGB size=595x500>, 0)

test_data[20]
# (<PIL.Image.Image image mode=RGB size=500x578>, 1)

test_data[60]
# (<PIL.Image.Image image mode=RGB size=500x625>, 2)

import matplotlib.pyplot as plt

def show_images(data, ims, main_title=None):
    plt.figure(figsize=(10, 5))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, j in enumerate(ims, start=1):
        plt.subplot(2, 5, i)
        im, lab = data[j]
        plt.imshow(X=im)
        plt.title(label=lab)
    plt.tight_layout()
    plt.show()

train_ims = (0, 1, 2, 10, 20, 30, 40, 50, 60, 70)
val_ims = (0, 1, 2, 10, 20, 30, 40, 50, 60, 70)
test_ims = (0, 1, 2, 20, 60, 80, 116, 161, 186, 206)

show_images(data=train_data, ims=train_ims, main_title="train_data")
show_images(data=train_data, ims=val_ims, main_title="val_data")
show_images(data=test_data, ims=test_ims, main_title="test_data")

Flowers in PyTorch

Flowers in PyTorch

Flowers in PyTorch

Das obige ist der detaillierte Inhalt vonBlumen in PyTorch. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn