首頁  >  文章  >  後端開發  >  python類別參數定義及資料擴充方式unsqueeze/expand

python類別參數定義及資料擴充方式unsqueeze/expand

WBOY
WBOY轉載
2022-08-24 13:32:402426瀏覽

【相關推薦:Python3影片教學

類別的參數定義

將conda環境設定為ai,conda activate ai

這個檔案的由來:

由於在yolov1的pytorch實現的損失函數中,看到繼承了nn .Module,並且其中兩個參數不像c 那裡指定類型,那麼他們的類型是哪裡來的

這裡就是在探索這樣一件事

##操作邏輯:

    先在類別中定義了建構子以及一個自訂函數;
  • 建構函式定義了屬性S、B,自訂函數引進兩個參數,對兩個參數進行呼叫
    • 這裡就說明參數的結構是怎麼樣的,取決於參數被呼叫了什麼東西,例如這裡呼叫了
    • N = box1.size(0) M = box2.size (0)說明了它是類似一個矩陣的東西,對應的box1的定義就是`torch.rand(10,4)
  • import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.autograd import Variable
    
    #探究属性S,B是如何产生的,以及box1、box2是如何产生的、如何调用
    class yoloLoss(nn.Module):
        def __init__(self,S,B):
            self.S=S
            self.B=B
        def compute_iot(self,box1,box2):
            N = box1.size(0)  #调用方式就表示了变量是什么类型,这里是一个张量,其中每个元素是一个tensor,所以是N*4的张量
            M = box2.size(0)
            print(M,N)
    
    yoloLoss1 =yoloLoss(10, 11)
    yoloLoss1.compute_iot(torch.rand(10,4),torch.rand(11,4))

資料擴展

探究unsqueeze以及expand的使用方法,unsqueeze可以增加一個緯度,但是維度的siz只是1而已,而expand就可以將資料複製,將資料變成n

# 获得一开始的初始化数值:tensor([[a1,a2,a3]])
nn1=torch.rand(1,3)
print(nn1)
# unsqueeze是解压的意思,在第i个维度上进行扩展,将其扩展为tensor([[[a1,a2,a3]]])
nn1=nn1.unsqueeze(0)
print("*"*100)
print(nn1)
#利用expand对数据进行扩展
nn1=nn1.expand(1,3,3)
print("*"*100)
print(nn1)

【相關推薦:

Python3影片教學

以上是python類別參數定義及資料擴充方式unsqueeze/expand的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文轉載於:jb51.net。如有侵權,請聯絡admin@php.cn刪除