Maison >développement back-end >Tutoriel Python >ImageNet dans PyTorch
Achetez-moi un café☕
*Mon message explique ImageNet.
ImageNet() peut utiliser l'ensemble de données ImageNet comme indiqué ci-dessous :
*Mémos :
from torchvision.datasets import ImageNet from torchvision.datasets.folder import default_loader train_data = ImageNet( root="data" ) train_data = ImageNet( root="data", split="train", transform=None, target_transform=None, loader=default_loader ) val_data = ImageNet( root="data", split="val" ) len(train_data), len(val_data) # (1281167, 50000) train_data # Dataset ImageNet # Number of datapoints: 1281167 # Root location: D:/data # Split: train train_data.root # 'data' train_data.split # 'train' print(train_data.transform) # None print(train_data.target_transform) # None train_data.loader # <function torchvision.datasets.folder.default_loader(path: str) -> Any> len(train_data.classes), train_data.classes # (1000, # [('tench', 'Tinca tinca'), ('goldfish', 'Carassius auratus'), # ('great white shark', 'white shark', 'man-eater', 'man-eating shark', # 'Carcharodon carcharias'), ('tiger shark', 'Galeocerdo cuvieri'), # ('hammerhead', 'hammerhead shark'), ('electric ray', 'crampfish', # 'numbfish', 'torpedo'), ('stingray',), ('cock',), ('hen',), # ('ostrich', 'Struthio camelus'), ..., ('bolete',), ('ear', 'spike', # 'capitulum'), ('toilet tissue', 'toilet paper', 'bathroom tissue')]) train_data[0] # (<PIL.Image.Image image mode=RGB size=250x250>, 0) train_data[1] # (<PIL.Image.Image image mode=RGB size=200x150>, 0) train_data[2] # (<PIL.Image.Image image mode=RGB size=500x375>, 0) train_data[1300] # (<PIL.Image.Image image mode=RGB size=640x480>, 1) train_data[2600] # (<PIL.Image.Image image mode=RGB size=500x375>, 2) val_data[0] # (<PIL.Image.Image image mode=RGB size=500x375>, 0) val_data[1] # (<PIL.Image.Image image mode=RGB size=500x375>, 0) val_data[2] # (<PIL.Image.Image image mode=RGB size=500x375>, 0) val_data[50] # (<PIL.Image.Image image mode=RGB size=500x500>, 1) val_data[100] # (<PIL.Image.Image image mode=RGB size=679x444>, 2) import matplotlib.pyplot as plt def show_images(data, ims, main_title=None): plt.figure(figsize=[12, 6]) plt.suptitle(t=main_title, y=1.0, fontsize=14) for i, j in enumerate(iterable=ims, start=1): plt.subplot(2, 5, i) im, lab = data[j] plt.imshow(X=im) plt.title(label=lab) plt.tight_layout(h_pad=3.0) plt.show() train_ims = [0, 1, 2, 1300, 2600, 3900, 5200, 6500, 7800, 9100] val_ims = [0, 1, 2, 50, 100, 150, 200, 250, 300, 350] show_images(data=train_data, ims=train_ims, main_title="train_data") show_images(data=val_data, ims=val_ims, main_title="val_data")
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!