View source on GitHub |
ObjectDetector for building object detection model.
mediapipe_model_maker.object_detector.ObjectDetector(
model_spec: mediapipe_model_maker.object_detector.ModelSpec
,
label_names: List[str],
hparams: mediapipe_model_maker.object_detector.HParams
,
model_options: mediapipe_model_maker.object_detector.ModelOptions
) -> None
Methods
create
@classmethod
create( train_data:
mediapipe_model_maker.object_detector.Dataset
, validation_data:mediapipe_model_maker.object_detector.Dataset
, options:mediapipe_model_maker.object_detector.ObjectDetectorOptions
) -> 'ObjectDetector'
Creates and trains an ObjectDetector.
Loads data and trains the model based on data for object detection.
Args | |
---|---|
train_data
|
Training data. |
validation_data
|
Validation data. |
options
|
Configurations for creating and training object detector. |
Returns | |
---|---|
An instance of ObjectDetector. |
evaluate
evaluate(
dataset: mediapipe_model_maker.object_detector.Dataset
,
batch_size: int = 1
) -> Tuple[List[float], Dict[str, float]]
Overrides Classifier.evaluate to calculate COCO metrics.
export_labels
export_labels(
export_dir: str, label_filename: str = 'labels.txt'
)
Exports classification labels into a label file.
Args | |
---|---|
export_dir
|
The directory to save exported files. |
label_filename
|
File name to save labels model. The full export path is {export_dir}/{label_filename}. |
export_model
export_model(
model_name: str = 'model.tflite',
quantization_config: Optional[mediapipe_model_maker.quantization.QuantizationConfig
] = None
)
Converts and saves the model to a TFLite file with metadata included.
The model export format is automatically set based on whether or not
quantization_aware_training
(QAT) was run. The model exports to float32 by
default and will export to an int8 quantized model if QAT was run. To export
a float32 model after running QAT, run restore_float_ckpt
before this
method. For custom post-training quantization without QAT, use the
quantization_config parameter.
Note that only the TFLite file is needed for deployment. This function also saves a metadata.json file to the same directory as the TFLite file which can be used to interpret the metadata content in the TFLite file.
Args | |
---|---|
model_name
|
File name to save TFLite model with metadata. The full export path is {self._hparams.export_dir}/{model_name}. |
quantization_config
|
The configuration for model quantization. Note that int8 quantization aware training is automatically applied when possible. This parameter is used to specify other post-training quantization options such as fp16 and int8 without QAT. |
Raises | |
---|---|
ValueError
|
If a custom quantization_config is specified when the model has quantization aware training enabled. |
export_tflite
export_tflite(
export_dir: str,
tflite_filename: str = 'model.tflite',
quantization_config: Optional[mediapipe_model_maker.quantization.QuantizationConfig
] = None,
preprocess: Optional[Callable[..., bool]] = None
)
Converts the model to requested formats.
Args | |
---|---|
export_dir
|
The directory to save exported files. |
tflite_filename
|
File name to save TFLite model. The full export path is {export_dir}/{tflite_filename}. |
quantization_config
|
The configuration for model quantization. |
preprocess
|
A callable to preprocess the representative dataset for quantization. The callable takes three arguments in order: feature, label, and is_training. |
quantization_aware_training
quantization_aware_training(
train_data: mediapipe_model_maker.object_detector.Dataset
,
validation_data: mediapipe_model_maker.object_detector.Dataset
,
qat_hparams: mediapipe_model_maker.object_detector.QATHParams
) -> None
Runs quantization aware training(QAT) on the model.
The QAT step happens after training a regular float model from the create
method. This additional step will fine-tune the model with a lower precision
in order mimic the behavior of a quantized model. The resulting quantized
model generally has better performance than a model which is quantized
without running QAT. See the following link for more information:
Just like training the float model using the create
method, the QAT step
also requires some manual tuning of hyperparameters. In order to run QAT
more than once for purposes such as hyperparameter tuning, use the
restore_float_ckpt
method to restore the model state to the trained float
checkpoint without having to rerun the create
method.
Args | |
---|---|
train_data
|
Training dataset. |
validation_data
|
Validaiton dataset. |
qat_hparams
|
Configuration for QAT. |
restore_float_ckpt
restore_float_ckpt() -> None
Loads a float checkpoint of the model from {hparams.export_dir}/float_ckpt.
The float checkpoint at {hparams.export_dir}/float_ckpt is automatically
saved after training an ObjectDetector using the create
method. This
method is used to restore the trained float checkpoint state of the model in
order to run quantization_aware_training
multiple times. Example usage:
Train a model
model = object_detector.create(...)
Run QAT
model.quantization_aware_training(...) model.evaluate(...)
Restore the float checkpoint to run QAT again
model.restore_float_ckpt()
Run QAT with different parameters
model.quantization_aware_training(...) model.evaluate(...)
summary
summary()
Prints a summary of the model.