首頁 >後端開發 >Python教學 >建構常規等變 CNN 的原則

建構常規等變 CNN 的原則

王林
王林原創
2024-07-18 11:29:181119瀏覽

其中一個原則簡單地表述為“讓核心旋轉”,我們將在本文中重點介紹如何將其應用到您的架構中。

等變架構使我們能夠訓練對某些群體行為無關的模型。

為了理解這到底意味著什麼,讓我們在 MNIST 資料集(0-9 的手寫數字資料集)上訓練這個簡單的 CNN 模型。

class SimpleCNN(nn.Module):

    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.max_1 = nn.MaxPool2d(kernel_size=2)
        self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.max_2 = nn.MaxPool2d(kernel_size=2)
        self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7)
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)
        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)
        x = nn.functional.silu(self.cl3(x))
        x = x.view(len(x), -1)
        logits = self.dense(x)
        return logits
Accuracy on test Accuracy on 90-degree rotated test
97.3% 15.1%

表1:SimpleCNN模型的測試精確度

正如預期的那樣,我們在測試資料集上獲得了超過 95% 的準確率,但是如果我們將影像旋轉 90 度呢?如果不採取任何對策,結果只會比猜測的好一點點。這個模型對於一般應用來說是沒有用的。

相較之下,讓我們訓練一個具有相同數量參數的類似等變架構,其中組動作恰好是 90 度旋轉。

Accuracy on test Accuracy on 90-degree rotated test
96.5% 96.5%

表2:使用與SimpleCNN模型相同數量的參數來測試EqCNN模型的準確度

準確性保持不變,我們甚至沒有選擇資料增強。

這些模型在 3D 數據的幫助下變得更加令人印象深刻,但我們將繼續使用這個範例來探索核心思想。

如果您想親自測試一下,您可以在 Github-Repo 下免費存取用 PyTorch 和 JAX 編寫的所有程式碼,並且只需兩個命令即可使用 Docker 或 Podman 進行訓練。

玩得開心!

那什麼是等方差呢?

等變架構保證了某些群體行為下特徵的穩定性。群組是簡單的結構,其中群組元素可以組合、反轉或不執行任何操作。

有興趣的話可以去維基百科查一下正式的定義。

出於我們的目的,您可以想像一組作用於方形影像的 90 度旋轉。我們可以將影像旋轉 90、180、270 或 360 度。為了反轉該動作,我們分別套用 270、180、90 或 0 度旋轉。很容易看出,我們可以對錶示為 的群組進行組合、反轉或不執行任何操作 C4C_4C4

。該影像將影像上的所有操作視覺化。

Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively
圖 1:將 MNIST 影像分別旋轉 90°、180°、270°、360°

Now, given an input image xxx , our CNN model classifier fθf_\thetafθ , and an arbitrary 90-degree rotation ggg , the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x) f_\theta(\text{rotate } x \text{ by } g) = f_\theta(x) fθ(rotate x by g)=fθ(x)

Generally speaking, we want our image-based model to have the same outputs when rotated.

As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.

How to Make Our CNN Equivariant

The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.

So, in code, our CNN kernel

x = nn.functional.silu(self.cl1(x))

now acts on all four rotated images:

x_0 = x
x_90 = torch.rot90(x, k=1, dims=(2, 3))
x_180 = torch.rot90(x, k=2, dims=(2, 3))
x_270 = torch.rot90(x, k=3, dims=(2, 3))

x_0 = nn.functional.silu(self.cl1(x_0))
x_90 = nn.functional.silu(self.cl1(x_90))
x_180 = nn.functional.silu(self.cl1(x_180))
x_270 = nn.functional.silu(self.cl1(x_270))

Or more compactly written as a 3D convolution:

self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
...
x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
x = nn.functional.silu(self.cl1(x))

The resulting equivariant model has just a few lines more compared to the version above:

class EqCNN(nn.Module):

    def __init__(self):
        super(EqCNN, self).__init__()
        self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
        self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3))
        self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5))
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x_0 = x
        x_90 = torch.rot90(x, k=1, dims=(2, 3))
        x_180 = torch.rot90(x, k=2, dims=(2, 3))
        x_270 = torch.rot90(x, k=3, dims=(2, 3))

        x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)

        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)

        x = nn.functional.silu(self.cl3(x))

        x = x.squeeze()
        x = torch.max(x, dim=-1).values
        logits = self.dense(x)
        return logits

But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.

This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.

To understand what is happening, let us plot the feature maps after the first convolution stage.

Figure 2: Feature maps for all four rotations
Figure 2: Feature maps for all four rotations

And now the same features after we rotate the input by 90 degrees.

Figure 3: Feature maps for all four rotations after the input image was rotated
圖 3:輸入影像旋轉後所有四個旋轉的特徵圖

我對對應的地圖進行了顏色編碼。每個特徵圖移動一位。由於最終的 max 運算子對這些移位的特徵圖計算相同的結果,因此我們得到了相同的結果。

在我的程式碼中,我在最終卷積後沒有旋轉回去,因為我的核心將圖像壓縮為一維數組。如果您想擴展此範例,則需要考慮這一事實。

考慮群體行為或「核心輪換」在更複雜的架構設計中起著至關重要的作用。

這是免費的午餐嗎?

不,我們付出的是計算速度、歸納偏差和更複雜的實現。

後一點在 E3NN 等庫中得到了一定程度的解決,其中大部分繁重的數學都被抽象化了。不過,架構設計時需要考慮很多。

一個表面上的弱點是計算所有旋轉特徵層的 4 倍計算成本。然而,具有大規模並行化的現代硬體可以輕鬆抵消這種負載。相較之下,透過資料增強訓練一個簡單的 CNN 訓練時間很容易就會超過 10 倍。對於 3D 旋轉,情況會變得更糟,其中資料增強需要大約 500 倍的訓練量才能補償所有可能的旋轉。

整體而言,如果想要穩定的特徵,等方差模型設計通常是值得付出的代價。

接下來是什麼?

等變模型設計近年來呈現爆炸性成長,在本文中,我們僅僅觸及了表面。事實上,我們甚至沒有充分利用 C4C_4C4

還沒有組。我們可以使用完整的 3D 核心。然而,我們的模型已經達到了 95% 以上的準確率,因此沒有理由進一步使用這個範例。

除了 CNN 之外,研究人員還成功地將這些原理轉化為連續組,包括 SO(2(2SO(2)SO(2) (平面上所有旋轉的組)和 SE(3(3SE(3)

SE(3)


(3D 空間中所有平移和旋轉的集合)。

根據我的經驗,這些模型絕對令人興奮,並且在從頭開始訓練時所達到的性能,可與在數倍大的數據集上訓練的基礎模型的性能相媲美。

如果您希望我就這個主題寫更多內容,請告訴我。

進一步參考 如果您想正式介紹該主題,這裡有一份優秀的論文彙編,涵蓋了機器學習中等變性的完整歷史。 埃恩 我實際上計劃創建一個關於這個主題的深入的實踐教程。您已經可以註冊我的郵件列表,隨著時間的推移,我將向您提供免費版本,以及反饋和問答的直接管道。 再見:)

以上是建構常規等變 CNN 的原則的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn