首頁  >  文章  >  web前端  >  使用 C FFI 從 Dart 執行 JAX 程序

使用 C FFI 從 Dart 執行 JAX 程序

Barbara Streisand
Barbara Streisand原創
2024-11-23 13:50:10982瀏覽

?為什麼要將 Dart 和 JAX 結合起來進行機器學習?

建立應用程式時,選擇正確的工具至關重要。您需要高效能、輕鬆開發和無縫跨平台部署。流行的框架提供了權衡:

  • C 提供速度,但會減慢開發速度。
  • Dart(與 Flutter 一起)速度較慢,但簡化了記憶體管理和跨平台開發。

但問題是:大多數框架缺乏強大的原生機器學習 (ML) 支援。這種差距的存在是因為這些框架早於人工智慧繁榮。問題是:

我們如何有效地將機器學習整合到應用程式中?

諸如ONNX Runtime之類的常見解決方案允許導出ML模型以進行應用程式集成,但它們沒有針對CPU進行最佳化,也沒有足夠靈活以適應通用演算法。

輸入JAX,一個Python庫,它:

  • 允許編寫最佳化的機器學習和通用演算法。
  • 在 CPU、GPU 和 TPU 上提供與平台無關的執行。
  • 支援autogradJIT編譯等尖端功能。

在本文中,我們將向您展示如何:

  1. 用 Python 寫 JAX 程式。
  2. 產生 XLA 規格。
  3. 使用 C FFI 在 Dart 中部署最佳化的 JAX 程式碼。

?什麼是JAX?

JAX 就像是類固醇上的 NumPy。它由 Google 開發,是一個低階高效能函式庫,使 ML 易於存取且功能強大。

  • 與平台無關:程式碼無需修改即可在 CPU、GPU 和 TPU 上運作。
  • 速度:JAX 由 XLA 編譯器提供支持,最佳化並加速執行。
  • 靈活性:非常適合機器學習模型和通用演算法。

這是一個比較 NumPy 和 JAX 的例子:

# NumPy version
import numpy as np  
def assign_numpy():  
  a = np.empty(1000000)  
  a[:] = 1  
  return a  

# JAX version
import jax.numpy as jnp  
import jax  

@jax.jit  
def assign_jax():  
  a = jnp.empty(1000000)  
  return a.at[:].set(1)  

Google Colab 中的基準測試揭示了 JAX 的效能優勢:

  • CPU 和 GPU:JAX 比 NumPy 更快。
  • TPU:由於資料傳輸成本,大型模型的加速變得明顯。

這種靈活性和速度使 JAX 成為性能至關重要的生產環境的理想選擇。


Running a JAX Program from Dart Using C   FFI


?️ 將 JAX 投入生產

雲端微服務與本地部署

  • 雲端:容器化 Python 微服務非常適合基於雲端的運算。
  • 本地:發布 Python 解釋器對於本機應用程式來說並不理想。

解:利用 JAX 的 XLA 編譯

JAX 將 Python 程式碼翻譯成 HLO(高階最佳化器)規格,可以使用 C XLA 函式庫編譯和執行。這使得:

  1. 用 Python 寫演算法。
  2. 透過 C 庫本地運行它們。
  3. 透過FFI(外部函數介面) 與 Dart 整合。

✍️ 逐步集成

1. 產生 HLO 原型

編寫您的 JAX 函數並匯出其 HLO 表示。例如:

# NumPy version
import numpy as np  
def assign_numpy():  
  a = np.empty(1000000)  
  a[:] = 1  
  return a  

# JAX version
import jax.numpy as jnp  
import jax  

@jax.jit  
def assign_jax():  
  a = jnp.empty(1000000)  
  return a.at[:].set(1)  

要產生 HLO,請使用 JAX 儲存庫中的 jax_to_ir.py 腳本:

import jax.numpy as jnp  

def fn(x, y, z):  
  return jnp.dot(x, y) / z  

將產生的檔案(fn_hlo.txt 和 fn_hlo.pb)放入應用程式的資產目錄中。


2. 建置C動態庫

修改JAX的C範例程式碼

複製 JAX 儲存庫並導航至 jax/examples/jax_cpp

  • 新增main.h頭檔:
python jax_to_ir.py \
  --fn jax_example.prog.fn \
  --input_shapes '[("x", "f32[2,2]"), ("y", "f32[2,2")]' \
  --constants '{"z": 2.0}' \
  --ir_format HLO \
  --ir_human_dest /tmp/fn_hlo.txt \
  --ir_dest /tmp/fn_hlo.pb
  • 更新BUILD檔案以建立共享庫:
#ifndef MAIN_H  
#define MAIN_H  

extern "C" {  
  int bar(int foo);  
}  

#endif  

使用 Bazel 編譯:

cc_shared_library(  
   name = "jax",  
   deps = [":main"],  
   visibility = ["//visibility:public"],  
)  

您將在輸出目錄中找到編譯後的 libjax.dylib。


3. 使用 FFI 將 Dart 與 C 連接

使用 Dart 的 FFI 套件 與 C 函式庫通訊。建立一個 jax.dart 檔案:

bazel build examples/jax_cpp:jax  

將動態程式庫包含在您的專案目錄中。測試一下:

import 'dart:ffi';  
import 'package:dynamic_library/dynamic_library.dart';  

typedef FooCFunc = Int32 Function(Int32 bar);  
typedef FooDartFunc = int Function(int bar);  

class JAX {  
  late final DynamicLibrary dylib;  

  JAX() {  
    dylib = loadDynamicLibrary(libraryName: 'jax');  
  }  

  Function get _bar => dylib.lookupFunction<FooCFunc, FooDartFunc>('bar');  

  int bar(int foo) {  
    return _bar(foo);  
  }  
}  

您將在控制台中看到 C 庫的輸出。


?下一步

透過此設置,您可以:

  • 使用 JAX 和 XLA 最佳化 ML 模型。
  • 在本地運行強大的演算法。

潛在用例包括:

  • 搜尋演算法(例如,A*)。
  • 組合最佳化(例如調度)。
  • 影像處理(例如邊緣偵測)。

JAX 彌合了基於 Python 的開發和生產級效能之間的差距,讓 ML 工程師能夠專注於演算法,而無需擔心低階 C 程式碼。


我們正在建立一個尖端的人工智慧平台,具有無限的聊天令牌和長期記憶,確保隨著時間的推移而不斷發展的無縫、上下文感知互動。

它是完全免費的,您也可以在當前的 IDE 中嘗試它。


Running a JAX Program from Dart Using C   FFI

以上是使用 C FFI 從 Dart 執行 JAX 程序的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn