首页  >  文章  >  科技周边  >  PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

王林
王林转载
2023-11-22 14:38:31698浏览

从年初到现在,生成式 AI 发展迅猛。但很多时候,我们又不得不面临一个难题:如何加快生成式 AI 的训练、推理等,尤其是在使用 PyTorch 的情况下。

本文 PyTorch 团队的研究者为我们提供了一个解决方案。文章重点介绍了如何使用纯原生 PyTorch 加速生成式 AI 模型,此外,文章还介绍了 PyTorch 新功能,以及如何组合这些功能的实际示例。

结果如何呢?PyTorch 团队表示,他们重写了 Meta 的「分割一切」 (SAM) 模型,从而使代码比原始实现快 8 倍,并且没有损失准确率,所有这些都是使用原生 PyTorch 进行优化的。 

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

博客地址:https://pytorch.org/blog/accelerating-generative-ai/

在阅读本文后,你将会获得以下的了解:

  • Torch.compile:PyTorch 模型编译器, PyTorch 2.0 加入了一个新的函数,叫做 torch.compile (),能够通过一行代码对已有的模型进行加速;
  • GPU 量化:通过降低运算精度来加速模型;
  • SDPA(Scaled Dot Product Attention ):内存高效的注意力实现方式;
  • 半结构化 (2:4) 稀疏性:一种针对 GPU 优化的稀疏内存格式;
  • Nested Tensor:Nested Tensor 把 {tensor, mask} 打包在一起,将非均匀大小的数据批处理到单个张量中,例如不同大小的图像;
  • Triton 自定义操作:使用 Triton Python DSL 编写 GPU 操作,并通过自定义操作符注册轻松将其集成到 PyTorch 的各种组件中。

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

PyTorch 原生特性所带来的吞吐量增加以及减少的内存开销。

有关此研究的更多信息,请参考Meta提出的SAM。详细文章可在「CV不存在了?Meta发布「分割一切」AI模型,CV或迎来GPT-3时刻」中找到

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

接下来,我们将介绍SAM的优化过程,包括性能分析、瓶颈识别,以及如何将这些新功能整合进PyTorch以解决SAM所面临的问题。此外,我们还会介绍PyTorch的一些新特性,包括torch.compile、SDPA、Triton kernels、Nested Tensor以及semi-structured sparsity(半结构化稀疏)

内容的逐层深入,本文最后将介绍快速版 SAM。对于感兴趣的读者,可以前往 GitHub 下载。此外,通过使用 Perfetto UI 对这些数据进行了可视化,以展示 PyTorch 各项特性的应用价值

GitHub 地址:https://github.com/pytorch-labs/segment-anything-fast 可以找到这个项目的源代码

对分割一切模型 SAM 的重写

该研究指出,本文使用的SAM基线数据类型为float32 dtype,批处理大小为1,并使用PyTorch Profiler来查看核心追踪的结果如下:

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

本文发现 SAM 有两个地方可以优化:

第一个是对 aten::index 的长调用,这是由张量索引操作(例如 [])产生的底层调用导致的。然而实际上 GPU 花费在 aten::index 上的时间相对较低,原因在于 aten::index 在启动两个内核的过程中,两者之间发生了阻塞 cudaStreamSynchronize。这意味着 CPU 会等待 GPU 完成处理,直到启动第二个内核。因而为了优化 SAM,本文认为应该致力于消除导致空闲时间的阻塞 GPU 同步。

第二个问题是在矩阵乘法中,SAM花费了大量的GPU时间(如图所示的深绿色部分),这在Transformers模型中非常普遍。如果我们能够减少SAM模型在矩阵乘法上的GPU时间,那么我们就能够显着提高SAM的速度

接下来,我们将以SAM的吞吐量(img/s)和内存开销(GiB)来建立基准。然后就是优化过程

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

需要进行改写的句子是:Bfloat16 半精度(加上GPU 同步和批处理)

为了解决上述问题,即减少矩阵乘法所需的时间,本文转向bfloat16。 bfloat16是常用的半精度类型,通过降低每个参数和激活的精度,能够节省大量的计算时间和内存

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍


将填充类型替换为bfloat16

此外,本文发现有两个位置可以进行优化,以移除GPU 同步

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍


PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

具体来说,根据上图更容易理解,该研究发现在SAM的图像编码器中,有两个变量q_coords和k_coords充当坐标缩放器,这些变量都在CPU上进行分配和处理。然而,一旦这些变量用于在rel_pos_resized中建立索引,索引操作会自动将这些变量移动到GPU上,从而导致GPU同步的问题。为了解决这个问题,该研究指出可以使用torch.where函数重写这部分内容来解决问题,具体如上所示

核心追踪

在对这些更改进行应用之后,我们注意到单个内核调用之间存在明显的时间间隔,特别是在小批量(这里为1)的情况下更为明显。为了更深入地了解这一现象,我们开始对批大小为8的SAM推理进行性能分析

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

在分析每个内核所花费的时间时,我们注意到SAM 的大部分GPU时间都用于逐元素内核和softmax 操作

现在可以看到矩阵乘法的相对开销小了很多。

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

将 GPU 同步和 bfloat16 优化结合在一起,SAM 性能提高了 3 倍。

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

Torch.compile(+graph breaks 和 CUDA graphs)

在研究SAM的过程中发现了许多细小的操作。研究人员认为使用编译器来整合这些操作非常有益,因此PyTorch对torch.compile进行了以下优化

  • 将nn.LayerNorm 或nn.GELU 等操作序列融合成一个单一的GPU 内核;
  • 融合紧跟在矩阵乘法内核之后的操作,以减少GPU 内核调用的数量。

通过这些优化,该研究减少了 GPU 全局内存往返次数(roundtrips),从而加快了推理速度。我们现在可以在 SAM 的图像编码器上尝试 torch.compile。为了最大限度地提高性能,本文使用了一些高级编译技术:

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

核心追踪

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

根据结果显示,torch.compile 的表现非常出色

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

可以观察到 softmax 占了很大一部分时间,然后是各种 GEMM 变体。以下测量的是批大小为 8 及以上的变化。

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

SDPA: scaled_dot_product_attention

接下来,本文又对 SDPA(scaled_dot_product_attention)进行了实验,研究的重点是注意力机制。一般来讲,原生注意力机制在时间和内存上随序列长度呈二次方扩展。PyTorch 的 SDPA 操作基于 Flash Attention、FlashAttentionV2 和 xFormer 的内存高效注意力原理构建,可以显着加快 GPU 注意力。与 torch.compile 相结合,这个操作允许在 MultiheadAttention 的变体中表达和融合一个共同的模式。经过一小部分更改后,现在模型可以使用 scaled_dot_product_attention。

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

核心追踪

现在可以看到内存高效的注意力内核占用了 GPU 上大量的计算时间:

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

使用 PyTorch 的原生 scaled_dot_product_attention,可以显著增加批处理大小。下图为批大小为 32 及以上的变化。

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

接下来,该研究进行了对 Triton、NestedTensor、批处理 Predict_torch、int8 量化、半结构化 (2:4) 稀疏性等操作的实验

例如本文使用自定义 positional Triton 内核,观察到批大小为 32 的测量结果。

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

采用 Nested Tensor 技术,并调整批大小为 32 及以上

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

添加量化后,批大小为 32 及以上变化的测量结果。

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

文章的最后是半结构化稀疏性。该研究表示,矩阵乘法仍然是需要面对的一个瓶颈。解决的办法是使用稀疏化来近似矩阵乘法。通过稀疏矩阵(即将值归零)可以使用更少的位来存储权重和激活张量。该研究将张量中哪些权重设置为零的过程称为剪枝。剪枝掉较小的权重可以潜在地减小模型大小,而不会显着损失准确率。

剪枝的方法有很多种,从完全非结构化到高度结构化都有。虽然理论上来说非结构化剪枝对精度的影响最小,但是在稀疏情况下,GPU可能会遇到显着的性能下降,尽管在进行大型密集矩阵乘法时非常高效。最近PyTorch支持的一种剪枝方法是半结构化(或2:4)稀疏性,旨在寻求平衡。这种稀疏存储方式将原始张量减少了50%,同时产生密集张量的输出。请参考下图进行说明

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

为了使用这种稀疏存储格式和相关的快速内核,接下来要做的是剪枝权重。本文在 2:4 的稀疏度下选择最小的两个权重进行剪枝,将权重从默认的 PyTorch(“strided”)布局更改为这种新的半结构化稀疏布局很容易。要实现apply_sparse (model),只需要32 行Python 代码:

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

在稀疏度为2:4的情况下,我们观察到vit_b和批大小为32时的SAM峰值性能

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

最终,对这篇文章的概括如下:本文介绍了截至目前在PyTorch上实现Segment Anything的最快方法,借助官方发布的一系列新功能,本文在纯PyTorch中重新编写了原始的SAM,并且没有损失准确度

对于感兴趣的读者,可以查看原博客以获取更多信息

以上是PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍的详细内容。更多信息请关注PHP中文网其他相关文章!

声明:
本文转载于:51cto.com。如有侵权,请联系admin@php.cn删除