搜索
首页科技周边人工智能改变几行代码,PyTorch炼丹速度狂飙、模型优化时间大减

如何提升 PyTorch「炼丹」速度?

最近,知名机器学习与 AI 研究者 Sebastian Raschka 向我们展示了他的绝招。据他表示,他的方法在不影响模型准确率的情况下,仅仅通过改变几行代码,将 BERT 优化时间从 22.63 分钟缩减到 3.15 分钟,训练速度足足提升了 7 倍。

图片

作者更是表示,如果你有 8 个 GPU 可用,整个训练过程只需要 2 分钟,实现 11.5 倍的性能加速。

图片

下面我们来看看他到底是如何实现的。

让 PyTorch 模型训练更快

首先是模型,作者采用 DistilBERT 模型进行研究,它是 BERT 的精简版,与 BERT 相比规模缩小了 40%,但性能几乎没有损失。其次是数据集,训练数据集为大型电影评论数据集 IMDB Large Movie Review,该数据集总共包含 50000 条电影评论。作者将使用下图中的 c 方法来预测数据集中的影评情绪。

图片

基本任务交代清楚后,下面就是 PyTorch 的训练过程。为了让大家更好地理解这项任务,作者还贴心地介绍了一下热身练习,即如何在 IMDB 电影评论数据集上训练 DistilBERT 模型。如果你想自己运行代码,可以使用相关的 Python 库设置一个虚拟环境,如下所示:

相关软件的版本如下:

图片

现在省略掉枯燥的数据加载介绍,只需要了解本文将数据集划分为 35000 个训练示例、5000 个验证示例和 10000 个测试示例。需要的代码如下:

图片

代码部分截图

完整代码地址:

​https://github.com/rasbt/faster-pytorch-blog/blob/main/1_pytorch-distilbert.py​

然后在 A100 GPU 上运行代码,得到如下结果:

图片

部分结果截图

正如上述代码所示,模型从第 2 轮到第 3 轮开始有一点过拟合,验证准确率从 92.89% 下降到了 92.09%。在模型运行了 22.63 分钟后进行微调,最终的测试准确率为 91.43%。

使用 Trainer 类

接下来是改进上述代码,改进部分主要是把 PyTorch 模型包装在 LightningModule 中,这样就可以使用来自 Lightning 的 Trainer 类。部分代码截图如下:

图片

完整代码地址:https://github.com/rasbt/faster-pytorch-blog/blob/main/2_pytorch-with-trainer.py

上述代码建立了一个 LightningModule,它定义了如何执行训练、验证和测试。相比于前面给出的代码,主要变化是在第 5 部分(即 ### 5 Finetuning),即微调模型。与以前不同的是,微调部分在 LightningModel 类中包装了 PyTorch 模型,并使用 Trainer 类来拟合模型。

图片

之前的代码显示验证准确率从第 2 轮到第 3 轮有所下降,但改进后的代码使用了 ModelCheckpoint 以加载最佳模型。在同一台机器上,这个模型在 23.09 分钟内达到了 92% 的测试准确率。

图片

需要注意,如果禁用 checkpointing 并允许 PyTorch 以非确定性模式运行,本次运行最终将获得与普通 PyTorch 相同的运行时间(时间为 22.63 分而不是 23.09 分)。

自动混合精度训练

进一步,如果 GPU 支持混合精度训练,可以开启 GPU 以提高计算效率。作者使用自动混合精度训练,在 32 位和 16 位浮点之间切换而不会牺牲准确率。

图片

在这一优化下,使用 Trainer 类,即能通过一行代码实现自动混合精度训练:

图片

上述操作可以将训练时间从 23.09 分钟缩短到 8.75 分钟,这几乎快了 3 倍。测试集的准确率为 92.2%,甚至比之前的 92.0% 还略有提高。

图片

使用 Torch.Compile 静态图

最近 PyTorch 2.0 公告显示,PyTorch 团队引入了新的 toch.compile 函数。该函数可以通过生成优化的静态图来加速 PyTorch 代码执行,而不是使用动态图运行 PyTorch 代码。

图片

由于 PyTorch 2.0 尚未正式发布,因而必须先要安装 torchtriton,并更新到 PyTorch 最新版本才能使用此功能。

图片


然后通过添加这一行对代码进行修改:

图片

在 4 块 GPU 上进行分布式数据并行

上文介绍了在单 GPU 上加速代码的混合精度训练,接下来介绍多 GPU 训练策略。下图总结了几种不同的多 GPU 训练技术。

图片

想要实现分布式数据并行,可以通过 DistributedDataParallel 来实现,只需修改一行代码就能使用 Trainer。

图片

经过这一步优化,在 4 个 A100 GPU 上,这段代码运行了 3.52 分钟就达到了 93.1% 的测试准确率。

图片

图片

DeepSpeed

最后,作者探索了在 Trainer 中使用深度学习优化库 DeepSpeed 以及多 GPU 策略的结果。首先必须安装 DeepSpeed 库:

图片

接着只需更改一行代码即可启用该库:

图片

这一波下来,用时 3.15 分钟就达到了 92.6% 的测试准确率。不过 PyTorch 也有 DeepSpeed 的替代方案:fully-sharded DataParallel,通过 strategy="fsdp" 调用,最后花费 3.62 分钟完成。

图片

以上就是作者提高 PyTorch 模型训练速度的方法,感兴趣的小伙伴可以跟着原博客尝试一下,相信你会得到想要的结果。

以上是改变几行代码,PyTorch炼丹速度狂飙、模型优化时间大减的详细内容。更多信息请关注PHP中文网其他相关文章!

声明
本文转载于:51CTO.COM。如有侵权,请联系admin@php.cn删除
MISTRAL大2:足够强大,可以挑战Llama 3.1 405b?MISTRAL大2:足够强大,可以挑战Llama 3.1 405b?Apr 18, 2025 am 10:16 AM

MISTRAL大2:深入了解Mistral AI强大的开源LLM Meta AI最近发布的Llama 3.1模型系列很快被Mistral AI揭幕了其迄今为止最大的模型:Mistral flow 2。这个1230亿参数

稳定扩散中的噪声时间表是什么? - 分析Vidhya稳定扩散中的噪声时间表是什么? - 分析VidhyaApr 18, 2025 am 10:15 AM

了解扩散模型中的噪声时间表:综合指南 您是否曾经被AI产生的令人惊叹的数字艺术视觉效果所吸引,并想知道基础机制? 关键要素是“噪声时间表,&quo

如何使用GPT-4O构建对话聊天机器人? - 分析Vidhya如何使用GPT-4O构建对话聊天机器人? - 分析VidhyaApr 18, 2025 am 10:06 AM

使用GPT-4O构建上下文聊天机器人:综合指南 在AI和NLP迅速发展的景观中,聊天机器人已成为开发人员和组织必不可少的工具。 创建真正引人入胜且聪明的聊天的关键方面

2025年建造AI代理的前7个框架2025年建造AI代理的前7个框架Apr 18, 2025 am 10:00 AM

本文探讨了建立AI代理的七个领先框架 - 自主软件实体,这些软件实体可以感知,决定和采取行动实现目标。 这些代理人超越了传统的强化学习,利用高级计划和推理

I型和II型错误有什么区别? - 分析VidhyaI型和II型错误有什么区别? - 分析VidhyaApr 18, 2025 am 09:48 AM

了解统计假设检验中的I型和II型错误 想象一下一项临床试验测试一种新的血压药物。 该试验的结论大大降低了血压,但实际上并非如此。这是一种类型

使用Sumy库的自动文本摘要使用Sumy库的自动文本摘要Apr 18, 2025 am 09:37 AM

Sumy:您的AI驱动摘要助理 厌倦了筛选无尽的文件? 强大的Python库Sumy提供了一种简化的解决方案,用于自动文本摘要。 本文探讨了Sumy的功能,指导您通过

SQL案例语句:从基础到高级技术SQL案例语句:从基础到高级技术Apr 18, 2025 am 09:31 AM

数据挑战:掌握SQL的案例声明以进行准确的见解 当您拥有数据爱好者时,谁需要律师? 数据分析师,科学家和广阔数据世界中的每个人都面临着自己的复杂挑战,确保系统功能FLA

及时工程中知识链的力量是什么?及时工程中知识链的力量是什么?Apr 18, 2025 am 09:30 AM

利用AI中的知识链的力量:深入研究迅速工程 您是否知道人工智能(AI)不仅可以理解您的问题,而且还可以编织大量知识来提供有见地的答案?

See all articles

热AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智能驱动的应用程序,用于创建逼真的裸体照片

AI Clothes Remover

AI Clothes Remover

用于从照片中去除衣服的在线人工智能工具。

Undress AI Tool

Undress AI Tool

免费脱衣服图片

Clothoff.io

Clothoff.io

AI脱衣机

AI Hentai Generator

AI Hentai Generator

免费生成ai无尽的。

热门文章

R.E.P.O.能量晶体解释及其做什么(黄色晶体)
1 个月前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.最佳图形设置
1 个月前By尊渡假赌尊渡假赌尊渡假赌
威尔R.E.P.O.有交叉游戏吗?
1 个月前By尊渡假赌尊渡假赌尊渡假赌

热工具

MinGW - 适用于 Windows 的极简 GNU

MinGW - 适用于 Windows 的极简 GNU

这个项目正在迁移到osdn.net/projects/mingw的过程中,你可以继续在那里关注我们。MinGW:GNU编译器集合(GCC)的本地Windows移植版本,可自由分发的导入库和用于构建本地Windows应用程序的头文件;包括对MSVC运行时的扩展,以支持C99功能。MinGW的所有软件都可以在64位Windows平台上运行。

SublimeText3 英文版

SublimeText3 英文版

推荐:为Win版本,支持代码提示!

SublimeText3汉化版

SublimeText3汉化版

中文版,非常好用

适用于 Eclipse 的 SAP NetWeaver 服务器适配器

适用于 Eclipse 的 SAP NetWeaver 服务器适配器

将Eclipse与SAP NetWeaver应用服务器集成。

PhpStorm Mac 版本

PhpStorm Mac 版本

最新(2018.2.1 )专业的PHP集成开发工具