首页 >后端开发 >Python教程 >在 PyTorch 中展平

在 PyTorch 中展平

Patricia Arquette
Patricia Arquette原创
2024-11-06 05:58:03345浏览

Flatten in PyTorch

请我喝杯咖啡☕

*备忘录:

  • 我的帖子解释了 flatten() 和 ravel()。
  • 我的帖子解释了 unflatten()。

Flatten() 可以通过从零个或多个元素的 0D 或多个 D 张量中选择维度来移除零个或多个维度,得到零个或多个元素的 1D 或多个 D 张量,如下所示:

*备忘录:

  • 初始化的第一个参数是 start_dim(Optional-Default:1-Type:int)。
  • 初始化的第二个参数是 end_dim(可选-默认:-1-类型:int)。
  • 第一个参数是输入(必需类型:int、float、complex 或 bool 的张量)。
  • Flatten() 可以将 0D 张量更改为 1D 张量。
  • Flatten() 对于一维张量没有任何作用。
  • Flatten() 和 flatten() 的区别是:
    • Flatten() 的 start_dim 默认值为 1,而 flatten() 的 start_dim 默认值为 0。
    • 基本上,Flatten() 用于定义模型,而 flatten() 不用于定义模型。
import torch
from torch import nn

flatten = nn.Flatten()
flatten
# Flatten(start_dim=1, end_dim=-1)

flatten.start_dim
# 1

flatten.end_dim
# -1

my_tensor = torch.tensor(7)

flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten(input=my_tensor)
# tensor([7])

my_tensor = torch.tensor([7, 1, -8, 3, -6, 0])

flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])

my_tensor = torch.tensor([[7, 1, -8], [3, -6, 0]])

flatten = nn.Flatten(start_dim=0, end_dim=1)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=1)
flatten = nn.Flatten(start_dim=-2, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])

flatten = nn.Flatten()
flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=0, end_dim=-2)
flatten = nn.Flatten(start_dim=1, end_dim=1)
flatten = nn.Flatten(start_dim=1, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=1)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=0)
flatten = nn.Flatten(start_dim=-2, end_dim=-2)
flatten(input=my_tensor)
# tensor([[7, 1, -8], [3, -6, 0]])

my_tensor = torch.tensor([[[7], [1], [-8]], [[3], [-6], [0]]])

flatten = nn.Flatten(start_dim=0, end_dim=2)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-3, end_dim=2)
flatten = nn.Flatten(start_dim=-3, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])

flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-3)
flatten = nn.Flatten(start_dim=1, end_dim=1)
flatten = nn.Flatten(start_dim=1, end_dim=-2)
flatten = nn.Flatten(start_dim=2, end_dim=2)
flatten = nn.Flatten(start_dim=2, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=2)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=1)
flatten = nn.Flatten(start_dim=-2, end_dim=-2)
flatten = nn.Flatten(start_dim=-3, end_dim=0)
flatten = nn.Flatten(start_dim=-3, end_dim=-3)
flatten(input=my_tensor)
# tensor([[[7], [1], [-8]], [[3], [-6], [0]]])

flatten = nn.Flatten(start_dim=0, end_dim=1)
flatten = nn.Flatten(start_dim=0, end_dim=-2)
flatten = nn.Flatten(start_dim=-3, end_dim=1)
flatten = nn.Flatten(start_dim=-3, end_dim=-2)
flatten(input=my_tensor)
# tensor([[7], [1], [-8], [3], [-6], [0]])

flatten = nn.Flatten()
flatten = nn.Flatten(start_dim=1, end_dim=2)
flatten = nn.Flatten(start_dim=1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=2)
flatten = nn.Flatten(start_dim=-2, end_dim=-1)
flatten(input=my_tensor)
# tensor([[7, 1, -8], [3, -6, 0]])

my_tensor = torch.tensor([[[7.], [1.], [-8.]], [[3.], [-6.], [0.]]])

flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[7., 1., -8.], [3., -6., 0.]])

my_tensor = torch.tensor([[[7.+0.j], [1.+0.j], [-8.+0.j]],
                          [[3.+0.j], [-6.+0.j], [0.+0.j]]])
flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[7.+0.j, 1.+0.j, -8.+0.j],
#         [3.+0.j, -6.+0.j, 0.+0.j]])

my_tensor = torch.tensor([[[True], [False], [True]],
                          [[False], [True], [False]]])
flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[True, False, True],
#         [False, True, False]])

以上是在 PyTorch 中展平的详细内容。更多信息请关注PHP中文网其他相关文章!

声明:
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn