TensorLabel is an util wrapper for TensorBuffers with meaningful labels on an axis.
For example, an image classification model may have an output tensor with shape as {1, 10},
where 1 is the batch size and 10 is the number of categories. In fact, on the 2nd axis, we could
label each sub-tensor with the name or description of each corresponding category. TensorLabel
could help converting the plain Tensor in TensorBuffer
into a map from
predefined labels to sub-tensors. In this case, if provided 10 labels for the 2nd axis, TensorLabel
could convert the original {1, 10} Tensor to a 10 element map, each value of which
is Tensor in shape {} (scalar). Usage example:
TensorBuffer outputTensor = ...; List<String> labels = FileUtil.loadLabels(context, labelFilePath); // labels the first axis with size greater than one TensorLabel labeled = new TensorLabel(labels, outputTensor); // If each sub-tensor has effectively size 1, we can directly get a float value Map<String, Float> probabilities = labeled.getMapWithFloatValue(); // Or get sub-tensors, when each sub-tensor has elements more than 1 Map<String, TensorBuffer> subTensors = labeled.getMapWithTensorBuffer();
Note: currently we only support tensor-to-map conversion for the first label with size greater than 1.
Public Constructors
TensorLabel(Map<Integer, List<String>> axisLabels, TensorBuffer tensorBuffer)
Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors.
|
|
TensorLabel(List<String> axisLabels, TensorBuffer tensorBuffer)
Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors.
|
Public Methods
List<Category> |
getCategoryList()
Gets a list of
Category from the TensorLabel object. |
Map<String, Float> |
getMapWithFloatValue()
Gets a map that maps label to float.
|
Map<String, TensorBuffer> |
getMapWithTensorBuffer()
Gets the map with a pair of the label and the corresponding TensorBuffer.
|
Inherited Methods
Public Constructors
public TensorLabel (Map<Integer, List<String>> axisLabels, TensorBuffer tensorBuffer)
Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors.
Parameters
axisLabels | A map, whose key is axis id (starting from 0) and value is corresponding labels. Note: The size of labels should be same with the size of the tensor on that axis. |
---|---|
tensorBuffer | The TensorBuffer to be labeled. |
Throws
NullPointerException | if axisLabels or tensorBuffer is null, or any
value in axisLabels is null. |
---|---|
IllegalArgumentException | if any key in axisLabels is out of range (compared to
the shape of tensorBuffer , or any value (labels) has different size with the tensorBuffer on the given dimension.
|
public TensorLabel (List<String> axisLabels, TensorBuffer tensorBuffer)
Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors.
Note: The labels are applied on the first axis whose size is larger than 1. For example, if
the shape of the tensor is [1, 10, 3], the labels will be applied on axis 1 (id starting from
0), and size of axisLabels
should be 10 as well.
Parameters
axisLabels | A list of labels, whose size should be same with the size of the tensor on the to-be-labeled axis. |
---|---|
tensorBuffer | The TensorBuffer to be labeled. |
Public Methods
public List<Category> getCategoryList ()
Gets a list of Category
from the TensorLabel
object.
The axis of label should be effectively the last axis (which means every sub tensor
specified by this axis should have a flat size of 1), so that each labelled sub tensor could be
converted into a float value score. Example: A TensorLabel
with shape {2, 5, 3}
and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a Category
.
getMapWithFloatValue()
is an alternative but returns a Map
as
the result.
Throws
IllegalStateException | if size of a sub tensor on each label is not 1. |
---|
public Map<String, Float> getMapWithFloatValue ()
Gets a map that maps label to float. Only allow the mapping on the first axis with size greater than 1, and the axis should be effectively the last axis (which means every sub tensor specified by this axis should have a flat size of 1).
getCategoryList()
is an alternative API to get the result.
Throws
IllegalStateException | if size of a sub tensor on each label is not 1. |
---|
public Map<String, TensorBuffer> getMapWithTensorBuffer ()
Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the mapping on the first axis with size greater than 1 currently.