从年初到现在,生成式 AI 发展迅猛。但很多时候,我们又不得不面临一个难题:如何加快生成式 AI 的训练、推理等,尤其是在使用 PyTorch 的情况下。
本文 PyTorch 团队的研究者为我们提供了一个解决方案。文章重点介绍了如何使用纯原生 PyTorch 加速生成式 AI 模型,此外,文章还介绍了 PyTorch 新功能,以及如何组合这些功能的实际示例。
结果如何呢?PyTorch 团队表示,他们重写了 Meta 的「分割一切」 (SAM) 模型,从而使代码比原始实现快 8 倍,并且没有损失准确率,所有这些都是使用原生 PyTorch 进行优化的。
博客地址:https://pytorch.org/blog/accelerating-generative-ai/
在阅读本文后,你将会获得以下的了解:
PyTorch 原生特性所带来的吞吐量增加以及减少的内存开销。
有关此研究的更多信息,请参考Meta提出的SAM。详细文章可在「CV不存在了?Meta发布「分割一切」AI模型,CV或迎来GPT-3时刻」中找到
接下来,我们将介绍SAM的优化过程,包括性能分析、瓶颈识别,以及如何将这些新功能整合进PyTorch以解决SAM所面临的问题。此外,我们还会介绍PyTorch的一些新特性,包括torch.compile、SDPA、Triton kernels、Nested Tensor以及semi-structured sparsity(半结构化稀疏)
内容的逐层深入,本文最后将介绍快速版 SAM。对于感兴趣的读者,可以前往 GitHub 下载。此外,通过使用 Perfetto UI 对这些数据进行了可视化,以展示 PyTorch 各项特性的应用价值
GitHub 地址:https://github.com/pytorch-labs/segment-anything-fast 可以找到这个项目的源代码
该研究指出,本文使用的SAM基线数据类型为float32 dtype,批处理大小为1,并使用PyTorch Profiler来查看核心追踪的结果如下:
本文发现 SAM 有两个地方可以优化:
第一个是对 aten::index 的长调用,这是由张量索引操作(例如 [])产生的底层调用导致的。然而实际上 GPU 花费在 aten::index 上的时间相对较低,原因在于 aten::index 在启动两个内核的过程中,两者之间发生了阻塞 cudaStreamSynchronize。这意味着 CPU 会等待 GPU 完成处理,直到启动第二个内核。因而为了优化 SAM,本文认为应该致力于消除导致空闲时间的阻塞 GPU 同步。
第二个问题是在矩阵乘法中,SAM花费了大量的GPU时间(如图所示的深绿色部分),这在Transformers模型中非常普遍。如果我们能够减少SAM模型在矩阵乘法上的GPU时间,那么我们就能够显着提高SAM的速度
接下来,我们将以SAM的吞吐量(img/s)和内存开销(GiB)来建立基准。然后就是优化过程
需要进行改写的句子是:Bfloat16 半精度(加上GPU 同步和批处理)
为了解决上述问题,即减少矩阵乘法所需的时间,本文转向bfloat16。 bfloat16是常用的半精度类型,通过降低每个参数和激活的精度,能够节省大量的计算时间和内存
将填充类型替换为bfloat16
此外,本文发现有两个位置可以进行优化,以移除GPU 同步
具体来说,根据上图更容易理解,该研究发现在SAM的图像编码器中,有两个变量q_coords和k_coords充当坐标缩放器,这些变量都在CPU上进行分配和处理。然而,一旦这些变量用于在rel_pos_resized中建立索引,索引操作会自动将这些变量移动到GPU上,从而导致GPU同步的问题。为了解决这个问题,该研究指出可以使用torch.where函数重写这部分内容来解决问题,具体如上所示
核心追踪
在对这些更改进行应用之后,我们注意到单个内核调用之间存在明显的时间间隔,特别是在小批量(这里为1)的情况下更为明显。为了更深入地了解这一现象,我们开始对批大小为8的SAM推理进行性能分析
在分析每个内核所花费的时间时,我们注意到SAM 的大部分GPU时间都用于逐元素内核和softmax 操作
现在可以看到矩阵乘法的相对开销小了很多。
将 GPU 同步和 bfloat16 优化结合在一起,SAM 性能提高了 3 倍。
在研究SAM的过程中发现了许多细小的操作。研究人员认为使用编译器来整合这些操作非常有益,因此PyTorch对torch.compile进行了以下优化
通过这些优化,该研究减少了 GPU 全局内存往返次数(roundtrips),从而加快了推理速度。我们现在可以在 SAM 的图像编码器上尝试 torch.compile。为了最大限度地提高性能,本文使用了一些高级编译技术:
核心追踪
根据结果显示,torch.compile 的表现非常出色
可以观察到 softmax 占了很大一部分时间,然后是各种 GEMM 变体。以下测量的是批大小为 8 及以上的变化。
接下来,本文又对 SDPA(scaled_dot_product_attention)进行了实验,研究的重点是注意力机制。一般来讲,原生注意力机制在时间和内存上随序列长度呈二次方扩展。PyTorch 的 SDPA 操作基于 Flash Attention、FlashAttentionV2 和 xFormer 的内存高效注意力原理构建,可以显着加快 GPU 注意力。与 torch.compile 相结合,这个操作允许在 MultiheadAttention 的变体中表达和融合一个共同的模式。经过一小部分更改后,现在模型可以使用 scaled_dot_product_attention。
核心追踪
现在可以看到内存高效的注意力内核占用了 GPU 上大量的计算时间:
使用 PyTorch 的原生 scaled_dot_product_attention,可以显著增加批处理大小。下图为批大小为 32 及以上的变化。
接下来,该研究进行了对 Triton、NestedTensor、批处理 Predict_torch、int8 量化、半结构化 (2:4) 稀疏性等操作的实验
例如本文使用自定义 positional Triton 内核,观察到批大小为 32 的测量结果。
采用 Nested Tensor 技术,并调整批大小为 32 及以上
添加量化后,批大小为 32 及以上变化的测量结果。
文章的最后是半结构化稀疏性。该研究表示,矩阵乘法仍然是需要面对的一个瓶颈。解决的办法是使用稀疏化来近似矩阵乘法。通过稀疏矩阵(即将值归零)可以使用更少的位来存储权重和激活张量。该研究将张量中哪些权重设置为零的过程称为剪枝。剪枝掉较小的权重可以潜在地减小模型大小,而不会显着损失准确率。
剪枝的方法有很多种,从完全非结构化到高度结构化都有。虽然理论上来说非结构化剪枝对精度的影响最小,但是在稀疏情况下,GPU可能会遇到显着的性能下降,尽管在进行大型密集矩阵乘法时非常高效。最近PyTorch支持的一种剪枝方法是半结构化(或2:4)稀疏性,旨在寻求平衡。这种稀疏存储方式将原始张量减少了50%,同时产生密集张量的输出。请参考下图进行说明
为了使用这种稀疏存储格式和相关的快速内核,接下来要做的是剪枝权重。本文在 2:4 的稀疏度下选择最小的两个权重进行剪枝,将权重从默认的 PyTorch(“strided”)布局更改为这种新的半结构化稀疏布局很容易。要实现apply_sparse (model),只需要32 行Python 代码:
在稀疏度为2:4的情况下,我们观察到vit_b和批大小为32时的SAM峰值性能
最终,对这篇文章的概括如下:本文介绍了截至目前在PyTorch上实现Segment Anything的最快方法,借助官方发布的一系列新功能,本文在纯PyTorch中重新编写了原始的SAM,并且没有损失准确度
对于感兴趣的读者,可以查看原博客以获取更多信息
以上是PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍的详细内容。更多信息请关注PHP中文网其他相关文章!