This is an end-to-end LiteRT.js guide covering the process of converting a PyTorch model to run in the browser with WebGPU acceleration. This example uses ResNet18 for the vision model, and TensorFlow.js for pre- and post-processing.
The guide will cover the following steps:
- Convert your PyTorch model to LiteRT using AI Edge
Torch.
- Add the LiteRT package to your web app.
- Load the model.
- Write pre- and post-processing logic.
Convert to LiteRT
Use the PyTorch Converter
notebook
to convert a PyTorch model to the appropriate .tflite
format. For an in-depth
guide on the types of errors you may encounter and how to fix them, see the AI
Edge Torch Converter
README.
Your model must be compatible with
torch.export.export
, which
means it must be exportable with TorchDynamo. Therefore, it must not have any
python conditional branches that depend on the runtime values within tensors. If
you see the following errors during
torch.export.export
,
your model is not exportable with torch.export.export
. Your model also must
not have any dynamic input or output dimensions on its tensors. This includes
batch dimension.
You can also start with a TensorRT-compatible or ONNX-exportable PyTorch model:
A TensorRT-compatible version of a model can be a good starting point, since some types of TensorRT conversions also require models to be TorchDynamo exportable. If you use any NVIDIA / CUDA ops in the model, you will need to replace them with standard PyTorch ops.
An ONNX-exportable PyTorch model can be a good starting point, though some ONNX models use TorchScript instead of TorchDynamo to export, in which case the model may not be TorchDynamo-exportable (although it's likely closer than the original model code).
For more information, see Convert PyTorch models to LiteRT.
Add the LiteRT package
Install the @litertjs/core
package from npm:
npm install @litertjs/core
Import the package and load its Wasm files:
import {loadLiteRt} from '@litertjs/core';
// They are located in node_modules/@litertjs/core/wasm/
// Serve them statically on your server.
await loadLiteRt(`your/path/to/wasm/`);
Load the model
Import and initialize LiteRT.js and the LiteRT-TFJS conversion utilities. You also need to import TensorFlow.js to pass tensors to LiteRT.js.
import {CompileOptions, loadAndCompile, loadLiteRt, setWebGpuDevice} from '@litertjs/core';
import {runWithTfjsTensors} from '@litertjs/tfjs-interop';
// TensorFlow.js imports
import * as tf from '@tensorflow/tfjs';
import '@tensorflow/tfjs-backend-webgpu'; // Only WebGPU is supported
import {WebGPUBackend} from '@tensorflow/tfjs-backend-webgpu';
async function main() {
// Initialize TensorFlow.js WebGPU backend
await tf.setBackend('webgpu');
// Initialize LiteRT.js's Wasm files
await loadLiteRt('your/path/to/wasm/');
// Make LiteRt use the same GPU device as TFJS (for tensor conversion)
const backend = tf.backend() as WebGPUBackend;
setWebGpuDevice(backend.device);
// ...
}
main();
Load the converted LiteRT model:
const model = await loadAndCompile('path_to_model.tflite', {
accelerator: 'webgpu', // or 'wasm'
});
Write the model pipeline
Write the pre- and post-processing logic that connects the model to your app.
Using TensorFlow.js for pre- and post-processing is recommended, but if it is
not written in TensorFlow.js, you can call await
tensor.data
to get the
value as an ArrayBuffer or await
tensor.array
to get a
structured JS array.
The following is an example end-to-end pipeline for ResNet18:
// Wrap in a tf.tidy call to automatically clean up intermediate TensorFlow.js tensors.
// (Note: tidy only supports synchronous functions).
const top5 = tf.tidy(() => {
// Get RGB data values from an image element and convert it to range [0, 1).
const image = tf.browser.fromPixels(dogs, 3).div(255);
// These preprocessing steps come from https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L315
// The mean and standard deviation for the image normalization come from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_presets.py#L38
const imageData = image.resizeBilinear([224, 224])
.sub([0.485, 0.456, 0.406])
.div([0.229, 0.224, 0.225])
.reshape([1, 224, 224, 3])
.transpose([0, 3, 1, 2]);
// You can pass inputs as a single tensor, an array, or a JS Object
// where keys are the tensor names in the TFLite model.
// When passing an Object, the output is also an Object.
// Here, we're passing a single tensor, so the output is an array.
const probabilities = runWithTfjsTensors(model, imageData)[0];
// Get the top five classes.
return tf.topk(probabilities, 5);
});
const values = await top5.values.data();
const indices = await top5.indices.data();
top5.values.dispose(); // Clean up the tfjs tensors.
top5.indices.dispose();
// Print the top five classes.
const classes = ... // Class names are loaded from a JSON file in the demo.
for (let i = 0; i < 5; ++i) {
const text = `${classes[indices[i]]}: ${values[i]}`;
console.log(text);
}
Testing and troubleshooting
Refer to the following sections on ways to test your application and handle errors.
Testing with fake inputs
After loading a model, it's a good idea to test the model with fake inputs first. This will catch any runtime errors before you spend the time writing the pre and post processing logic for your model pipeline. To check this, you can use the LiteRT.js Model Tester or test it manually.
LiteRT.js Model Tester
The LiteRT.js Model Tester runs your model on GPU and CPU using random inputs to verify that the model runs correctly on GPU. It checks the following:
- Whether the input and output data types are supported.
- Whether all ops are available on GPU.
- How closely the GPU outputs match the reference CPU outputs.
- The performance of GPU inference.
To run the LiteRT.js Model Tester, run npm i @litertjs/model-tester
and then
npx model-tester
. It will open a browser tab for you to run your model.
Manual model testing
If you prefer to manually test the model instead of using the LiteRT.js model
tester (@litertjs/model-tester
), you can generate fake inputs and run the
model with runWithTfjsTensors
.
To generate fake inputs, you need to know the names and shapes of the input
tensors. These can be found with LiteRT.js by calling model.getInputDetails
or
model.getOutputDetails
. A simple way to find them is to set a breakpoint after
the model is created. Alternatively, use Model
Explorer.
Once you know the input and output shapes and names, you can test the model with a fake input. This gives some confidence that the model will run before you write the rest of the machine learning pipeline. This would test that all model operations are supported. For example:
// Imports, initialization, and model loading...
// Create fake inputs for the model
const fakeInputs = model.getInputDetails().map(
({shape, dtype}) => tf.ones(shape, dtype));
// Run the model
const outputs = runWithTfjsTensors(model, fakeInputs);
console.log(outputs);
Error types
Some LiteRT models may not be supported by LiteRT.js. Errors usually fall into these categories:
- Shape Mismatch: A known bug that only affects GPU.
- Operation Not Supported: The runtime doesn't support an operation in the model. The WebGPU backend has more limited coverage than CPU, so if you're seeing this error on GPU, you may be able to run the model on CPU instead.
- Unsupported Tensor Type: LiteRT.js only supports int32 and float32 tensors for model inputs and outputs.
- Model Too Large: LiteRT.js is limited in the size of models it can load.
Operation Not Supported
This indicates that the backend being used does not support one of the operations in the model. You will need to rewrite the original PyTorch model to avoid this op and re-convert it, or you may be able to run the model on CPU.
In the case of BROADCAST_TO
, this may be solved by making the batch dimension
the same for every input tensor to the model. Other cases may be more
complicated.
Unsupported Tensor Type
LiteRT.js only supports int32 and float32 tensors for the model's inputs and outputs.
Model Too Large
This usually appears as a call to Aborted()
or a memory allocation failure at
model-loading time. LiteRT.js is limited in the size of models it can load, so
if you're seeing this, your model may be too large. You can try quantizing the
weights with the
ai-edge-quantizer, but
keep computations at float32 or float16, and model inputs and outputs as float32
or int32.