Maison >développement back-end >Tutoriel Python >Le principe pour créer des CNN équivariants réguliers

Le principe pour créer des CNN équivariants réguliers

王林
王林original
2024-07-18 11:29:181126parcourir

Le principe est simplement énoncé comme « Laissez le noyau tourner » et nous nous concentrerons dans cet article sur la façon dont vous pouvez l'appliquer dans vos architectures.

Les architectures équivariantes permettent de former des modèles indifférents à certaines actions de groupe.

Pour comprendre ce que cela signifie exactement, entraînons ce modèle CNN simple sur l'ensemble de données MNIST (un ensemble de données de chiffres manuscrits de 0 à 9).

class SimpleCNN(nn.Module):

    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.max_1 = nn.MaxPool2d(kernel_size=2)
        self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.max_2 = nn.MaxPool2d(kernel_size=2)
        self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7)
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)
        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)
        x = nn.functional.silu(self.cl3(x))
        x = x.view(len(x), -1)
        logits = self.dense(x)
        return logits
Accuracy on test Accuracy on 90-degree rotated test
97.3% 15.1%

Tableau 1 : Précision des tests du modèle SimpleCNN

Comme prévu, nous obtenons une précision de plus de 95 % sur l'ensemble de données de test, mais que se passe-t-il si nous faisons pivoter l'image de 90 degrés ? Sans aucune contre-mesure appliquée, les résultats tombent à peine meilleurs que ce que l’on aurait pu deviner. Ce modèle serait inutile pour les applications générales.

En revanche, formons une architecture équivariante similaire avec le même nombre de paramètres, où les actions de groupe sont exactement les rotations de 90 degrés.

Accuracy on test Accuracy on 90-degree rotated test
96.5% 96.5%

Tableau 2 : Test de précision du modèle EqCNN avec le même nombre de paramètres que le modèle SimpleCNN

La précision reste la même, et nous n'avons même pas opté pour l'augmentation des données.

Ces modèles deviennent encore plus impressionnants avec des données 3D, mais nous nous en tiendrons à cet exemple pour explorer l'idée de base.

Si vous souhaitez le tester par vous-même, vous pouvez accéder gratuitement à tout le code écrit en PyTorch et JAX sous Github-Repo, et la formation avec Docker ou Podman est possible avec seulement deux commandes.

Amusez-vous !

Alors, qu’est-ce que l’équivariance ?

Les architectures équivariantes garantissent la stabilité des fonctionnalités sous certaines actions de groupe. Les groupes sont des structures simples où les éléments du groupe peuvent être combinés, inversés ou ne rien faire.

Vous pouvez rechercher la définition formelle sur Wikipédia si vous êtes intéressé.

Pour nos besoins, vous pouvez penser à un groupe de rotations de 90 degrés agissant sur des images carrées. Nous pouvons faire pivoter une image de 90, 180, 270 ou 360 degrés. Pour inverser l'action, nous appliquons respectivement une rotation de 270, 180, 90 ou 0 degrés. Il est simple de voir que nous pouvons combiner, inverser ou ne rien faire avec le groupe noté C4C_4C4 . L'image visualise toutes les actions sur une image.

Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively
Figure 1 : Image MNIST pivotée de 90°, 180°, 270°, 360°, respectivement

Now, given an input image xxx , our CNN model classifier fθf_\thetafθ , and an arbitrary 90-degree rotation ggg , the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x) f_\theta(\text{rotate } x \text{ by } g) = f_\theta(x) fθ(rotate x by g)=fθ(x)

Generally speaking, we want our image-based model to have the same outputs when rotated.

As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.

How to Make Our CNN Equivariant

The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.

So, in code, our CNN kernel

x = nn.functional.silu(self.cl1(x))

now acts on all four rotated images:

x_0 = x
x_90 = torch.rot90(x, k=1, dims=(2, 3))
x_180 = torch.rot90(x, k=2, dims=(2, 3))
x_270 = torch.rot90(x, k=3, dims=(2, 3))

x_0 = nn.functional.silu(self.cl1(x_0))
x_90 = nn.functional.silu(self.cl1(x_90))
x_180 = nn.functional.silu(self.cl1(x_180))
x_270 = nn.functional.silu(self.cl1(x_270))

Or more compactly written as a 3D convolution:

self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
...
x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
x = nn.functional.silu(self.cl1(x))

The resulting equivariant model has just a few lines more compared to the version above:

class EqCNN(nn.Module):

    def __init__(self):
        super(EqCNN, self).__init__()
        self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
        self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3))
        self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5))
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x_0 = x
        x_90 = torch.rot90(x, k=1, dims=(2, 3))
        x_180 = torch.rot90(x, k=2, dims=(2, 3))
        x_270 = torch.rot90(x, k=3, dims=(2, 3))

        x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)

        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)

        x = nn.functional.silu(self.cl3(x))

        x = x.squeeze()
        x = torch.max(x, dim=-1).values
        logits = self.dense(x)
        return logits

But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.

This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.

To understand what is happening, let us plot the feature maps after the first convolution stage.

Figure 2: Feature maps for all four rotations
Figure 2: Feature maps for all four rotations

And now the same features after we rotate the input by 90 degrees.

Figure 3: Feature maps for all four rotations after the input image was rotated
Figure 3 : Cartes de caractéristiques pour les quatre rotations après la rotation de l'image d'entrée

J'ai codé par couleur les cartes correspondantes. Chaque carte de fonctionnalités est décalée de un. Comme l'opérateur max final calcule le même résultat pour ces cartes de caractéristiques décalées, nous obtenons les mêmes résultats.

Dans mon code, je n'ai pas effectué de rotation après la convolution finale, car mes noyaux condensent l'image en un tableau unidimensionnel. Si vous souhaitez développer cet exemple, vous devrez tenir compte de ce fait.

La comptabilisation des actions de groupe ou des « rotations du noyau » joue un rôle essentiel dans la conception d'architectures plus sophistiquées.

Est-ce un déjeuner gratuit ?

Non, nous payons en vitesse de calcul, en biais inductif et en une mise en œuvre plus complexe.

Ce dernier point est quelque peu résolu avec des bibliothèques telles que E3NN, où la plupart des mathématiques lourdes sont abstraites. Néanmoins, il faut tenir compte de beaucoup de choses lors de la conception de l'architecture.

Une faiblesse superficielle est le coût de calcul 4x pour le calcul de toutes les couches d'entités pivotées. Cependant, le matériel moderne doté de parallélisation de masse peut facilement contrecarrer cette charge. En revanche, la formation d’un simple CNN avec augmentation des données dépasserait facilement 10 fois le temps de formation. Cela est encore pire pour les rotations 3D où l'augmentation des données nécessiterait environ 500 fois la quantité d'entraînement pour compenser toutes les rotations possibles.

Dans l'ensemble, la conception d'un modèle d'équivariance est le plus souvent un prix qui vaut la peine d'être payé si l'on veut des fonctionnalités stables.

Quelle est la prochaine étape ?

Les conceptions de modèles équivariants ont explosé ces dernières années, et dans cet article, nous avons à peine effleuré la surface. En fait, nous n'avons même pas exploité pleinement C4C_4C4 groupe encore. Nous aurions pu utiliser des noyaux entièrement 3D. Cependant, notre modèle atteint déjà une précision de plus de 95 %, il n'y a donc aucune raison d'aller plus loin avec cet exemple.

Outre les CNN, les chercheurs ont réussi à traduire ces principes en groupes continus, notamment SO(2) SO(2)SO(2) (le groupe de toutes les rotations dans le plan) et SE(3) SE(3)SE(3) (le groupe de toutes les traductions et rotations dans l'espace 3D).

D'après mon expérience, ces modèles sont absolument époustouflants et atteignent des performances, lorsqu'ils sont entraînés à partir de zéro, comparables aux performances des modèles de base entraînés sur des ensembles de données plusieurs fois plus grands.

Faites-moi savoir si vous souhaitez que j'écrive davantage sur ce sujet.

Autres références

Au cas où vous souhaiteriez une introduction formelle à ce sujet, voici une excellente compilation d'articles, couvrant l'histoire complète de l'équivariance dans l'apprentissage automatique.
AEN

Je prévois en fait de créer un tutoriel pratique et approfondi sur ce sujet. Vous pouvez déjà vous inscrire à ma liste de diffusion et je vous fournirai des versions gratuites au fil du temps, ainsi qu'un canal direct pour les commentaires et les questions-réponses.

À bientôt :)

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn