首页 >web前端 >js教程 >使用 C FFI 从 Dart 运行 JAX 程序

使用 C FFI 从 Dart 运行 JAX 程序

Barbara Streisand
Barbara Streisand原创
2024-11-23 13:50:101051浏览

?为什么将 Dart 和 JAX 结合起来进行机器学习?

构建应用程序时,选择正确的工具至关重要。您需要高性能、轻松开发和无缝跨平台部署。流行的框架提供了权衡:

  • C 提供速度,但会减慢开发速度。
  • Dart(与 Flutter 一起)速度较慢,但​​简化了内存管理和跨平台开发。

但问题是:大多数框架缺乏强大的原生机器学习 (ML) 支持。这种差距的存在是因为这些框架早于人工智能繁荣。问题是:

我们如何有效地将机器学习集成到应用程序中? ​​

诸如ONNX Runtime之类的常见解决方案允许导出ML模型以进行应用程序集成,但它们没有针对CPU进行优化,也没有足够灵活以适应通用算法。

输入JAX,一个Python库,它:

  • 允许编写优化的机器学习和通用算法。
  • 在 CPU、GPU 和 TPU 上提供与平台无关的执行。
  • 支持诸如autogradJIT编译等尖端功能。

在本文中,我们将向您展示如何:

  1. 用 Python 编写 JAX 程序。
  2. 生成 XLA 规格。
  3. 使用 C FFI 在 Dart 中部署优化的 JAX 代码。

?什么是JAX?

JAX 就像类固醇上的 NumPy。它由 Google 开发,是一个低级高性能库,使 ML 易于访问且功能强大。

  • 与平台无关:代码无需修改即可在 CPU、GPU 和 TPU 上运行。
  • 速度:JAX 由 XLA 编译器提供支持,优化并加速执行。
  • 灵活性:非常适合机器学习模型和通用算法。

这是一个比较 NumPy 和 JAX 的示例:

# NumPy version
import numpy as np  
def assign_numpy():  
  a = np.empty(1000000)  
  a[:] = 1  
  return a  

# JAX version
import jax.numpy as jnp  
import jax  

@jax.jit  
def assign_jax():  
  a = jnp.empty(1000000)  
  return a.at[:].set(1)  

Google Colab 中的基准测试揭示了 JAX 的性能优势:

  • CPU 和 GPU:JAX 比 NumPy 更快。
  • TPU:由于数据传输成本,大型模型的加速变得明显。

这种灵活性和速度使 JAX 成为性能至关重要的生产环境的理想选择。


Running a JAX Program from Dart Using C   FFI


?️ 将 JAX 投入生产

云微服务与本地部署

  • :容器化 Python 微服务非常适合基于云的计算。
  • 本地:发布 Python 解释器对于本地应用程序来说并不理想。

解决方案:利用 JAX 的 XLA 编译

JAX 将 Python 代码翻译成 HLO(高级优化器)规范,可以使用 C XLA 库编译和执行。这使得:

  1. 用 Python 编写算法。
  2. 通过 C 库本地运行它们。
  3. 通过FFI(外部函数接口) 与 Dart 集成。

✍️ 逐步集成

1. 生成 HLO 原型

编写您的 JAX 函数并导出其 HLO 表示。例如:

# NumPy version
import numpy as np  
def assign_numpy():  
  a = np.empty(1000000)  
  a[:] = 1  
  return a  

# JAX version
import jax.numpy as jnp  
import jax  

@jax.jit  
def assign_jax():  
  a = jnp.empty(1000000)  
  return a.at[:].set(1)  

要生成 HLO,请使用 JAX 存储库中的 jax_to_ir.py 脚本:

import jax.numpy as jnp  

def fn(x, y, z):  
  return jnp.dot(x, y) / z  

将生成的文件(fn_hlo.txt 和 fn_hlo.pb)放入应用程序的资产目录中。


2. 构建C动态库

修改JAX的C示例代码

克隆 JAX 存储库并导航到 jax/examples/jax_cpp

  • 添加main.h头文件:
python jax_to_ir.py \
  --fn jax_example.prog.fn \
  --input_shapes '[("x", "f32[2,2]"), ("y", "f32[2,2")]' \
  --constants '{"z": 2.0}' \
  --ir_format HLO \
  --ir_human_dest /tmp/fn_hlo.txt \
  --ir_dest /tmp/fn_hlo.pb
  • 更新BUILD文件以创建共享库:
#ifndef MAIN_H  
#define MAIN_H  

extern "C" {  
  int bar(int foo);  
}  

#endif  

使用 Bazel 编译:

cc_shared_library(  
   name = "jax",  
   deps = [":main"],  
   visibility = ["//visibility:public"],  
)  

您将在输出目录中找到编译后的 libjax.dylib。


3. 使用 FFI 将 Dart 与 C 连接

使用 Dart 的 FFI 包 与 C 库通信。创建一个 jax.dart 文件:

bazel build examples/jax_cpp:jax  

将动态库包含在您的项目目录中。测试一下:

import 'dart:ffi';  
import 'package:dynamic_library/dynamic_library.dart';  

typedef FooCFunc = Int32 Function(Int32 bar);  
typedef FooDartFunc = int Function(int bar);  

class JAX {  
  late final DynamicLibrary dylib;  

  JAX() {  
    dylib = loadDynamicLibrary(libraryName: 'jax');  
  }  

  Function get _bar => dylib.lookupFunction<FooCFunc, FooDartFunc>('bar');  

  int bar(int foo) {  
    return _bar(foo);  
  }  
}  

您将在控制台中看到 C 库的输出。


?下一步

通过此设置,您可以:

  • 使用 JAX 和 XLA 优化 ML 模型。
  • 在本地运行强大的算法。

潜在用例包括:

  • 搜索算法(例如,A*)。
  • 组合优化(例如调度)。
  • 图像处理(例如边缘检测)。

JAX 弥合了基于 Python 的开发和生产级性能之间的差距,让 ML 工程师能够专注于算法,而无需担心低级 C 代码。


我们正在构建一个尖端的人工智能平台,具有无限的聊天令牌和长期记忆,确保随着时间的推移而不断发展的无缝、上下文感知交互。

它是完全免费的,您也可以在当前的 IDE 中尝试它。


Running a JAX Program from Dart Using C   FFI

以上是使用 C FFI 从 Dart 运行 JAX 程序的详细内容。更多信息请关注PHP中文网其他相关文章!

声明:
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn