Maison >développement back-end >Tutoriel Python >Le principe pour créer des CNN équivariants réguliers
Le principe est simplement énoncé comme « Laissez le noyau tourner » et nous nous concentrerons dans cet article sur la façon dont vous pouvez l'appliquer dans vos architectures.
Les architectures équivariantes permettent de former des modèles indifférents à certaines actions de groupe.
Pour comprendre ce que cela signifie exactement, entraînons ce modèle CNN simple sur l'ensemble de données MNIST (un ensemble de données de chiffres manuscrits de 0 à 9).
class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1) self.max_1 = nn.MaxPool2d(kernel_size=2) self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1) self.max_2 = nn.MaxPool2d(kernel_size=2) self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7) self.dense = nn.Linear(in_features=16, out_features=10) def forward(self, x: torch.Tensor): x = nn.functional.silu(self.cl1(x)) x = self.max_1(x) x = nn.functional.silu(self.cl2(x)) x = self.max_2(x) x = nn.functional.silu(self.cl3(x)) x = x.view(len(x), -1) logits = self.dense(x) return logits
Accuracy on test | Accuracy on 90-degree rotated test |
---|---|
97.3% | 15.1% |
Tableau 1 : Précision des tests du modèle SimpleCNN
Comme prévu, nous obtenons une précision de plus de 95 % sur l'ensemble de données de test, mais que se passe-t-il si nous faisons pivoter l'image de 90 degrés ? Sans aucune contre-mesure appliquée, les résultats tombent à peine meilleurs que ce que l’on aurait pu deviner. Ce modèle serait inutile pour les applications générales.
En revanche, formons une architecture équivariante similaire avec le même nombre de paramètres, où les actions de groupe sont exactement les rotations de 90 degrés.
Accuracy on test | Accuracy on 90-degree rotated test |
---|---|
96.5% | 96.5% |
Tableau 2 : Test de précision du modèle EqCNN avec le même nombre de paramètres que le modèle SimpleCNN
La précision reste la même, et nous n'avons même pas opté pour l'augmentation des données.
Ces modèles deviennent encore plus impressionnants avec des données 3D, mais nous nous en tiendrons à cet exemple pour explorer l'idée de base.
Si vous souhaitez le tester par vous-même, vous pouvez accéder gratuitement à tout le code écrit en PyTorch et JAX sous Github-Repo, et la formation avec Docker ou Podman est possible avec seulement deux commandes.
Amusez-vous !
Les architectures équivariantes garantissent la stabilité des fonctionnalités sous certaines actions de groupe. Les groupes sont des structures simples où les éléments du groupe peuvent être combinés, inversés ou ne rien faire.
Vous pouvez rechercher la définition formelle sur Wikipédia si vous êtes intéressé.
Pour nos besoins, vous pouvez penser à un groupe de rotations de 90 degrés agissant sur des images carrées. Nous pouvons faire pivoter une image de 90, 180, 270 ou 360 degrés. Pour inverser l'action, nous appliquons respectivement une rotation de 270, 180, 90 ou 0 degrés. Il est simple de voir que nous pouvons combiner, inverser ou ne rien faire avec le groupe noté C4 . L'image visualise toutes les actions sur une image.
Figure 1 : Image MNIST pivotée de 90°, 180°, 270°, 360°, respectivement
Now, given an input image
x
, our CNN model classifier
fθ
, and an arbitrary 90-degree rotation
g
, the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x)
Generally speaking, we want our image-based model to have the same outputs when rotated.
As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.
The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.
So, in code, our CNN kernel
x = nn.functional.silu(self.cl1(x))
now acts on all four rotated images:
x_0 = x x_90 = torch.rot90(x, k=1, dims=(2, 3)) x_180 = torch.rot90(x, k=2, dims=(2, 3)) x_270 = torch.rot90(x, k=3, dims=(2, 3)) x_0 = nn.functional.silu(self.cl1(x_0)) x_90 = nn.functional.silu(self.cl1(x_90)) x_180 = nn.functional.silu(self.cl1(x_180)) x_270 = nn.functional.silu(self.cl1(x_270))
Or more compactly written as a 3D convolution:
self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3)) ... x = torch.stack([x_0, x_90, x_180, x_270], dim=-3) x = nn.functional.silu(self.cl1(x))
The resulting equivariant model has just a few lines more compared to the version above:
class EqCNN(nn.Module): def __init__(self): super(EqCNN, self).__init__() self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3)) self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2)) self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3)) self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2)) self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5)) self.dense = nn.Linear(in_features=16, out_features=10) def forward(self, x: torch.Tensor): x_0 = x x_90 = torch.rot90(x, k=1, dims=(2, 3)) x_180 = torch.rot90(x, k=2, dims=(2, 3)) x_270 = torch.rot90(x, k=3, dims=(2, 3)) x = torch.stack([x_0, x_90, x_180, x_270], dim=-3) x = nn.functional.silu(self.cl1(x)) x = self.max_1(x) x = nn.functional.silu(self.cl2(x)) x = self.max_2(x) x = nn.functional.silu(self.cl3(x)) x = x.squeeze() x = torch.max(x, dim=-1).values logits = self.dense(x) return logits
But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.
This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.
To understand what is happening, let us plot the feature maps after the first convolution stage.
Figure 2: Feature maps for all four rotations
And now the same features after we rotate the input by 90 degrees.
Figure 3 : Cartes de caractéristiques pour les quatre rotations après la rotation de l'image d'entrée
J'ai codé par couleur les cartes correspondantes. Chaque carte de fonctionnalités est décalée de un. Comme l'opérateur max final calcule le même résultat pour ces cartes de caractéristiques décalées, nous obtenons les mêmes résultats.
Dans mon code, je n'ai pas effectué de rotation après la convolution finale, car mes noyaux condensent l'image en un tableau unidimensionnel. Si vous souhaitez développer cet exemple, vous devrez tenir compte de ce fait.
La comptabilisation des actions de groupe ou des « rotations du noyau » joue un rôle essentiel dans la conception d'architectures plus sophistiquées.
Non, nous payons en vitesse de calcul, en biais inductif et en une mise en œuvre plus complexe.
Ce dernier point est quelque peu résolu avec des bibliothèques telles que E3NN, où la plupart des mathématiques lourdes sont abstraites. Néanmoins, il faut tenir compte de beaucoup de choses lors de la conception de l'architecture.
Une faiblesse superficielle est le coût de calcul 4x pour le calcul de toutes les couches d'entités pivotées. Cependant, le matériel moderne doté de parallélisation de masse peut facilement contrecarrer cette charge. En revanche, la formation d’un simple CNN avec augmentation des données dépasserait facilement 10 fois le temps de formation. Cela est encore pire pour les rotations 3D où l'augmentation des données nécessiterait environ 500 fois la quantité d'entraînement pour compenser toutes les rotations possibles.
Dans l'ensemble, la conception d'un modèle d'équivariance est le plus souvent un prix qui vaut la peine d'être payé si l'on veut des fonctionnalités stables.
Les conceptions de modèles équivariants ont explosé ces dernières années, et dans cet article, nous avons à peine effleuré la surface. En fait, nous n'avons même pas exploité pleinement C4 groupe encore. Nous aurions pu utiliser des noyaux entièrement 3D. Cependant, notre modèle atteint déjà une précision de plus de 95 %, il n'y a donc aucune raison d'aller plus loin avec cet exemple.
Outre les CNN, les chercheurs ont réussi à traduire ces principes en groupes continus, notamment SO(2) (le groupe de toutes les rotations dans le plan) et SE(3) (le groupe de toutes les traductions et rotations dans l'espace 3D).
D'après mon expérience, ces modèles sont absolument époustouflants et atteignent des performances, lorsqu'ils sont entraînés à partir de zéro, comparables aux performances des modèles de base entraînés sur des ensembles de données plusieurs fois plus grands.
Faites-moi savoir si vous souhaitez que j'écrive davantage sur ce sujet.
Au cas où vous souhaiteriez une introduction formelle à ce sujet, voici une excellente compilation d'articles, couvrant l'histoire complète de l'équivariance dans l'apprentissage automatique.
AEN
Je prévois en fait de créer un tutoriel pratique et approfondi sur ce sujet. Vous pouvez déjà vous inscrire à ma liste de diffusion et je vous fournirai des versions gratuites au fil du temps, ainsi qu'un canal direct pour les commentaires et les questions-réponses.
À bientôt :)
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!