Rumah  >  Artikel  >  pembangunan bahagian belakang  >  Bagaimana untuk melaksanakan perambatan balik softmax dalam Python.

Bagaimana untuk melaksanakan perambatan balik softmax dalam Python.

WBOY
WBOYke hadapan
2023-05-09 08:05:531174semak imbas

Terbitan rambatan balik

Seperti yang anda lihat, softmax mengira input berbilang neuron Apabila merambat balik terbitan, anda perlu mempertimbangkan untuk mendapatkan parameter neuron yang berbeza.

Pertimbangkan dua situasi:

  • Apabila parameter untuk terbitan terletak dalam pengangka

  • Apabila parameter untuk terbitan ialah terletak di Apabila penyebutnya ialah

Bagaimana untuk melaksanakan perambatan balik softmax dalam Python.

Apabila parameter untuk terbitan berada dalam pengangka:

Bagaimana untuk melaksanakan perambatan balik softmax dalam Python.

Apabila derivasi dilakukan Apabila parameter terletak dalam penyebut (ez2 atau ez3 kedua-dua ini adalah simetri, hasil terbitan adalah sama):

Bagaimana untuk melaksanakan perambatan balik softmax dalam Python.

Bagaimana untuk melaksanakan perambatan balik softmax dalam Python.

Kod

import torch
import math

def my_softmax(features):
    _sum = 0
    for i in features:
        _sum += math.e ** i
    return torch.Tensor([ math.e ** i / _sum for i in features ])

def my_softmax_grad(outputs):    
    n = len(outputs)
    grad = []
    for i in range(n):
        temp = []
        for j in range(n):
            if i == j:
                temp.append(outputs[i] * (1- outputs[i]))
            else:
                temp.append(-outputs[j] * outputs[i])
        grad.append(torch.Tensor(temp))
    return grad

if __name__ == '__main__':

    features = torch.randn(10)
    features.requires_grad_()

    torch_softmax = torch.nn.functional.softmax
    p1 = torch_softmax(features,dim=0)
    p2 = my_softmax(features)
    print(torch.allclose(p1,p2))
    
    n = len(p1)
    p2_grad = my_softmax_grad(p2)
    for i in range(n):
        p1_grad = torch.autograd.grad(p1[i],features, retain_graph=True)
        print(torch.allclose(p1_grad[0], p2_grad[i]))

Atas ialah kandungan terperinci Bagaimana untuk melaksanakan perambatan balik softmax dalam Python.. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Kenyataan:
Artikel ini dikembalikan pada:yisu.com. Jika ada pelanggaran, sila hubungi admin@php.cn Padam