Maison >Périphériques technologiques >IA >Comprenez en profondeur les fonctions principales de Pytorch : dérivation automatique !
Salut, je m'appelle Xiaozhuang !
À propos de l'opération de dérivation automatique dans pytorch, présentez le concept de dérivation automatique dans pytorch
La dérivation automatique est une fonction importante du cadre d'apprentissage en profondeur, utilisée pour calculer les gradients et réaliser la mise à jour et l'optimisation des paramètres.
PyTorch est un framework d'apprentissage en profondeur couramment utilisé qui utilise des graphiques de calcul dynamiques et des mécanismes de dérivation automatique pour simplifier le processus de calcul du gradient.
La dérivation automatique est une fonction importante du cadre d'apprentissage automatique. Elle peut calculer automatiquement la dérivée (gradient) d'une fonction, simplifiant ainsi le processus de formation de modèles d'apprentissage profond. En apprentissage profond, les modèles contiennent souvent un grand nombre de paramètres, et le calcul manuel des gradients peut devenir complexe et sujet aux erreurs. PyTorch fournit une fonction de dérivation automatique, permettant aux utilisateurs de calculer facilement les gradients et d'effectuer une rétropropagation pour mettre à jour les paramètres du modèle. L’introduction de cette fonctionnalité améliore considérablement l’efficacité et la facilité d’utilisation du deep learning.
La fonction de dérivation automatique de PyTorch est basée sur des graphiques de calcul dynamique. Un graphe de calcul est une structure graphique utilisée pour représenter le processus de calcul de fonction, dans laquelle les nœuds représentent les opérations et les arêtes représentent le flux de données. Contrairement aux graphiques de calcul statiques, la structure des graphiques de calcul dynamiques peut être générée dynamiquement sur la base du processus d'exécution réel, plutôt que d'être définie à l'avance. Cette conception rend PyTorch flexible et évolutif pour s'adapter aux différents besoins informatiques. Grâce à des graphiques de calcul dynamiques, PyTorch peut enregistrer l'historique des opérations, effectuer une rétropropagation et calculer les gradients selon les besoins. Cela fait de PyTorch l’un des frameworks les plus utilisés dans le domaine du deep learning.
Dans PyTorch, chaque opération de l'utilisateur est enregistrée pour construire le graphe de calcul. De cette façon, lorsque le gradient doit être calculé, PyTorch peut effectuer une rétropropagation en fonction du graphique de calcul et calculer automatiquement le gradient de chaque paramètre par rapport à la fonction de perte. Ce mécanisme de dérivation automatique basé sur des graphiques de calcul dynamiques rend PyTorch flexible et évolutif, le rendant adapté à diverses structures de réseaux neuronaux complexes.
Dans PyTorch, le tensor est la structure de données de base pour la dérivation automatique. Les tenseurs sont similaires aux tableaux multidimensionnels dans NumPy, mais disposent de fonctionnalités supplémentaires telles que la dérivation automatique. Grâce à la classe torch.Tensor, les utilisateurs peuvent créer des tenseurs et effectuer diverses opérations sur eux.
import torch# 创建张量x = torch.tensor([2.0], requires_grad=True)
Dans l'exemple ci-dessus, require_grad=True signifie que nous voulons différencier automatiquement ce tenseur.
Chaque opération effectuée créera un nœud dans le graphe informatique. PyTorch propose diverses opérations tensorielles, telles que des fonctions d'addition, de multiplication, d'activation, etc., qui laisseront des traces dans le graphe de calcul.
# 张量操作y = x ** 2z = 2 * y + 3
Dans l'exemple ci-dessus, les processus de calcul de y et z sont enregistrés dans le graphique de calcul.
Une fois le graphique de calcul construit, la rétropropagation peut être effectuée en appelant la méthode .backward() pour calculer automatiquement le gradient.
# 反向传播z.backward()
À ce stade, le dégradé de x peut être obtenu en accédant à x.grad.
# 获取梯度print(x.grad)
Parfois, nous souhaitons désactiver le suivi des dégradés pour certaines opérations, nous pouvons utiliser le gestionnaire de contexte torch.no_grad().
with torch.no_grad():# 在这个区域内的操作不会被记录在计算图中w = x + 1
Dans la boucle d'entraînement, il est généralement nécessaire d'effacer les dégradés avant chaque rétropropagation pour éviter l'accumulation de dégradés.
# 清零梯度x.grad.zero_()
Pour démontrer plus spécifiquement le processus de dérivation automatique, considérons un simple problème de régression linéaire. Nous définissons un modèle linéaire et une fonction de perte d'erreur quadratique moyenne et utilisons la dérivation automatique pour optimiser les paramètres du modèle.
import torch# 数据准备X = torch.tensor([[1.0], [2.0], [3.0]])y = torch.tensor([[2.0], [4.0], [6.0]])# 模型参数w = torch.tensor([[0.0]], requires_grad=True)b = torch.tensor([[0.0]], requires_grad=True)# 模型和损失函数def linear_model(X, w, b):return X @ w + bdef mean_squared_error(y_pred, y_true):return ((y_pred - y_true) ** 2).mean()# 训练循环learning_rate = 0.01epochs = 100for epoch in range(epochs):# 前向传播y_pred = linear_model(X, w, b)loss = mean_squared_error(y_pred, y)# 反向传播loss.backward()# 更新参数with torch.no_grad():w -= learning_rate * w.gradb -= learning_rate * b.grad# 清零梯度w.grad.zero_()b.grad.zero_()# 打印最终参数print("训练后的参数:")print("权重 w:", w)print("偏置 b:", b)
Dans cet exemple, nous définissons un modèle linéaire simple et une fonction de perte d'erreur quadratique moyenne. Grâce à plusieurs boucles d'entraînement itératives, les paramètres w et b du modèle seront optimisés pour minimiser la fonction de perte.
Enfin
Grâce à des graphiques de calcul dynamique et à des calculs de gradient, les utilisateurs peuvent facilement définir des structures de réseaux neuronaux complexes et mettre en œuvre des algorithmes d'optimisation tels que la descente de gradient par dérivation automatique.
Cela permet aux chercheurs et ingénieurs en apprentissage profond de se concentrer davantage sur la conception de modèles et les expériences sans avoir à se soucier des détails des calculs de gradient.
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!