The hyperparameters for a Bert Classifier.
Inherits From: BaseHParams
mediapipe_model_maker.text_classifier.BertHParams(
learning_rate: float = 3e-05,
batch_size: int = 48,
epochs: int = 2,
steps_per_epoch: Optional[int] = None,
class_weights: Optional[Mapping[int, float]] = None,
shuffle: bool = False,
repeat: bool = False,
export_dir: str = tempfile.mkdtemp(),
distribution_strategy: str = 'off',
num_gpus: int = 0,
tpu: str = '',
end_learning_rate: float = 0.0,
optimizer: mediapipe_model_maker.text_classifier.BertOptimizer
= mediapipe_model_maker.text_classifier.BertHParams.optimizer
,
weight_decay: float = 0.01,
desired_precisions: Sequence[float] = dataclasses.field(default_factory=list),
desired_recalls: Sequence[float] = dataclasses.field(default_factory=list),
gamma: float = 2.0,
tokenizer: mediapipe_model_maker.text_classifier.SupportedBertTokenizers
= mediapipe_model_maker.text_classifier.BertHParams.tokenizer
,
checkpoint_frequency: int = 0
)
Attributes |
learning_rate
|
Learning rate to use for gradient descent training.
|
end_learning_rate
|
End learning rate for linear decay. Defaults to 0.
|
batch_size
|
Batch size for training. Defaults to 48.
|
epochs
|
Number of training iterations over the dataset. Defaults to 2.
|
optimizer
|
Optimizer to use for training. Supported values are defined in
BertOptimizer enum: ADAMW and LAMB.
|
weight_decay
|
Weight decay of the optimizer. Defaults to 0.01.
|
desired_precisions
|
If specified, adds a RecallAtPrecision metric per
desired_precisions[i] entry which tracks the recall given the constraint
on precision. Only supported for binary classification.
|
desired_recalls
|
If specified, adds a PrecisionAtRecall metric per
desired_recalls[i] entry which tracks the precision given the constraint
on recall. Only supported for binary classification.
|
gamma
|
Gamma parameter for focal loss. To use cross entropy loss, set this
value to 0. Defaults to 2.0.
|
tokenizer
|
Tokenizer to use for preprocessing. Must be one of the enum
options of SupportedBertTokenizers. Defaults to FULL_TOKENIZER.
|
checkpoint_frequency
|
Frequency(in epochs) of saving checkpoints during
training. Defaults to 0 which does not save training checkpoints.
|
steps_per_epoch
|
Dataclass field
|
class_weights
|
Dataclass field
|
shuffle
|
Dataclass field
|
repeat
|
Dataclass field
|
export_dir
|
Dataclass field
|
distribution_strategy
|
Dataclass field
|
num_gpus
|
Dataclass field
|
tpu
|
Dataclass field
|
Methods
get_strategy
View source
get_strategy()
__eq__
__eq__(
other
)
Class Variables |
batch_size
|
48
|
checkpoint_frequency
|
0
|
class_weights
|
None
|
distribution_strategy
|
'off'
|
end_learning_rate
|
0.0
|
epochs
|
2
|
export_dir
|
'/tmpfs/tmp/tmpnt_h4p9w'
|
gamma
|
2.0
|
learning_rate
|
3e-05
|
num_gpus
|
0
|
optimizer
|
<BertOptimizer.ADAMW: 'adamw'>
|
repeat
|
False
|
shuffle
|
False
|
steps_per_epoch
|
None
|
tokenizer
|
<SupportedBertTokenizers.FULL_TOKENIZER: 'fulltokenizer'>
|
tpu
|
''
|
weight_decay
|
0.01
|