PyTorch でフラット化する

Patricia Arquette
Patricia Arquetteオリジナル
2024-11-06 05:58:03354ブラウズ

Flatten in PyTorch

コーヒー買ってきて☕

*メモ:

  • 私の投稿では flatten() と ravel() について説明しています。
  • 私の投稿では un flatten() について説明しています。

Flatten() は、以下に示すように、0 個以上の要素の 0D またはそれ以上の D テンソルから次元を選択し、0 個以上の要素の 1D またはそれ以上の D テンソルを取得することで、0 個以上の次元を削除できます。

*メモ:

  • 初期化の最初の引数は start_dim(Optional-Default:1-Type:int) です。
  • 初期化の 2 番目の引数は end_dim(Optional-Default:-1-Type:int) です。
  • 最初の引数は input(必須型: int、float、complex、または bool のテンソル) です。
  • Flatten() は 0D テンソルを 1D テンソルに変更できます。
  • Flatten() は 1D テンソルに対しては何も行いません。
  • Flatten() と flatten() の違いは次のとおりです。
    • Flatten() の start_dim のデフォルト値は 1 ですが、 flatten() の start_dim のデフォルト値は 0 です。
    • 基本的に、Flatten() はモデルの定義に使用されますが、 flatten() はモデルの定義には使用されません。
import torch
from torch import nn

flatten = nn.Flatten()
flatten
# Flatten(start_dim=1, end_dim=-1)

flatten.start_dim
# 1

flatten.end_dim
# -1

my_tensor = torch.tensor(7)

flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten(input=my_tensor)
# tensor([7])

my_tensor = torch.tensor([7, 1, -8, 3, -6, 0])

flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])

my_tensor = torch.tensor([[7, 1, -8], [3, -6, 0]])

flatten = nn.Flatten(start_dim=0, end_dim=1)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=1)
flatten = nn.Flatten(start_dim=-2, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])

flatten = nn.Flatten()
flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=0, end_dim=-2)
flatten = nn.Flatten(start_dim=1, end_dim=1)
flatten = nn.Flatten(start_dim=1, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=1)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=0)
flatten = nn.Flatten(start_dim=-2, end_dim=-2)
flatten(input=my_tensor)
# tensor([[7, 1, -8], [3, -6, 0]])

my_tensor = torch.tensor([[[7], [1], [-8]], [[3], [-6], [0]]])

flatten = nn.Flatten(start_dim=0, end_dim=2)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-3, end_dim=2)
flatten = nn.Flatten(start_dim=-3, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])

flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-3)
flatten = nn.Flatten(start_dim=1, end_dim=1)
flatten = nn.Flatten(start_dim=1, end_dim=-2)
flatten = nn.Flatten(start_dim=2, end_dim=2)
flatten = nn.Flatten(start_dim=2, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=2)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=1)
flatten = nn.Flatten(start_dim=-2, end_dim=-2)
flatten = nn.Flatten(start_dim=-3, end_dim=0)
flatten = nn.Flatten(start_dim=-3, end_dim=-3)
flatten(input=my_tensor)
# tensor([[[7], [1], [-8]], [[3], [-6], [0]]])

flatten = nn.Flatten(start_dim=0, end_dim=1)
flatten = nn.Flatten(start_dim=0, end_dim=-2)
flatten = nn.Flatten(start_dim=-3, end_dim=1)
flatten = nn.Flatten(start_dim=-3, end_dim=-2)
flatten(input=my_tensor)
# tensor([[7], [1], [-8], [3], [-6], [0]])

flatten = nn.Flatten()
flatten = nn.Flatten(start_dim=1, end_dim=2)
flatten = nn.Flatten(start_dim=1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=2)
flatten = nn.Flatten(start_dim=-2, end_dim=-1)
flatten(input=my_tensor)
# tensor([[7, 1, -8], [3, -6, 0]])

my_tensor = torch.tensor([[[7.], [1.], [-8.]], [[3.], [-6.], [0.]]])

flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[7., 1., -8.], [3., -6., 0.]])

my_tensor = torch.tensor([[[7.+0.j], [1.+0.j], [-8.+0.j]],
                          [[3.+0.j], [-6.+0.j], [0.+0.j]]])
flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[7.+0.j, 1.+0.j, -8.+0.j],
#         [3.+0.j, -6.+0.j, 0.+0.j]])

my_tensor = torch.tensor([[[True], [False], [True]],
                          [[False], [True], [False]]])
flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[True, False, True],
#         [False, True, False]])

以上がPyTorch でフラット化するの詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

声明:
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。