Integrazione del plug-in PJRT

Contesto

PJRT è l'API uniforme del dispositivo che vogliamo aggiungere all'ecosistema ML. La visione a lungo termine è che: (1) i framework (JAX, TF, ecc.) chiameranno PJRT, che ha implementazioni specifiche del dispositivo opache ai framework; (2) ciascun dispositivo si concentra sull’implementazione delle API PJRT e può essere opaco rispetto ai framework.

Questo documento si concentra sulle raccomandazioni su come eseguire l'integrazione con PJRT e su come testare l'integrazione di PJRT con JAX.

Come eseguire l'integrazione con PJRT

Passaggio 1: implementa l'interfaccia dell'API PJRT C

Opzione A: puoi implementare direttamente l'API PJRT C.

Opzione B: se sei in grado di creare sulla base del codice C++ nel repository xla (tramite forking o bazel), puoi anche implementare l'API PJRT C++ e utilizzare il wrapper C→C++:

  1. Implementare un client PJRT C++ che eredita dal client PJRT di base (e dalle relative classi PJRT). Ecco alcuni esempi di client PJRT C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. Implementa alcuni metodi dell'API C che non fanno parte del client PJRT C++:
    • PJRT_Client_Create. Di seguito è riportato un esempio di pseudocodice (supponendo che GetPluginPjRtClient restituisca un client PJRT C++ implementato sopra):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

Nota: PJRT_Client_Create può utilizzare opzioni trasferite dal framework. Qui è riportato un esempio di come un client GPU utilizza questa funzionalità.

Con il wrapper, non è necessario implementare le API C rimanenti.

Passaggio 2: implementa GetPjRtApi

Devi implementare un metodo GetPjRtApi che restituisca un puntatore a funzione contenente PJRT_Api* alle implementazioni dell'API C PJRT. Di seguito è riportato un esempio in cui si presuppone l'implementazione tramite wrapper (simile a pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

Passaggio 3: testa le implementazioni dell'API C

Puoi chiamare RegisterPjRtCApiTestFactory per eseguire un piccolo insieme di test sui comportamenti di base dell'API PJRT C.

Come utilizzare un plug-in PJRT di JAX

Passaggio 1: configura JAX

Puoi usare JAX di notte

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

oppure crea JAX dal codice sorgente.

Per ora, devi associare la versione jaxlib alla versione dell'API PJRT C. Generalmente è sufficiente utilizzare una versione notturna jaxlib dello stesso giorno del commit TF per cui stai creando il plug-in, ad esempio:

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

Puoi anche creare un file jaxlib dal codice sorgente esattamente al commit XLA per cui stai creando (instructions).

Inizieremo a supportare la compatibilità con ABI.

Passaggio 2: utilizza lo spazio dei nomi jax_plugins o configura entry_point

Esistono due opzioni per consentire a JAX di trovare il tuo plug-in.

  1. Utilizzo di pacchetti dello spazio dei nomi (ref). Definisci un modulo univoco a livello globale nel pacchetto dello spazio dei nomi jax_plugins (ad esempio, crea una directory jax_plugins e definisci il modulo in basso). Ecco un esempio di struttura di directory:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. Utilizzo dei metadati del pacchetto (ref). Se crei un pacchetto tramite pyproject.toml o setup.py, pubblicizza il nome del modulo del plug-in includendo un punto di ingresso sotto il gruppo jax_plugins che rimanda al nome completo del modulo. Ecco un esempio tramite pyproject.toml o setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

Ecco alcuni esempi di come openxla-pjrt-plugin viene implementato utilizzando l'opzione 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

Passaggio 3: implementa un metodo startize()

Per registrare il plug-in, devi implementare un metodo startize() nel modulo Python, ad esempio:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

Consulta questa pagina per informazioni sull'utilizzo di xla_bridge.register_plugin. Attualmente è un metodo privato. In futuro verrà rilasciata un'API pubblica.

Puoi eseguire la riga seguente per verificare che il plug-in sia registrato e segnalare un errore se non può essere caricato.

jax.config.update("jax_platforms", "my_plugin")

JAX potrebbe avere più backend/plugin. Esistono alcune opzioni per assicurarsi che il plug-in venga utilizzato come backend predefinito:

  • Opzione 1: esegui jax.config.update("jax_platforms", "my_plugin") all'inizio del programma.
  • Opzione 2: imposta ENV JAX_PLATFORMS=my_plugin.
  • Opzione 3: imposta una priorità sufficientemente elevata quando chiami xb.register_plugin (il valore predefinito è 400, maggiore di quello degli altri backend esistenti). Tieni presente che il backend con la massima priorità verrà utilizzato solo quando JAX_PLATFORMS=''. Il valore predefinito di JAX_PLATFORMS è '', ma a volte viene sovrascritto.

Come eseguire il test con JAX

Alcuni scenari di test di base da provare:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

Presto aggiungeremo le istruzioni per eseguire i test delle unità jax sul tuo plug-in.

Esempio: plug-in JAX CUDA

  1. Implementazione dell'API PJRT C tramite wrapper (pjrt_c_api_gpu.h).
  2. Configura il punto di ingresso per il pacchetto (setup.py).
  3. Implementa un metodo startize() (__init__.py).
  4. Può essere testato con qualsiasi test jax per CUDA. ```