2022 年了 , 我该用 JAX 吗?
很遗憾 , 这个问题的答案还是「视情况而定」 。 是否迁移到 JAX 取决于你的情况和目标 。 为具体分析是否应该(或不应该)在 2022 年使用 JAX , 这里将建议汇总到下面的流程图中 , 并针对不同的兴趣领域提供不同的图表 。
科学计算
文章图片
如果你对 JAX 在通用计算感兴趣 , 首先要问的问题就是——是否只尝试在加速器上运行 NumPy?如果答案是肯定的 , 那么你显然应该开始迁移到 JAX 。
如果你不只处理数字而是参与动态计算建模 , 那么是否应该使用 JAX 将取决于具体用例 。 如果大部分工作是在 Python 中使用大量自定义代码完成的 , 那么开始学习 JAX 以增强工作流程是值得的 。
如果大部分工作不在 Python 中 , 但你想构建的是某种基于模型 / 神经网络的混合系统 , 那么使用 JAX 可能是值得的 。
如果大部分工作不使用 Python , 或者你正在使用一些专门的软件进行研究(热力学、半导体等) , 那么 JAX 可能是不合适的工具 , 除非你想从这些程序中导出数据 , 用来做自定义计算 。 如果你感兴趣的领域更接近物理 / 数学并包含计算方法(动力系统、微分几何、统计物理)并且大部分工作都在例如 Mathematica 上 , 那么坚持使用目前的工具才是值得的 , 特别是在已有大型自定义代码库的情形下 。
深度学习
文章图片
虽然我们已经强调过 , JAX 不是专为深度学习构建的通用框架 , 但 JAX 速度很快且具有自动微分功能 , 你肯定想知道使用 JAX 进行深度学习是什么样的 。
若想在 TPU 上进行训练 , 那么你应该开始使用 JAX , 尤其是如果当前正在使用的是 PyTorch 。 虽然有 PyTorch-XLA 存在 , 但使用 JAX 进行 TPU 训练绝对是更好的体验 。 如果你正在研究的是「非标准」架构 / 建模 , 例如 SDE-Nets , 那么也绝对应该尝试一下 JAX 。 此外 , 如果你想利用高阶优化技术 , JAX 也是要尝试的东西 。
如果你不是在构建特殊的架构 , 只是在 GPU 上训练常见的架构 , 那么你现在可能应该坚持使用 PyTorch 或 TensorFlow 。 然而 , 这个建议可能会在未来一两年内快速发生变化 。 虽然 PyTorch 仍然在研究领域占据主导地位 , 但使用 JAX 的论文数量一直在稳步增长 。 随着 DeepMind 和谷歌重量级玩家不断开发用于 JAX 的高级深度学习 API , 在几年内 JAX 可能会出现爆炸性的增长率 。
这意味着你至少应该稍微熟悉一下 JAX , 如果你是研究人员的话更应如此 。
深度学习初学者
文章图片
但如果我只是个初学者呢?情况会有些不一样 。
如果你有兴趣了解深度学习并实现一些想法 , 你应该使用 JAX 或 PyTorch 。 如果你想自上而下学习深度学习 , 或有一些 Python 软件的经验 , 则应该从 PyTorch 入手 。 如果你想自下而上地学习深度学习 , 或具有数学背景 , 你可能会发现 JAX 很直观 。 在这种情况下 , 在进行任何大型项目之前 , 请确保了解如何使用 JAX 。
如果你对深度学习感兴趣 , 又想转行相关的职位 , 那么你需要使用 PyTorch 或 TensorFlow 。 尽管最好是同时熟悉两个框架 , 但你必须知道 TensorFlow 被普遍认为是「行业」框架 , 不同框架的职位发布数量证明了这一点:
特别声明:本站内容均来自网友提供或互联网,仅供参考,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
