Home > Article > Backend Development > Python class parameter definition and data expansion method unsqueeze/expand
[Related recommendations: Python3 video tutorial]
Set the conda environment to ai, conda activate ai
The origin of this file:
Because in the loss function implemented by yolov1's pytorch, we see that nn is inherited .Module, and two of the parameters are not of type specified in c, so where do their types come from?
Here we are exploring such a thing
Operation logic:
N = box1.size(0) M = box2.size (0)
explains that it is something like a matrix, and the corresponding definition of box1 is `torch.rand(10,4)import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable #探究属性S,B是如何产生的,以及box1、box2是如何产生的、如何调用 class yoloLoss(nn.Module): def __init__(self,S,B): self.S=S self.B=B def compute_iot(self,box1,box2): N = box1.size(0) #调用方式就表示了变量是什么类型,这里是一个张量,其中每个元素是一个tensor,所以是N*4的张量 M = box2.size(0) print(M,N) yoloLoss1 =yoloLoss(10, 11) yoloLoss1.compute_iot(torch.rand(10,4),torch.rand(11,4))
Explore the use of unsqueeze and expand. Unsqueeze can add a latitude, but the siz of the dimension is only 1, and expand can copy the data and change the data to n
# 获得一开始的初始化数值:tensor([[a1,a2,a3]]) nn1=torch.rand(1,3) print(nn1) # unsqueeze是解压的意思,在第i个维度上进行扩展,将其扩展为tensor([[[a1,a2,a3]]]) nn1=nn1.unsqueeze(0) print("*"*100) print(nn1) #利用expand对数据进行扩展 nn1=nn1.expand(1,3,3) print("*"*100) print(nn1)
【Related recommendations: Python3 video tutorial】
The above is the detailed content of Python class parameter definition and data expansion method unsqueeze/expand. For more information, please follow other related articles on the PHP Chinese website!