首页 >后端开发 >Python教程 >构建常规等变 CNN 的原则

构建常规等变 CNN 的原则

王林
王林原创
2024-07-18 11:29:181152浏览

其中一个原则简单地表述为“让内核旋转”,我们将在本文中重点介绍如何将其应用到您的架构中。

等变架构使我们能够训练对某些群体行为无关的模型。

为了理解这到底意味着什么,让我们在 MNIST 数据集(0-9 的手写数字数据集)上训练这个简单的 CNN 模型。

class SimpleCNN(nn.Module):

    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.max_1 = nn.MaxPool2d(kernel_size=2)
        self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.max_2 = nn.MaxPool2d(kernel_size=2)
        self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7)
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)
        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)
        x = nn.functional.silu(self.cl3(x))
        x = x.view(len(x), -1)
        logits = self.dense(x)
        return logits
Accuracy on test Accuracy on 90-degree rotated test
97.3% 15.1%

表1:SimpleCNN模型的测试精度

正如预期的那样,我们在测试数据集上获得了超过 95% 的准确率,但是如果我们将图像旋转 90 度呢?如果不采取任何对策,结果会下降到仅比猜测稍好一些。这个模型对于一般应用来说是没有用的。

相比之下,让我们训练一个具有相同数量参数的类似等变架构,其中组动作恰好是 90 度旋转。

Accuracy on test Accuracy on 90-degree rotated test
96.5% 96.5%

表2:使用与SimpleCNN模型相同数量的参数来测试EqCNN模型的准确性

准确性保持不变,我们甚至没有选择数据增强。

这些模型在 3D 数据的帮助下变得更加令人印象深刻,但我们将继续使用这个示例来探索核心思想。

如果您想亲自测试一下,您可以在 Github-Repo 下免费访问用 PyTorch 和 JAX 编写的所有代码,并且只需两个命令即可使用 Docker 或 Podman 进行训练。

玩得开心!

那么什么是等方差?

等变架构保证了某些群体行为下特征的稳定性。组是简单的结构,其中组元素可以组合、反转或不执行任何操作。

有兴趣的话可以去维基百科查一下正式的定义。

出于我们的目的,您可以想象一组作用于方形图像的 90 度旋转。我们可以将图像旋转 90、180、270 或 360 度。为了反转该动作,我们分别应用 270、180、90 或 0 度旋转。很容易看出,我们可以对表示为 的组进行组合、反转或不执行任何操作 C4C_4C4 。该图像将图像上的所有操作可视化。

Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively
图 1:将 MNIST 图像分别旋转 90°、180°、270°、360°

Now, given an input image xxx , our CNN model classifier fθf_\thetafθ , and an arbitrary 90-degree rotation ggg , the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x) f_\theta(\text{rotate } x \text{ by } g) = f_\theta(x) fθ(rotate x by g)=fθ(x)

Generally speaking, we want our image-based model to have the same outputs when rotated.

As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.

How to Make Our CNN Equivariant

The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.

So, in code, our CNN kernel

x = nn.functional.silu(self.cl1(x))

now acts on all four rotated images:

x_0 = x
x_90 = torch.rot90(x, k=1, dims=(2, 3))
x_180 = torch.rot90(x, k=2, dims=(2, 3))
x_270 = torch.rot90(x, k=3, dims=(2, 3))

x_0 = nn.functional.silu(self.cl1(x_0))
x_90 = nn.functional.silu(self.cl1(x_90))
x_180 = nn.functional.silu(self.cl1(x_180))
x_270 = nn.functional.silu(self.cl1(x_270))

Or more compactly written as a 3D convolution:

self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
...
x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
x = nn.functional.silu(self.cl1(x))

The resulting equivariant model has just a few lines more compared to the version above:

class EqCNN(nn.Module):

    def __init__(self):
        super(EqCNN, self).__init__()
        self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
        self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3))
        self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5))
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x_0 = x
        x_90 = torch.rot90(x, k=1, dims=(2, 3))
        x_180 = torch.rot90(x, k=2, dims=(2, 3))
        x_270 = torch.rot90(x, k=3, dims=(2, 3))

        x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)

        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)

        x = nn.functional.silu(self.cl3(x))

        x = x.squeeze()
        x = torch.max(x, dim=-1).values
        logits = self.dense(x)
        return logits

But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.

This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.

To understand what is happening, let us plot the feature maps after the first convolution stage.

Figure 2: Feature maps for all four rotations
Figure 2: Feature maps for all four rotations

And now the same features after we rotate the input by 90 degrees.

Figure 3: Feature maps for all four rotations after the input image was rotated
图 3:输入图像旋转后所有四个旋转的特征图

我对相应的地图进行了颜色编码。每个特征图移动一位。由于最终的 max 运算符对这些移位的特征图计算相同的结果,因此我们获得了相同的结果。

在我的代码中,我在最终卷积后没有旋转回去,因为我的内核将图像压缩为一维数组。如果您想扩展此示例,则需要考虑这一事实。

考虑群体行为或“内核轮换”在更复杂的架构设计中起着至关重要的作用。

这是免费的午餐吗?

不,我们付出的是计算速度、归纳偏差和更复杂的实现。

后一点在 E3NN 等库中得到了一定程度的解决,其中大部分繁重的数学都被抽象化了。不过,架构设计时需要考虑很多。

一个表面上的弱点是计算所有旋转特征层的 4 倍计算成本。然而,具有大规模并行化的现代硬件可以轻松抵消这种负载。相比之下,通过数据增强训练一个简单的 CNN 训练时间很容易就会超过 10 倍。对于 3D 旋转,情况会变得更糟,其中数据增强需要大约 500 倍的训练量才能补偿所有可能的旋转。

总体而言,如果想要稳定的特征,等方差模型设计通常是值得付出的代价。

接下来是什么?

等变模型设计近年来呈爆炸式增长,在本文中,我们仅仅触及了表面。事实上,我们甚至没有充分利用 C4C_4C4 还没有组。我们可以使用完整的 3D 内核。然而,我们的模型已经达到了 95% 以上的准确率,因此没有理由进一步使用这个示例。

除了 CNN 之外,研究人员还成功地将这些原理转化为连续组,包括 SO(2) SO(2)SO(2) (平面上所有旋转的组)和 SE(3) SE(3)SE(3) (3D 空间中所有平移和旋转的集合)。

根据我的经验,这些模型绝对令人兴奋,并且在从头开始训练时所达到的性能,可与在数倍大的数据集上训练的基础模型的性能相媲美。

如果您希望我就这个主题写更多内容,请告诉我。

进一步参考

如果您想正式介绍该主题,这里有一份优秀的论文汇编,涵盖了机器学习中等变性的完整历史。
埃恩

我实际上计划创建一个关于这个主题的深入的实践教程。您已经可以注册我的邮件列表,随着时间的推移,我将向您提供免费版本,以及反馈和问答的直接渠道。

再见:)

以上是构建常规等变 CNN 的原则的详细内容。更多信息请关注PHP中文网其他相关文章!

声明:
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn