CelebA ialah PyTorch

Susan Sarandon
Susan Sarandonasal
2024-12-28 02:36:10171semak imbas

Beli Saya Kopi☕

*Siaran saya menerangkan CelebA.

CelebA() boleh menggunakan set data CelebA seperti yang ditunjukkan di bawah:

*Memo:

  • Argumen pertama ialah root(Required-Type:str or pathlib.Path). *Laluan mutlak atau relatif boleh dilakukan.
  • Argumen ke-2 dipecahkan(Pilihan-Lalai:"train"-Type:str). *"kereta api"(162,770 imej), "sah"(19,867 imej), "ujian"(19,962 imej) atau "semua"(202,599 imej) boleh ditetapkan padanya.
  • Argumen ke-3 ialah target_type(Optional-Default:"attr"-Type:str atau senarai str): *Memo:
    • "attr", "identiti", "bbox" dan/atau "tanda tempat" boleh ditetapkan padanya.
    • Senarai kosong juga boleh ditetapkan kepadanya.
    • Berbilang nilai yang sama boleh ditetapkan padanya.
    • Jika susunan nilai berbeza, susunan elemennya juga berbeza.
  • Argumen ke-4 ialah transform(Optional-Default:None-Type:callable).
  • Argumen ke-5 ialah target_transform(Optional-Default:None-Type:callable).
  • Argumen ke-6 ialah muat turun(Optional-Default:False-Type:bool): *Memo:
    • Jika Benar, set data dimuat turun dari internet dan diekstrak (dibuka zip) ke akar.
    • Jika ia Benar dan set data sudah dimuat turun, ia akan diekstrak.
    • Jika ia Benar dan set data sudah dimuat turun dan diekstrak, tiada apa yang berlaku.
    • Ia sepatutnya Palsu jika set data sudah dimuat turun dan diekstrak kerana ia lebih pantas.
    • gdown diperlukan untuk memuat turun set data.
    • Anda boleh memuat turun dan mengekstrak set data secara manual (img_align_celeba.zip dengan identity_CelebA.txt, list_attr_celeba.txt, list_bbox_celeba.txt, list_eval_partition.txt dan list_landmarks_align_celeba.txt) dari sini ke data/celeba/.
from torchvision.datasets import CelebA

train_attr_data = CelebA(
    root="data"
)

train_attr_data = CelebA(
    root="data",
    split="train",
    target_type="attr",
    transform=None,
    target_transform=None,
    download=False
)

valid_identity_data = CelebA(
    root="data",
    split="valid",
    target_type="identity"
)

test_bbox_data = CelebA(
    root="data",
    split="test",
    target_type="bbox"
)

all_landmarks_data = CelebA(
    root="data",
    split="all",
    target_type="landmarks"
)

all_empty_data = CelebA(
    root="data",
    split="all",
    target_type=[]
)

all_all_data = CelebA(
    root="data",
    split="all",
    target_type=["attr", "identity", "bbox", "landmarks"]
)

len(train_attr_data), len(valid_identity_data), len(test_bbox_data)
# (162770, 19867, 19962)

len(all_landmarks_data), len(all_empty_data), len(all_all_data)
# (202599, 202599, 202599)

train_attr_data
# Dataset CelebA
#     Number of datapoints: 162770
#     Root location: data
#     Target type: ['attr']
#     Split: train

train_attr_data.root
# 'data'

train_attr_data.split
# 'train'

train_attr_data.target_type
# ['attr']

print(train_attr_data.transform)
# None

print(train_attr_data.target_transform)
# None

train_attr_data.download
# <bound method CelebA.download of Dataset CelebA
#     Number of datapoints: 162770
#     Root location: data
#     Target type: ['attr']
#     Split: train>

len(train_attr_data.attr), train_attr_data.attr
# (162770, tensor([[0, 1, 1, ..., 0, 0, 1],
#                  [0, 0, 0, ..., 0, 0, 1],
#                  [0, 0, 0, ..., 0, 0, 1],
#                  ...,
#                  [1, 0, 1, ..., 0, 1, 1],
#                  [0, 0, 0, ..., 0, 0, 1],
#                  [0, 1, 1, ..., 1, 0, 1]]))

len(train_attr_data.attr_names), train_attr_data.attr_names
# (41, ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 
#       'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose',
#       'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair',
#       ...
#       'Wearing_Necklace', 'Wearing_Necktie', 'Young', ''])

len(train_attr_data.identity), train_attr_data.identity
# (162770, tensor([[2880], [2937], [8692], ..., [7391], [8610], [2304]]))

len(train_attr_data.bbox), train_attr_data.bbox
# (162770, tensor([[95, 71, 226, 313],
#                  [72, 94, 221, 306],
#                  [216, 59, 91, 126],
#                  ...,
#                  [103, 103, 143, 198],
#                  [30, 59, 216, 280],
#                  [376, 4, 372, 515]]))

len(train_attr_data.landmarks_align), train_attr_data.landmarks_align
# (162770, tensor([[69, 109, 106, ..., 152, 108, 154],
#                  [69, 110, 107, ..., 151, 108, 153],
#                  [76, 112, 104, ..., 156, 98, 158],
#                  ...,
#                  [69, 113, 109, ..., 151, 110, 151],
#                  [68, 112, 109, ..., 150, 108, 151],
#                  [70, 111, 107, ..., 153, 102, 152]]))

train_attr_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 1, 1,
#          0, 1, 0, 0, 1, 0, 0, 1, 0, 0,
#          0, 1, 1, 0, 1, 0, 1, 0, 0, 1]))

train_attr_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 0, 1,
#          0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 0, 1]))

train_attr_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
#          1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#          1, 0, 0, 1, 1, 0, 0, 1, 0, 0,
#          0, 0, 0, 1, 0, 0, 0, 0, 0, 1]))

valid_identity_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(2594))

valid_identity_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(2795))

valid_identity_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(947))

test_bbox_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([147, 82, 120, 166]))

test_bbox_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([106, 34, 140, 194]))

test_bbox_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([107, 78, 109, 151]))

all_landmarks_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([69, 109, 106, 113, 77, 142, 73, 152, 108, 154]))

all_landmarks_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([69, 110, 107, 112, 81, 135, 70, 151, 108, 153]))

all_landmarks_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([76, 112, 104, 106, 108, 128, 74, 156, 98, 158]))

all_empty_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)

all_empty_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)

all_empty_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)

all_all_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 1, 1,
#           0, 1, 0, 0, 1, 0, 0, 1, 0, 0,
#           0, 1, 1, 0, 1, 0, 1, 0, 0, 1]),
#   tensor(2880),
#   tensor([95, 71, 226, 313]),
#   tensor([69, 109, 106, 113, 77, 142, 73, 152, 108, 154])))

all_all_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 0, 1,
#           0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 0, 1]),
#   tensor(2937),
#   tensor([72, 94, 221, 306]),
#   tensor([69, 110, 107, 112, 81, 135, 70, 151, 108, 153])))

all_all_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
#           1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#           1, 0, 0, 1, 1, 0, 0, 1, 0, 0,
#           0, 0, 0, 1, 0, 0, 0, 0, 0, 1]),
#  tensor(8692),
#  tensor([216, 59, 91, 126]),
#  tensor([76, 112, 104, 106, 108, 128, 74, 156, 98, 158])))

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.patches import Circle

def show_images(data, main_title=None):
    if "attr" in data.target_type and len(data.target_type) == 1 \
        or not data.target_type:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, _) in enumerate(data, start=1):
            plt.subplot(2, 5, i)
            plt.imshow(X=im)
            if i == 10:
                break
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif "identity" in data.target_type and len(data.target_type) == 1:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, lab) in enumerate(data, start=1):
            plt.subplot(2, 5, i)
            plt.title(label=lab.item())
            plt.imshow(X=im)
            if i == 10:
                break
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif "bbox" in data.target_type and len(data.target_type) == 1:
        fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
        fig.suptitle(t=main_title, y=1.0, fontsize=14)
        for (i, (im, (x, y, w, h))), axis \
            in zip(enumerate(data, start=1), axes.ravel()):
            axis.imshow(X=im)
            rect = Rectangle(xy=(x, y), width=w, height=h,
                             linewidth=3, edgecolor='r',
                             facecolor='none')
            axis.add_patch(p=rect)
            if i == 10:
                break
        fig.tight_layout(h_pad=3.0)
        plt.show()
    elif "landmarks" in data.target_type and len(data.target_type) == 1:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, lm) in enumerate(data, start=1):
            px = []
            py = []
            for j, v in enumerate(lm):
                if j%2 == 0:
                    px.append(v)
                else:
                    py.append(v)
            plt.subplot(2, 5, i)
            plt.imshow(X=im)
            plt.scatter(x=px, y=py)
            if i == 10:
                break
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif len(data.target_type) == 4:
        fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
        fig.suptitle(t=main_title, y=1.0, fontsize=14)
        for (i, (im, (_, lab, (x, y, w, h), lm))), axis \
            in zip(enumerate(data, start=1), axes.ravel()):
            axis.set_title(label=lab.item())
            axis.imshow(X=im)
            rect = Rectangle(xy=(x, y), width=w, height=h,
                             linewidth=3, edgecolor='r',
                             facecolor='none', clip_on=True)
            axis.add_patch(p=rect)
            for j, (px, py) in enumerate(lm.split(2)):
                axis.add_patch(p=Circle(xy=(px, py)))
            # for j, v in enumerate(lm):
            #     if j%2 == 0:
            #         px.append(v)
            #     else:
            #         py.append(v)
            # axis.scatter(x=px, y=py)
            # axis.plot(px, py)
# `axis.scatter()` and `axis.plot()` of `plt.subplots()` don't work
# properly. They shrink images so use `axis.add_patch()` instead.
            if i == 10:
                break
        fig.tight_layout(h_pad=3.0)
        plt.show()

show_images(data=train_attr_data, main_title="train_attr_data")
show_images(data=valid_identity_data, main_title="valid_identity_data")
show_images(data=test_bbox_data, main_title="test_bbox_data")
show_images(data=all_landmarks_data, main_title="all_landmarks_data")
show_images(data=all_empty_data, main_title="all_empty_data")
show_images(data=all_all_data, main_title="all_all_data")

CelebA in PyTorch

CelebA in PyTorch

CelebA in PyTorch

CelebA in PyTorch

CelebA in PyTorch

CelebA in PyTorch

Atas ialah kandungan terperinci CelebA ialah PyTorch. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Kenyataan:
Kandungan artikel ini disumbangkan secara sukarela oleh netizen, dan hak cipta adalah milik pengarang asal. Laman web ini tidak memikul tanggungjawab undang-undang yang sepadan. Jika anda menemui sebarang kandungan yang disyaki plagiarisme atau pelanggaran, sila hubungi admin@php.cn