Heim >Backend-Entwicklung >Python-Tutorial >Das Prinzip zum Aufbau regulärer äquivarianter CNNs

Das Prinzip zum Aufbau regulärer äquivarianter CNNs

王林
王林Original
2024-07-18 11:29:181119Durchsuche

Das eine Prinzip wird einfach als „Lassen Sie den Kernel rotieren“ ausgedrückt und wir werden uns in diesem Artikel darauf konzentrieren, wie Sie es in Ihren Architekturen anwenden können.

Äquivariante Architekturen ermöglichen es uns, Modelle zu trainieren, die gegenüber bestimmten Gruppenaktionen gleichgültig sind.

Um zu verstehen, was das genau bedeutet, trainieren wir dieses einfache CNN-Modell auf dem MNIST-Datensatz (einem Datensatz handgeschriebener Ziffern von 0-9).

class SimpleCNN(nn.Module):

    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.max_1 = nn.MaxPool2d(kernel_size=2)
        self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.max_2 = nn.MaxPool2d(kernel_size=2)
        self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7)
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)
        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)
        x = nn.functional.silu(self.cl3(x))
        x = x.view(len(x), -1)
        logits = self.dense(x)
        return logits
Accuracy on test Accuracy on 90-degree rotated test
97.3% 15.1%

Tabelle 1: Testgenauigkeit des SimpleCNN-Modells

Wie erwartet erreichen wir beim Testdatensatz eine Genauigkeit von über 95 %, aber was ist, wenn wir das Bild um 90 Grad drehen? Ohne Gegenmaßnahmen sinken die Ergebnisse auf ein knapp besseres Ergebnis als geschätzt. Dieses Modell wäre für allgemeine Anwendungen unbrauchbar.

Im Gegensatz dazu trainieren wir eine ähnliche äquivariante Architektur mit der gleichen Anzahl von Parametern, bei der die Gruppenaktionen genau die 90-Grad-Rotationen sind.

Accuracy on test Accuracy on 90-degree rotated test
96.5% 96.5%

Tabelle 2: Testgenauigkeit des EqCNN-Modells mit der gleichen Anzahl an Parametern wie das SimpleCNN-Modell

Die Genauigkeit bleibt gleich und wir haben uns nicht einmal für eine Datenerweiterung entschieden.

Diese Modelle werden mit 3D-Daten noch beeindruckender, aber wir bleiben bei diesem Beispiel, um die Kernidee zu untersuchen.

Falls Sie es selbst testen möchten, können Sie unter Github-Repo kostenlos auf den gesamten in PyTorch und JAX geschriebenen Code zugreifen, und das Training mit Docker oder Podman ist mit nur zwei Befehlen möglich.

Viel Spaß!

Was ist also Äquivarianz?

Äquivariante Architekturen garantieren die Stabilität von Funktionen unter bestimmten Gruppenaktionen. Gruppen sind einfache Strukturen, in denen Gruppenelemente kombiniert, umgekehrt oder gar nichts bewirken können.

Bei Interesse können Sie die formale Definition auf Wikipedia nachschlagen.

Für unsere Zwecke können Sie sich eine Gruppe von 90-Grad-Rotationen vorstellen, die auf quadratische Bilder wirken. Wir können ein Bild um 90, 180, 270 oder 360 Grad drehen. Um die Aktion umzukehren, wenden wir eine Drehung um 270, 180, 90 bzw. 0 Grad an. Es ist leicht zu erkennen, dass wir die als bezeichnete Gruppe kombinieren, umkehren oder nichts tun können C4C_4C4 . Das Bild visualisiert alle Aktionen auf einem Bild.

Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively
Abbildung 1: Gedrehtes MNIST-Bild um 90°, 180°, 270° bzw. 360°

Now, given an input image xxx , our CNN model classifier fθf_\thetafθ , and an arbitrary 90-degree rotation ggg , the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x) f_\theta(\text{rotate } x \text{ by } g) = f_\theta(x) fθ(rotate x by g)=fθ(x)

Generally speaking, we want our image-based model to have the same outputs when rotated.

As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.

How to Make Our CNN Equivariant

The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.

So, in code, our CNN kernel

x = nn.functional.silu(self.cl1(x))

now acts on all four rotated images:

x_0 = x
x_90 = torch.rot90(x, k=1, dims=(2, 3))
x_180 = torch.rot90(x, k=2, dims=(2, 3))
x_270 = torch.rot90(x, k=3, dims=(2, 3))

x_0 = nn.functional.silu(self.cl1(x_0))
x_90 = nn.functional.silu(self.cl1(x_90))
x_180 = nn.functional.silu(self.cl1(x_180))
x_270 = nn.functional.silu(self.cl1(x_270))

Or more compactly written as a 3D convolution:

self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
...
x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
x = nn.functional.silu(self.cl1(x))

The resulting equivariant model has just a few lines more compared to the version above:

class EqCNN(nn.Module):

    def __init__(self):
        super(EqCNN, self).__init__()
        self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
        self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3))
        self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5))
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x_0 = x
        x_90 = torch.rot90(x, k=1, dims=(2, 3))
        x_180 = torch.rot90(x, k=2, dims=(2, 3))
        x_270 = torch.rot90(x, k=3, dims=(2, 3))

        x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)

        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)

        x = nn.functional.silu(self.cl3(x))

        x = x.squeeze()
        x = torch.max(x, dim=-1).values
        logits = self.dense(x)
        return logits

But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.

This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.

To understand what is happening, let us plot the feature maps after the first convolution stage.

Figure 2: Feature maps for all four rotations
Figure 2: Feature maps for all four rotations

And now the same features after we rotate the input by 90 degrees.

Figure 3: Feature maps for all four rotations after the input image was rotated
Abbildung 3: Feature-Maps für alle vier Rotationen, nachdem das Eingabebild gedreht wurde

Ich habe die entsprechenden Karten farblich markiert. Jede Feature-Map wird um eins verschoben. Da der endgültige Max-Operator das gleiche Ergebnis für diese verschobenen Feature-Maps berechnet, erhalten wir die gleichen Ergebnisse.

In meinem Code habe ich nach der letzten Faltung nicht zurückrotiert, da meine Kernel das Bild zu einem eindimensionalen Array verdichten. Wenn Sie dieses Beispiel näher erläutern möchten, müssen Sie diese Tatsache berücksichtigen.

Die Berücksichtigung von Gruppenaktionen oder „Kernelrotationen“ spielt eine entscheidende Rolle beim Entwurf anspruchsvollerer Architekturen.

Ist es ein kostenloses Mittagessen?

Nein, wir bezahlen mit Rechengeschwindigkeit, induktivem Bias und einer komplexeren Implementierung.

Der letzte Punkt lässt sich einigermaßen mit Bibliotheken wie E3NN lösen, in denen der Großteil der schweren Mathematik abstrahiert wird. Dennoch muss man beim Architekturentwurf einiges berücksichtigen.

Eine oberflächliche Schwäche ist der vierfache Rechenaufwand für die Berechnung aller gedrehten Feature-Layer. Moderne Hardware mit Massenparallelisierung kann dieser Belastung jedoch problemlos entgegenwirken. Im Gegensatz dazu würde das Training eines einfachen CNN mit Datenerweiterung die Trainingszeit leicht um das Zehnfache überschreiten. Noch schlimmer wird es bei 3D-Rotationen, bei denen die Datenerweiterung etwa das 500-fache des Trainingsaufwands erfordern würde, um alle möglichen Rotationen zu kompensieren.

Insgesamt ist der Entwurf eines Äquivarianzmodells oft ein lohnenswerter Preis, wenn man stabile Funktionen wünscht.

Was kommt als nächstes?

Äquivariante Modelldesigns haben in den letzten Jahren explosionsartig zugenommen, und in diesem Artikel haben wir kaum an der Oberfläche gekratzt. Tatsächlich haben wir nicht einmal das volle Potenzial ausgeschöpft C4C_4C4 Gruppe noch. Wir hätten vollständige 3D-Kernel verwenden können. Allerdings erreicht unser Modell bereits eine Genauigkeit von über 95 %, sodass es kaum einen Grund gibt, mit diesem Beispiel noch weiter zu gehen.

Außer CNNs haben Forscher diese Prinzipien erfolgreich auf kontinuierliche Gruppen übertragen, darunter SO(2) SO(2)SO(2) (die Gruppe aller Drehungen in der Ebene) und SE(3) SE(3)SE(3) (die Gruppe aller Translationen und Rotationen im 3D-Raum).

Meiner Erfahrung nach sind diese Modelle absolut umwerfend und erreichen eine Leistung, wenn sie von Grund auf trainiert werden, vergleichbar mit der Leistung von Basismodellen, die auf mehrfach größeren Datensätzen trainiert werden.

Lassen Sie mich wissen, wenn Sie möchten, dass ich mehr zu diesem Thema schreibe.

Weitere Referenzen

Falls Sie eine formelle Einführung in dieses Thema wünschen, finden Sie hier eine hervorragende Zusammenstellung von Artikeln, die die gesamte Geschichte der Äquivarianz im maschinellen Lernen abdecken.
AEN

Ich habe tatsächlich vor, ein ausführliches, praktisches Tutorial zu diesem Thema zu erstellen. Sie können sich bereits jetzt für meine Mailingliste anmelden und ich werde Ihnen im Laufe der Zeit kostenlose Versionen zur Verfügung stellen, zusammen mit einem direkten Kanal für Feedback und Fragen und Antworten.

Wir sehen uns :)

Das obige ist der detaillierte Inhalt vonDas Prinzip zum Aufbau regulärer äquivarianter CNNs. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn