Home > Article > Web Front-end > Running a JAX Program from Dart Using C FFI
When building applications, selecting the right tools is crucial. You want high performance, easy development, and seamless cross-platform deployment. Popular frameworks offer trade-offs:
But here’s the catch: most frameworks lack robust native machine learning (ML) support. This gap exists because these frameworks predate the AI boom. The question is:
How can we efficiently integrate ML into applications?
Common solutions like ONNX Runtime allow exporting ML models for application integration, but they aren’t optimized for CPUs or flexible enough for generalized algorithms.
Enter JAX, a Python library that:
In this article, we’ll show you how to:
JAX is like NumPy on steroids. Developed by Google, it’s a low-level, high-performance library that makes ML accessible yet powerful.
Here’s an example comparing NumPy and JAX:
# NumPy version import numpy as np def assign_numpy(): a = np.empty(1000000) a[:] = 1 return a # JAX version import jax.numpy as jnp import jax @jax.jit def assign_jax(): a = jnp.empty(1000000) return a.at[:].set(1)
Benchmarking in Google Colab reveals JAX’s performance edge:
This flexibility and speed make JAX ideal for production environments where performance is key.
JAX translates Python code into HLO (High-Level Optimizer) specifications, which can be compiled and executed using C XLA libraries. This enables:
Write your JAX function and export its HLO representation. For example:
# NumPy version import numpy as np def assign_numpy(): a = np.empty(1000000) a[:] = 1 return a # JAX version import jax.numpy as jnp import jax @jax.jit def assign_jax(): a = jnp.empty(1000000) return a.at[:].set(1)
To generate the HLO, use the jax_to_ir.py script from the JAX repository:
import jax.numpy as jnp def fn(x, y, z): return jnp.dot(x, y) / z
Place the resulting files (fn_hlo.txt and fn_hlo.pb) in your app’s assets directory.
Clone the JAX repository and navigate to jax/examples/jax_cpp.
python jax_to_ir.py \ --fn jax_example.prog.fn \ --input_shapes '[("x", "f32[2,2]"), ("y", "f32[2,2")]' \ --constants '{"z": 2.0}' \ --ir_format HLO \ --ir_human_dest /tmp/fn_hlo.txt \ --ir_dest /tmp/fn_hlo.pb
#ifndef MAIN_H #define MAIN_H extern "C" { int bar(int foo); } #endif
Compile with Bazel:
cc_shared_library( name = "jax", deps = [":main"], visibility = ["//visibility:public"], )
You’ll find the compiled libjax.dylib in the output directory.
Use Dart’s FFI package to communicate with the C library. Create a jax.dart file:
bazel build examples/jax_cpp:jax
Include the dynamic library in your project directory. Test it with:
import 'dart:ffi'; import 'package:dynamic_library/dynamic_library.dart'; typedef FooCFunc = Int32 Function(Int32 bar); typedef FooDartFunc = int Function(int bar); class JAX { late final DynamicLibrary dylib; JAX() { dylib = loadDynamicLibrary(libraryName: 'jax'); } Function get _bar => dylib.lookupFunction<FooCFunc, FooDartFunc>('bar'); int bar(int foo) { return _bar(foo); } }
You’ll see the output from the C library in your console.
With this setup, you can:
Potential use cases include:
JAX bridges the gap between Python-based development and production-level performance, letting ML engineers focus on algorithms without worrying about low-level C code.
We’re building a cutting-edge AI platform with unlimited chat tokens and long-term memory, ensuring seamless, context-aware interactions that evolve over time.
It's fully free, and you can try it inside your current IDE, too.
The above is the detailed content of Running a JAX Program from Dart Using C FFI. For more information, please follow other related articles on the PHP Chinese website!