Home >Backend Development >Python Tutorial >Flatten in PyTorch
Buy Me a Coffee☕
*Memos:
Flatten() can remove zero or more dimensions by selecting dimensions from the 0D or more D tensor of zero or more elements, getting the 1D or more D tensor of zero or more elements as shown below:
*Memos:
import torch from torch import nn flatten = nn.Flatten() flatten # Flatten(start_dim=1, end_dim=-1) flatten.start_dim # 1 flatten.end_dim # -1 my_tensor = torch.tensor(7) flatten = nn.Flatten(start_dim=0, end_dim=0) flatten = nn.Flatten(start_dim=0, end_dim=-1) flatten = nn.Flatten(start_dim=-1, end_dim=0) flatten = nn.Flatten(start_dim=-1, end_dim=-1) flatten(input=my_tensor) # tensor([7]) my_tensor = torch.tensor([7, 1, -8, 3, -6, 0]) flatten = nn.Flatten(start_dim=0, end_dim=0) flatten = nn.Flatten(start_dim=0, end_dim=-1) flatten = nn.Flatten(start_dim=-1, end_dim=0) flatten = nn.Flatten(start_dim=-1, end_dim=-1) flatten(input=my_tensor) # tensor([7, 1, -8, 3, -6, 0]) my_tensor = torch.tensor([[7, 1, -8], [3, -6, 0]]) flatten = nn.Flatten(start_dim=0, end_dim=1) flatten = nn.Flatten(start_dim=0, end_dim=-1) flatten = nn.Flatten(start_dim=-2, end_dim=1) flatten = nn.Flatten(start_dim=-2, end_dim=-1) flatten(input=my_tensor) # tensor([7, 1, -8, 3, -6, 0]) flatten = nn.Flatten() flatten = nn.Flatten(start_dim=0, end_dim=0) flatten = nn.Flatten(start_dim=-1, end_dim=-1) flatten = nn.Flatten(start_dim=0, end_dim=-2) flatten = nn.Flatten(start_dim=1, end_dim=1) flatten = nn.Flatten(start_dim=1, end_dim=-1) flatten = nn.Flatten(start_dim=-1, end_dim=1) flatten = nn.Flatten(start_dim=-1, end_dim=-1) flatten = nn.Flatten(start_dim=-2, end_dim=0) flatten = nn.Flatten(start_dim=-2, end_dim=-2) flatten(input=my_tensor) # tensor([[7, 1, -8], [3, -6, 0]]) my_tensor = torch.tensor([[[7], [1], [-8]], [[3], [-6], [0]]]) flatten = nn.Flatten(start_dim=0, end_dim=2) flatten = nn.Flatten(start_dim=0, end_dim=-1) flatten = nn.Flatten(start_dim=-3, end_dim=2) flatten = nn.Flatten(start_dim=-3, end_dim=-1) flatten(input=my_tensor) # tensor([7, 1, -8, 3, -6, 0]) flatten = nn.Flatten(start_dim=0, end_dim=0) flatten = nn.Flatten(start_dim=0, end_dim=-3) flatten = nn.Flatten(start_dim=1, end_dim=1) flatten = nn.Flatten(start_dim=1, end_dim=-2) flatten = nn.Flatten(start_dim=2, end_dim=2) flatten = nn.Flatten(start_dim=2, end_dim=-1) flatten = nn.Flatten(start_dim=-1, end_dim=2) flatten = nn.Flatten(start_dim=-1, end_dim=-1) flatten = nn.Flatten(start_dim=-2, end_dim=1) flatten = nn.Flatten(start_dim=-2, end_dim=-2) flatten = nn.Flatten(start_dim=-3, end_dim=0) flatten = nn.Flatten(start_dim=-3, end_dim=-3) flatten(input=my_tensor) # tensor([[[7], [1], [-8]], [[3], [-6], [0]]]) flatten = nn.Flatten(start_dim=0, end_dim=1) flatten = nn.Flatten(start_dim=0, end_dim=-2) flatten = nn.Flatten(start_dim=-3, end_dim=1) flatten = nn.Flatten(start_dim=-3, end_dim=-2) flatten(input=my_tensor) # tensor([[7], [1], [-8], [3], [-6], [0]]) flatten = nn.Flatten() flatten = nn.Flatten(start_dim=1, end_dim=2) flatten = nn.Flatten(start_dim=1, end_dim=-1) flatten = nn.Flatten(start_dim=-2, end_dim=2) flatten = nn.Flatten(start_dim=-2, end_dim=-1) flatten(input=my_tensor) # tensor([[7, 1, -8], [3, -6, 0]]) my_tensor = torch.tensor([[[7.], [1.], [-8.]], [[3.], [-6.], [0.]]]) flatten = nn.Flatten() flatten(input=my_tensor) # tensor([[7., 1., -8.], [3., -6., 0.]]) my_tensor = torch.tensor([[[7.+0.j], [1.+0.j], [-8.+0.j]], [[3.+0.j], [-6.+0.j], [0.+0.j]]]) flatten = nn.Flatten() flatten(input=my_tensor) # tensor([[7.+0.j, 1.+0.j, -8.+0.j], # [3.+0.j, -6.+0.j, 0.+0.j]]) my_tensor = torch.tensor([[[True], [False], [True]], [[False], [True], [False]]]) flatten = nn.Flatten() flatten(input=my_tensor) # tensor([[True, False, True], # [False, True, False]])
The above is the detailed content of Flatten in PyTorch. For more information, please follow other related articles on the PHP Chinese website!