Home >Backend Development >Python Tutorial >unsqueeze in PyTorch

unsqueeze in PyTorch

Susan Sarandon
Susan SarandonOriginal
2025-01-05 04:49:46301browse

unsqueeze in PyTorch

Buy Me a Coffee☕

*My post explains squeeze().

unsqueeze() can get the 1D or more D tensor of zero or more elements with additional dimension whose size is 1 from the 0D or more D tensor of zero or more elements as shown below:

*Memos:

  • unsqueeze() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool).
  • The 2nd argument with torch or the 1st argument with a tensor is dim(Required-Type:int). *It can add the dimension whose size is 1 to a specific position.
import torch

my_tensor = torch.tensor([[0, 1, 2],
                          [3, 4, 5],
                          [6, 7, 8],
                          [10, 11, 12]])
torch.unsqueeze(input=my_tensor, dim=0)
my_tensor.unsqueeze(dim=0)
torch.unsqueeze(input=my_tensor, dim=-3)
# tensor([[[0, 1, 2],
#          [3, 4, 5],
#          [6, 7, 8]
#          [10, 11, 12]]])

torch.unsqueeze(input=my_tensor, dim=1)
torch.unsqueeze(input=my_tensor, dim=-2)
# tensor([[[0, 1, 2]],
#         [[3, 4, 5]],
#         [[6, 7, 8]]
#         [[10, 11, 12]]])

torch.unsqueeze(input=my_tensor, dim=2)
torch.unsqueeze(input=my_tensor, dim=-1)
# tensor([[[0], [1], [2]],
#         [[3], [4], [5]],
#         [[6], [7], [8]],
#         [[10], [11], [12]]])

torch.unsqueeze(input=my_tensor, dim=3)
torch.unsqueeze(input=my_tensor, dim=-1)
# tensor([[[[0], [1], [2], [3]], [[4], [5], [6], [7]]],
#         [[[8], [9], [10], [11]], [[12], [13], [14], [15]]],
#         [[[16], [17], [18], [19]], [[20], [21], [22], [23]]]])

my_tensor = torch.tensor([[0., 1., 2.],
                          [3., 4., 5.],
                          [6., 7., 8.],
                          [10., 11., 12.]])
torch.unsqueeze(input=my_tensor, dim=0)
# tensor([[[0., 1., 2.],
#          [3., 4., 5.],
#          [6., 7., 8.],
#          [10., 11., 12.]]])

my_tensor = torch.tensor([[0.+0.j, 1.+0.j, 2.+0.j],
                          [3.+0.j, 4.+0.j, 5.+0.j],
                          [6.+0.j, 7.+0.j, 8.+0.j],
                          [10.+0.j, 11.+0.j, 12.+0.j]])
torch.unsqueeze(input=my_tensor, dim=0)
# tensor([[[0.+0.j, 1.+0.j, 2.+0.j],
#          [3.+0.j, 4.+0.j, 5.+0.j],
#          [6.+0.j, 7.+0.j, 8.+0.j],
#          [10.+0.j, 11.+0.j, 12.+0.j]]])

my_tensor = torch.tensor([[True, False, True],
                          [False, True, False],
                          [True, False, True],
                          [False, True, False]])
torch.unsqueeze(input=my_tensor, dim=0)
# tensor([[[True, False, True],
#          [False, True, False],
#          [True, False, True],
#          [False, True, False]]])

The above is the detailed content of unsqueeze in PyTorch. For more information, please follow other related articles on the PHP Chinese website!

Statement:
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn