PyTorch 2.0 发布,一行代码将训练提速 76%!( 二 )


PyTorch的开发理念自始至终都是灵活性和hackability第一 , 性能则是第二 , 致力于:
1.高性能的eagerexecution
2.不断Python化内部结构
3.分布式、自动比较、数据加载、加速器等的良好抽象
PyTorch自2017年面世以来 , 硬件加速器(如GPU)的计算速度提高了约15倍 , 内存访问速度提高了约2倍 。
为了保持高性能的eagerexecution , PyTorch内部的大部分内容不得不转移到C++中 , 这使得PyTorchhackability下降 , 也增加了开发者参与代码贡献的门槛 。
从第一天起 , PyTorch官方就意识到了eagerexecution的性能局限 。 2017年7月 , 官方开始致力于为PyTorch开发一个编译器 。 该编译器需要在不牺牲PyTorch体验的前提下 , 加速PyTorch程序的运行 , 其关键标准是保持某种程度上的灵活性(flexibility):支持开发者广泛使用的dynamicshapes以及dynamicprograms 。
PyTorch 2.0 发布,一行代码将训练提速 76%!
文章图片
开发者SylvainGugger表示:“只需添加一行代码 , PyTorch2.0就能在训练Transformers模型时实现1.5倍到2.0倍的速度提升 。 这是自混合精度训练问世以来最令人兴奋的事情!”
技术概述
多年来 , 研究者们在PyTorch中建立过好几个编译器项目 , 这些编译器可以分为3类:
图结构的获取图结构的降低图结构的编译其中 , 在构建PyTorch编译器时 , 图结构的获取是更难的挑战 。
过去5年中 , 官方尝试了torch.jit.trace、TorchScript、FXtracing以及LazyTensors , 但它们有些够灵活但不够快 , 有些够快但不灵活 , 有些既不快也不灵活 , 有些用户体验不好 。
虽然TorchScript很有前途 , 但它需要大量修改代码和依赖 , 可行性并不高 。
PyTorch 2.0 发布,一行代码将训练提速 76%!
文章图片
PyTorch编译流程示意图TorchDynamo:可靠快速地获取图结构
TorchDynamo使用了PEP-0523中引入的CPython功能 , 称为框架评估API(FrameEvaluationAPI) 。 为此 , 官方采取了一种数据驱动的方法来验证其在GraphCapture上的有效性 , 使用7000多个用PyTorch编写的Github项目作为验证集 。
结果显示 , TorchDynamo在99%的时间里都能正确、安全地获取图结构 , 而且开销可以忽略不计 , 因为它无需对原始代码做任何修改 。
TorchInductor:用define-by-runIR进行更迅速的codegen
对于PyTorch2.0的新编译器后端 , 团队从用户如何编写高性能的自定义内核中得到了灵感:越来越多地使用Triton语言 。
此外 , 对于PyTorch2.0全新的编译器后端 , 官方还希望能够使用与PyTorcheager类似的抽象 , 并且具有足够的通用性能支持PyTorch中广泛的功能 。
TorchInductor使用Pythonicdefine-by-runlooplevelIR , 自动将PyTorch模型映射到GPU上生成的Triton代码以及CPU上的C++/OpenMP 。
TorchInductor的corelooplevelIR只包含大约50个算子 , 而且是用Python实现的 , 这使得它具有很强的hackability和扩展性 。
AOTAutograd:对于ahead-of-timegraph , 重用Autograd
PyTorch2.0要想加速训练 , 不仅要捕获用户级代码 , 而且要捕获反向传播算法(backpropagation) 。
AOTAutograd利用PyTorchtorch_dispatch扩展机制来追踪Autograd引擎 , 使开发者得以提前捕获反向传播(backwardspas) , 从而使开发者得以使用TorchInductor加速前向和后向通道 。
PrimTorch:稳定的Primitiveoperator
PyTorch有1200多个运算符 , 如果考虑到每个运算符的各种重载 , 数量高达2000+ 。
PyTorch 2.0 发布,一行代码将训练提速 76%!
文章图片
2000+PyTorch算子的分类概况
因此 , 编写后端或交叉功能(cross-cuttingfeature)成为一项耗费精力的工作 。 PrimTorch致力于定义更小更稳定的运算符集 。 PyTorch程序可以持续降级到这些运算符集 。