PJRT eklentisi entegrasyonu

Arka plan

PJRT, ML ekosistemine eklemek istediğimiz tek tip Device API'dir. Uzun vadeli vizyon şu şekildedir: (1) Çerçeveler (JAX, TF vb.) çerçeveler için opak olan cihaza özel uygulamaları olan PJRT'yi çağırır; (2) her cihaz, PJRT API'lerini uygulamaya odaklanır ve çerçeveler için opak olabilir.

Bu belgede, PJRT ile entegrasyon ve JAX ile PJRT entegrasyonunun nasıl test edileceği ile ilgili öneriler ele alınmaktadır.

PJRT entegrasyonu

1. Adım: PJRT C API arayüzünü uygulayın

Seçenek A: PJRT C API'yi doğrudan uygulayabilirsiniz.

B seçeneği: xla deposunda C++ koduna göre derleme yapabiliyorsanız (çatal veya bazel aracılığıyla) PJRT C++ API'sini uygulayıp C→C++ sarmalayıcıyı da kullanabilirsiniz:

  1. Temel PJRT istemcisinden (ve ilgili PJRT sınıflarından) devralan bir C++ PJRT istemcisini uygulayın. C++ PJRT istemcisi için bazı örnekler şunlardır: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. C++ PJRT istemcisinin parçası olmayan birkaç C API yöntemini uygulayın:
    • PJRT_Client_Create Aşağıda bazı örnek kod verilmiştir (GetPluginPjRtClient hizmetinin yukarıda uygulanan bir C++ PJRT istemcisi döndürdüğü varsayılır):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

PJRT_Client_Create öğesinin, çerçeveden iletilen seçenekleri alabileceğini unutmayın. Bir GPU istemcisinin bu özelliği nasıl kullandığına ilişkin bir örneği burada bulabilirsiniz.

Sarmalayıcı ile kalan C API'lerini uygulamanız gerekmez.

2. Adım: GetPjRtApi'yi uygulayın

PJRT C API uygulamalarına yönelik işlev işaretçileri içeren bir PJRT_Api* döndüren bir GetPjRtApi yöntemi uygulamanız gerekiyor. Aşağıda, sarmalayıcı üzerinden uygulandığı varsayılan bir örnek verilmiştir (pjrt_c_api_cpu.cc'ye benzer şekilde):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

3. Adım: C API uygulamalarını test etme

Temel PJRT C API davranışlarıyla ilgili küçük bir test grubu çalıştırmak için RegisterPjRtCApiTestFactory'yi çağırabilirsiniz.

JAX'tan PJRT eklentisi nasıl kullanılır?

1. Adım: JAX'i kurun

JAX'ı gece kullanabilirsiniz

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

veya kaynaktan JAX oluşturun.

Şimdilik jaxlib sürümünü PJRT C API sürümüyle eşleştirmeniz gerekiyor. Genellikle, eklentinizi oluştururken kullandığınız TF kaydı ile aynı gün içindeki jaxlib gece sürümünü kullanmanız yeterlidir, ör.

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

Kaynaktan tam olarak kullandığınız XLA kaydında da bir jaxlib oluşturabilirsiniz (instructions).

ABI uyumluluğunu yakında desteklemeye başlayacağız.

2. Adım: jax_plugins ad alanını kullanın veya giriş_noktasını ayarlayın

Eklentinizin JAX tarafından keşfedilmesi için iki seçenek vardır.

  1. Ad alanı paketlerini kullanma (ref). jax_plugins ad alanı paketi altında genel olarak benzersiz bir modül tanımlayın (yani bir jax_plugins dizini oluşturup modülünüzü onun altında tanımlayın). Aşağıda örnek bir dizin yapısı verilmiştir:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. Paket meta verilerini kullanma (ref). pyproject.toml veya setup.py aracılığıyla bir paket oluşturuyorsanız jax_plugins grubu altında tam modül adınıza işaret eden bir giriş noktası ekleyerek eklenti modülünüzün adını tanıtın. pyproject.toml veya setup.py üzerinden alınan bir örneği burada görebilirsiniz:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

2. Seçenek kullanılarak openxla-pjrt-plugin'in nasıl uygulandığına dair örnekler: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

3. Adım: initialize() yöntemini uygulayın

Eklentiyi kaydetmek için python modülünüzde bir initialize() yöntemi uygulamanız gerekir. Örneğin:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

xla_bridge.register_plugin uygulamasını nasıl kullanacağınızı öğrenmek için lütfen buraya göz atın. Bu yöntem şu anda gizli bir yöntemdir. Gelecekte herkese açık bir API yayınlanacaktır.

Eklentinin kayıtlı olduğunu doğrulamak ve yüklenememesi durumunda hata bildirmek için aşağıdaki satırı çalıştırabilirsiniz.

jax.config.update("jax_platforms", "my_plugin")

JAX'ın birden fazla arka ucu/eklentisi olabilir. Eklentinizin varsayılan arka uç olarak kullanıldığından emin olmak için birkaç seçenek vardır:

  • 1. seçenek: jax.config.update("jax_platforms", "my_plugin") uygulamasını programın başında çalıştırma.
  • 2. seçenek: ENV'yi JAX_PLATFORMS=my_plugin ayarlayın.
  • 3. Seçenek: xb.register_plugin'i çağırırken yeterince yüksek bir öncelik belirleyin (varsayılan değer, diğer mevcut arka uçlardan daha yüksek olan 400'dür). En yüksek önceliğe sahip arka ucun yalnızca JAX_PLATFORMS='' durumunda kullanılacağını unutmayın. JAX_PLATFORMS öğesinin varsayılan değeri '' olsa da bazen bu değerin üzerine yazılabilir.

JAX ile nasıl test yapılır?

Deneyebileceğiniz bazı temel test durumları:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(Yakında eklentinize karşı jax birimi testlerini çalıştırmaya ilişkin talimatlar ekleyeceğiz!)

Örnek: JAX CUDA eklentisi

  1. Sarmalayıcı (pjrt_c_api_gpu.h) üzerinden PJRT C API uygulaması.
  2. Paketin giriş noktasını ayarlayın (setup.py).
  3. Bir initialize() yöntemini uygulayın (__init__.py).
  4. CUDA için tüm jax testleriyle test edilebilir. ```