Skip to content

Latest commit

 

History

History

generative

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

AI Edge Torch Generative API

Our Generative API library provides PyTorch native building blocks for composing Transformer models such as Gemma, TinyLlama and others using mobile-friendly abstractions, through which we can guarantee conversion, and performant execution on our mobile runtime, TensorFlow Lite.

Before proceeding, please note:

  • This is only v0.1 of the API, an early developer preview in the interest of developing openly in the community.
  • The API is unstable, and we expect it to change over the next several months.
  • The library is in early development. Please expect rough edges. Some known issues are listed below.

System Overview

The system is designed to help ML practitioners deploy their trained Large Language Models on mobile devices using the TFLite runtime. It assumes the user already has a trained model they are happy with, and is optimized for mobile inference.

  • Start with a trained PyTorch Large Language Model. You can choose any off the shelf model from huggingface.co, kaggle.com, or bring your own PyTorch model.
  • Re-author the model using the Edge Generative API. If our examples already contain it, it can save you time.
  • Quantize the model using our Quantization APIs. This is critical for reducing model size, and achieving reasonable performance.
  • Verify the model implementation, and quality using your model evaluation pipeline, including pre/post-processing steps for the LLM pipeline.
  • Convert the model, and get a TFLite Flatbuffer representing the mobile model.
  • Choose either approach below to deploy the end to end LLM Inference Pipeline.

Model Authoring using Edge Generative API

The library provides basic building blocks for common transformer models (encoder only, decoder only, or encoder-decoder style). As a mobile App developer who wants to integrate LLMs or transformer models into your Android or iOS app, you can re-author your PyTorch Large Language Model using these layers.

See our examples, which explain in detail how to re-compose popular architectures like Gemma, TinyLlama, and Phi-2 using the library. To do so, you need to have an understanding of the model structure (attention mechanism used, MLP layers) and also be familiar with writing PyTorch code. Our examples should help you get familiar with the process.

Quantization

Quantization can be done via the API exposed in quantize. To apply quantization, we need to create a configuration that fully expresses how the model should be quantized. This configuration is then passed into conversion, generating a quantized model.

quant_recipes.py contains a list of recipes that are known to be well-supported during runtime. For the average user, this is a good starting point to select the quantization scheme that is best suited for your deployment needs. After identifying the target recipe, the model can be quantized as follows. This example is extracted from generative/examples/quantize/example.py.

quant_config = quant_recipes.full_linear_int8_dynamic_recipe()
edge_model = ai_edge_torch.convert(
    model, (tokens, input_pos), quant_config=quant_config
)

Once converted, you will get a quantized .tflite model which will be ready for on-device execution.

Supported schemes

In the current release, the following schemes are supported:

  • Dynamic range quantization with FP32 activations and INT8 weights for linear ops
  • FP16 quantization with FP16 weights and FP32 activations and computation for all ops

These correspond to the available recipes in quant_recipes.py

Convert PyTorch LLM to a TFLite model

Once you re-author the model and validate its numerical accuracy, you can convert the nn.Module to TFLite format. Usually for LLMs, there are two entry functions (signatures) we can export: prefill and decode. Those two signatures only differ in the shape of arguments.

For example, in the generative/examples/test_models/toy_model_with_kv_cache.py, you can define inputs for both signatures:

Sample inputs for the prefill signature:

def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
idx = torch.unsqueeze(torch.arange(0, 100), 0)
input_pos = torch.arange(0, 100)
return idx, input_pos

Sample inputs for the decode signature:

def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
idx = torch.tensor([[1]], dtype=torch.long)
input_pos = torch.tensor([10], dtype=torch.int64)
return idx, input_pos

Then export the model to TFLite with:

print('converting toy model to tflite with 2 signatures (prefill + decode)')
edge_model = (
ai_edge_torch.signature('prefill', model, (idx, input_pos))
.signature('decode', model, (decode_idx, decode_input_pos))
.convert()
)
edge_model.export('/tmp/toy_kv_cache.tflite')

Please note that using the prefill and decode method conventions are required for easy integration into the Mediapipe LLM Inference API.

End-to-End Inference Pipeline

The model files typically only perform the core ML computation in the LLM pipeline. Deploying the full pipeline requires handling tokenization, sampling and any other pre or post-processing steps required by your system. There are two ways to deploy the converted LLMs on device as part of a full LLM Inference Pipeline.

Use TFLite Runtime APIs

The user needs to implement the entire LLM Pipeline themselves, and call TFLite Runtime APIs directly to invoke the model. A text generation pipeline typically requires a tokenizer/detokenizer and a sampler, in addition to model inference. The tokenizer converts the input text from a string to a list of integers. The prefill signature ingests the sequence of input tokens, and the decode signature is invoked to obtain a tensor of logits. The sampler selects a token based on the provided logits, and the decode loop is repeated autoregressively. Ultimately, the detokenizer maps the generated tokens back into human-readable text.

This approach provides users with the most control. For example, they can implement streaming, get more control over system memory or implement advanced features such as constrained grammar decoding, speculative decoding etc.

A very simple text generation pipeline based on a decoder-only-transformer is provided here for reference. Note that this example serves as a starting point, and users are expected to implement their own pipelines based on their model's specific requirements.

Use MediaPipe LLM Inference API

The MediaPipe LLM Inference API is a high-level API which supports LLM Inference using a prompt-in/prompt-out interface. While it supports some models "out of the box", you can also provide it LLMs converted via our Generative API, and get access to a simple high level interface with Java, and Swift bindings to easily integrate with Mobile Apps. It takes care of all the complexity of implementing the LLM pipeline under the hood, and makes deployment much easier. Unless, you want to explicitly control the pipeline, we would recommend using this for robustness, and ease of use.

To deploy using the MP LLM Inference API, you need to

  • Ensure you convert models using the expected convention of prefill, and decode functions in the examples. The pipeline only supports SentencePiece tokenizer, but it can support a wide variety of models.
  • Bundle the converted TFLite files along with some other configurations such as start/stop tokens, tokenizer model etc. See here
  • Once the bundle is created, you can easily invoke the pipeline using the mobile APIs here.

Model visualization

Install the Model Explorer package using the following command:

pip install ai-edge-model-explorer

Detailed install instruction can be found here.

Visualize the model using CLI

model-explorer 'gemma_seq512_kv1024.tflite'

Gemma-2b visualization demo

For an end-to-end example showing how to author, convert, quantize and execute, please refer to the steps here

What to expect

Future Roadmap

  • Expanded accleration support on mobile, and web GPUs, and mobile NPUs.
  • Advanced quantization approaches suitable for LLMs.
  • Expanded support of models, including Diffusion models.
  • LoRA support.

Known Issues

The following are known product issues we are actively working to fix.

  • The conversion, and serialization process is unoptimized for LLMs. It requires keeping multiple copies of the weights in memory for transformations, and serialization/deserialization.
  • Runtime execution of the LLM in TFLite is missing some memory optimizations, and inefficient during memory unpacking on XNNPack.