2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美

机器之心报道
编辑:杜伟、陈萍

近年来 , 谷歌于 2018 年推出的 JAX 迎来了迅猛发展 , 很多研究者对其寄予厚望 , 希望它可以取代 TensorFlow 等众多深度学习框架 。 但 JAX 是否真的适合所有人使用呢?这篇文章对 JAX 的方方面面展开了深入探讨 , 希望可以给研究者选择深度学习框架时提供有益的参考 。
自 2018 年底推出以来 , JAX 的受欢迎程度一直在稳步提升 。 2020 年 , DeepMind 宣布使用 JAX 来加速其研究 。 越来越多来自谷歌大脑(Google Brain)和其他机构的项目也都在使用 JAX 。
目前 , 在 JAX 的 GitHub 项目主页 , Star 量已经达到了 16.3k 。
2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

项目地址:https://github.com/google/jax
JAX 是一个非常有前途的项目 , 并且用户一直在稳步增长 。 JAX 已经在深度学习、机器人 / 控制系统、贝叶斯方法和科学模拟等诸多领域得到了广泛应用 。
2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

如此 , 是否意味着 JAX 也将成为下一个大型深度学习框架?近日 , 发表在 AssemblyAI 博客上的文章《Why You Should (or Shouldn't) Be Using JAX in 2022》中 , 作者 Ryan O'Connor 为我们深入解读了 JAX 的概念、使用 JAX 的理由以及是否应该使用 JAX 等 。
JAX 简介
JAX 不是一个深度学习框架或库 , 其设计初衷也不是成为一个深度学习框架或库 。 简而言之 , JAX 是一个包含可组合函数转换的数值计算库 。 正如我们所看到的 , 深度学习只是 JAX 功能的一小部分:
2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

JAX 的定位科学计算(Scientific Computing)和函数转换(Function Transformations)的交叉融合 , 具有除训练深度学习模型以外的一系列能力 , 包括如下:
  • 即时编译(Just-in-Time Compilation)
  • 自动并行化(Automatic Parallelization)
  • 自动向量化(Automatic Vectorization)
  • 自动微分(Automatic Differentiation)
使用 JAX 的原因有哪些?
简而言之 , 是速度 。 这是 JAX 与任何用例相关的一种通用能力 。 让我们使用 NumPy 和 JAX 对矩阵的前三个幂求和(按元素) 。
首先是 NumPy 实现 。 我们发现 , 该计算大约需要 851 毫秒 。
2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

然后使用 JAX 实现该计算:
2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

JAX 仅在 5.54 毫秒内执行完成该计算 , 速度是 NumPy 的 150 倍以上 。
2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

JAX 的速度比 NumPy 快了 N 个数量级 。 需要注意 , JAX 使用的是 TPU , NumPy 使用了 CPU , 以此强调 JAX 的速度上限远高于 NumPy 。
作者列出了以下六条可能想要使用 JAX 的理由:
  • NumPy 加速器 。 NumPy 是使用 Python 进行科学计算的基础包之一 , 但它仅与 CPU 兼容 。 JAX 提供了 NumPy 的实现(具有几乎相同的 API) , 可以非常轻松地在 GPU 和 TPU 上运行 。 对于许多用户而言 , 仅此一项功能就足以证明使用 JAX 的合理性;
  • XLA 。 XLA(Accelerated Linear Algebra)是专为线性代数设计的全程序优化编译器 。 JAX 建立在 XLA 之上 , 显著提高了计算速度上限;

    特别声明:本站内容均来自网友提供或互联网,仅供参考,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。