דגמי JAX עם LiteRT

הדף הזה מספק נתיב למשתמשים שרוצים לאמן מודלים ב-JAX ולפרוס לנייד לצורך הסקת מסקנות. השיטות במדריך הזה מפיקות tflite_model אפשר להשתמש בו ישירות עם הדוגמה של קוד התרגום ב-LiteRT, או לשמור אותו קובץ FlatBuffer של tflite.

דוגמה מקצה לקצה מופיעה במדריך למתחילים.

דרישות מוקדמות

מומלץ לנסות את התכונה הזו בגרסת Python החדשה ביותר של TensorFlow בכל לילה חבילה.

pip install tf-nightly --upgrade

נשתמש ב-Orbax ייצוא הספרייה אל לייצא מודלים של JAX. מוודאים שגרסת JAX היא לפחות 0.4.20 ואילך.

pip install jax --upgrade
pip install orbax-export --upgrade

המרת מודלים של JAX ל-LiteRT

אנחנו משתמשים ב-SavedModel של TensorFlow בתור הפורמט הביניים בין JAX ל-LiteRT. לאחר שמירה של Model ניתן להשתמש בממשקי API קיימים של LiteRT להשלמת תהליך ההמרה.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

בדיקת מודל ה-TFLite שהומר

אחרי שהמודל יומר ל-TFLite, תוכלו להריץ ממשקי API של המתורגמן של TFLite כדי לבדוק את הפלט של המודל.

# Run the model with LiteRT
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])