Heim  >  Artikel  >  Web-Frontend  >  Ausführen eines JAX-Programms von Dart mit C FFI

Ausführen eines JAX-Programms von Dart mit C FFI

Barbara Streisand
Barbara StreisandOriginal
2024-11-23 13:50:10984Durchsuche

? Warum Dart und JAX für maschinelles Lernen kombinieren?

Beim Erstellen von Anwendungen ist die Auswahl der richtigen Tools von entscheidender Bedeutung. Sie wünschen sich hohe Leistung, einfache Entwicklung und nahtlose plattformübergreifende Bereitstellung. Beliebte Frameworks bieten Kompromisse:

  • C sorgt für Geschwindigkeit, kann aber die Entwicklung verlangsamen.
  • Dart (mit Flutter) ist langsamer, vereinfacht aber die Speicherverwaltung und plattformübergreifende Entwicklung.

Aber hier ist der Haken: Den meisten Frameworks mangelt es an robuster Unterstützung für natives maschinelles Lernen (ML). Diese Lücke besteht, weil diese Frameworks vor dem KI-Boom entstanden sind. Die Frage lautet:

Wie können wir ML effizient in Anwendungen integrieren?

Gängige Lösungen wie ONNX Runtime ermöglichen den Export von ML-Modellen für die Anwendungsintegration, sind jedoch nicht für CPUs optimiert oder flexibel genug für verallgemeinerte Algorithmen.

Geben Sie JAX ein, eine Python-Bibliothek, die:

  • Ermöglicht das Schreiben optimierter ML- und Allzweckalgorithmen.
  • Bietet plattformunabhängige Ausführung auf CPUs, GPUs und TPUs.
  • Unterstützt hochmoderne Funktionen wie Autograd und JIT-Kompilierung.

In diesem Artikel zeigen wir Ihnen, wie Sie:

  1. Schreiben Sie JAX-Programme in Python.
  2. XLA-Spezifikationen generieren.
  3. Stellen Sie optimierten JAX-Code in Dart mit C FFI bereit.

? Was ist JAX?

JAX ist wie NumPy auf Steroiden. Es handelt sich um eine von Google entwickelte Low-Level-Hochleistungsbibliothek, die ML zugänglich und dennoch leistungsstark macht.

  • Plattformunabhängig: Code läuft ohne Änderung auf CPUs, GPUs und TPUs.
  • Geschwindigkeit: Mithilfe des XLA-Compilers optimiert und beschleunigt JAX die Ausführung.
  • Flexibilität: Perfekt für ML-Modelle und allgemeine Algorithmen gleichermaßen.

Hier ist ein Beispiel zum Vergleich von NumPy und JAX:

# NumPy version
import numpy as np  
def assign_numpy():  
  a = np.empty(1000000)  
  a[:] = 1  
  return a  

# JAX version
import jax.numpy as jnp  
import jax  

@jax.jit  
def assign_jax():  
  a = jnp.empty(1000000)  
  return a.at[:].set(1)  

Benchmarking in Google Colab zeigt den Leistungsvorsprung von JAX:

  • CPU & GPU: JAX ist schneller als NumPy.
  • TPU: Geschwindigkeitssteigerungen machen sich bei großen Modellen aufgrund der Datenübertragungskosten bemerkbar.

Diese Flexibilität und Geschwindigkeit machen JAX ideal für Produktionsumgebungen, in denen Leistung entscheidend ist.


Running a JAX Program from Dart Using C   FFI


?️ JAX in die Produktion bringen

Cloud-Microservices vs. lokale Bereitstellung

  • Cloud: Containerisierte Python-Microservices eignen sich hervorragend für cloudbasiertes Computing.
  • Lokal: Der Versand eines Python-Interpreters ist für lokale Apps nicht ideal.

Lösung: Nutzen Sie die XLA-Kompilierung von JAX

JAX übersetzt Python-Code in HLO-Spezifikationen (High-Level Optimizer), die mit C XLA-Bibliotheken kompiliert und ausgeführt werden können. Dies ermöglicht:

  1. Algorithmen in Python schreiben.
  2. Sie werden nativ über eine C-Bibliothek ausgeführt.
  3. Integration mit Dart über FFI (Foreign Function Interface).

✍️ Schritt-für-Schritt-Integration

1. Generieren Sie ein HLO-Proto

Schreiben Sie Ihre JAX-Funktion und exportieren Sie ihre HLO-Darstellung. Zum Beispiel:

# NumPy version
import numpy as np  
def assign_numpy():  
  a = np.empty(1000000)  
  a[:] = 1  
  return a  

# JAX version
import jax.numpy as jnp  
import jax  

@jax.jit  
def assign_jax():  
  a = jnp.empty(1000000)  
  return a.at[:].set(1)  

Um das HLO zu generieren, verwenden Sie das Skript jax_to_ir.py aus dem JAX-Repository:

import jax.numpy as jnp  

def fn(x, y, z):  
  return jnp.dot(x, y) / z  

Platzieren Sie die resultierenden Dateien (fn_hlo.txt und fn_hlo.pb) im Assets-Verzeichnis Ihrer App.


2. Erstellen Sie eine dynamische C-Bibliothek

Ändern Sie den C-Beispielcode von JAX

Klonen Sie das JAX-Repository und navigieren Sie zu jax/examples/jax_cpp.

  • Fügen Sie eine main.h-Header-Datei hinzu:
python jax_to_ir.py \
  --fn jax_example.prog.fn \
  --input_shapes '[("x", "f32[2,2]"), ("y", "f32[2,2")]' \
  --constants '{"z": 2.0}' \
  --ir_format HLO \
  --ir_human_dest /tmp/fn_hlo.txt \
  --ir_dest /tmp/fn_hlo.pb
  • Aktualisieren Sie die Datei BUILD, um eine gemeinsam genutzte Bibliothek zu erstellen:
#ifndef MAIN_H  
#define MAIN_H  

extern "C" {  
  int bar(int foo);  
}  

#endif  

Mit Bazel kompilieren:

cc_shared_library(  
   name = "jax",  
   deps = [":main"],  
   visibility = ["//visibility:public"],  
)  

Sie finden die kompilierte libjax.dylib im Ausgabeverzeichnis.


3. Verbinden Sie Dart über FFI mit C

Verwenden Sie das FFI-Paket von Dart, um mit der C-Bibliothek zu kommunizieren. Erstellen Sie eine jax.dart-Datei:

bazel build examples/jax_cpp:jax  

Fügen Sie die dynamische Bibliothek in Ihr Projektverzeichnis ein. Testen Sie es mit:

import 'dart:ffi';  
import 'package:dynamic_library/dynamic_library.dart';  

typedef FooCFunc = Int32 Function(Int32 bar);  
typedef FooDartFunc = int Function(int bar);  

class JAX {  
  late final DynamicLibrary dylib;  

  JAX() {  
    dylib = loadDynamicLibrary(libraryName: 'jax');  
  }  

  Function get _bar => dylib.lookupFunction<FooCFunc, FooDartFunc>('bar');  

  int bar(int foo) {  
    return _bar(foo);  
  }  
}  

Sie sehen die Ausgabe der C-Bibliothek in Ihrer Konsole.


? Nächste Schritte

Mit diesem Setup können Sie:

  • Optimieren Sie ML-Modelle mit JAX und XLA.
  • Führen Sie leistungsstarke Algorithmen lokal aus.

Mögliche Anwendungsfälle sind:

  • Suchalgorithmen (z. B. A*).
  • Kombinatorische Optimierung (z. B. Scheduling).
  • Bildverarbeitung (z. B. Kantenerkennung).

JAX schließt die Lücke zwischen Python-basierter Entwicklung und Leistung auf Produktionsebene und ermöglicht es ML-Ingenieuren, sich auf Algorithmen zu konzentrieren, ohne sich um Low-Level-C-Code kümmern zu müssen.


Wir bauen eine hochmoderne KI-Plattform mit unbegrenzten Chat-Tokens und Langzeitspeicher auf, um nahtlose, kontextbezogene Interaktionen zu gewährleisten, die sich im Laufe der Zeit weiterentwickeln.

Es ist völlig kostenlos und Sie können es auch in Ihrer aktuellen IDE ausprobieren.


Running a JAX Program from Dart Using C   FFI

Das obige ist der detaillierte Inhalt vonAusführen eines JAX-Programms von Dart mit C FFI. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn