PyTorch-Modelle in LiteRT konvertieren

AI Edge Torch ist eine Bibliothek, mit der Sie PyTorch-Modelle in ein .tflite-Format konvertieren können, um sie mit LiteRT und MediaPipe auszuführen. Das ist besonders hilfreich für Entwickler, die mobile Apps erstellen, in denen Modelle vollständig auf dem Gerät ausgeführt werden. AI Edge Torch bietet eine breite CPU-Abdeckung mit anfänglicher GPU- und NPU-Unterstützung.

Wenn Sie PyTorch-Modelle in LiteRT konvertieren möchten, verwenden Sie die PyTorch-Konvertierungs-Schnellstartanleitung. Weitere Informationen finden Sie im GitHub-Repository für AI Edge Torch.

Wenn Sie speziell Large Language Models (LLMs) oder transformerbasierte Modelle konvertieren, verwenden Sie die Generative Torch API. Diese API verarbeitet transformerspezifische Konvertierungsdetails wie das Erstellen und Quantisieren von Modellen.

Conversion-Workflow

Die folgenden Schritte veranschaulichen eine einfache End-to-End-Konvertierung eines PyTorch-Modells in LiteRT.

AI Edge Torch importieren

Importieren Sie zuerst das Pip-Paket „AI Edge Torch“ (ai-edge-torch) und dann PyTorch.

import ai_edge_torch
import torch

Für dieses Beispiel sind außerdem die folgenden Pakete erforderlich:

import numpy
import torchvision

Modell initialisieren und konvertieren

Wir konvertieren ResNet18, ein beliebtes Modell für die Bilderkennung.

resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1).eval()

Verwenden Sie die Methode convert aus der AI Edge Torch-Bibliothek, um das PyTorch-Modell zu konvertieren.

sample_input = (torch.randn(1, 3, 224, 224),)
edge_model = ai_edge_torch.convert(resnet18.eval(), sample_input)

Modell verwenden

Nachdem Sie das PyTorch-Modell konvertiert haben, können Sie mit dem neuen konvertierten LiteRT-Modell Inferenzen ausführen.

output = edge_model(*sample_inputs)

Sie können das konvertierte Modell für die spätere Verwendung im .tflite-Format exportieren und speichern.

edge_model.export('resnet.tflite')