Home >Backend Development >Python Tutorial >unsqueeze in PyTorch
Buy Me a Coffee☕
*My post explains squeeze().
unsqueeze() can get the 1D or more D tensor of zero or more elements with additional dimension whose size is 1 from the 0D or more D tensor of zero or more elements as shown below:
*Memos:
import torch my_tensor = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [10, 11, 12]]) torch.unsqueeze(input=my_tensor, dim=0) my_tensor.unsqueeze(dim=0) torch.unsqueeze(input=my_tensor, dim=-3) # tensor([[[0, 1, 2], # [3, 4, 5], # [6, 7, 8] # [10, 11, 12]]]) torch.unsqueeze(input=my_tensor, dim=1) torch.unsqueeze(input=my_tensor, dim=-2) # tensor([[[0, 1, 2]], # [[3, 4, 5]], # [[6, 7, 8]] # [[10, 11, 12]]]) torch.unsqueeze(input=my_tensor, dim=2) torch.unsqueeze(input=my_tensor, dim=-1) # tensor([[[0], [1], [2]], # [[3], [4], [5]], # [[6], [7], [8]], # [[10], [11], [12]]]) torch.unsqueeze(input=my_tensor, dim=3) torch.unsqueeze(input=my_tensor, dim=-1) # tensor([[[[0], [1], [2], [3]], [[4], [5], [6], [7]]], # [[[8], [9], [10], [11]], [[12], [13], [14], [15]]], # [[[16], [17], [18], [19]], [[20], [21], [22], [23]]]]) my_tensor = torch.tensor([[0., 1., 2.], [3., 4., 5.], [6., 7., 8.], [10., 11., 12.]]) torch.unsqueeze(input=my_tensor, dim=0) # tensor([[[0., 1., 2.], # [3., 4., 5.], # [6., 7., 8.], # [10., 11., 12.]]]) my_tensor = torch.tensor([[0.+0.j, 1.+0.j, 2.+0.j], [3.+0.j, 4.+0.j, 5.+0.j], [6.+0.j, 7.+0.j, 8.+0.j], [10.+0.j, 11.+0.j, 12.+0.j]]) torch.unsqueeze(input=my_tensor, dim=0) # tensor([[[0.+0.j, 1.+0.j, 2.+0.j], # [3.+0.j, 4.+0.j, 5.+0.j], # [6.+0.j, 7.+0.j, 8.+0.j], # [10.+0.j, 11.+0.j, 12.+0.j]]]) my_tensor = torch.tensor([[True, False, True], [False, True, False], [True, False, True], [False, True, False]]) torch.unsqueeze(input=my_tensor, dim=0) # tensor([[[True, False, True], # [False, True, False], # [True, False, True], # [False, True, False]]])
The above is the detailed content of unsqueeze in PyTorch. For more information, please follow other related articles on the PHP Chinese website!