Skip to content

Commit dfec7f9

Browse files
authored
Merge pull request #6 from adiazulay/tflite-patch
Turns on TFLite Support
2 parents 1a5af47 + 57ab0fe commit dfec7f9

2 files changed

Lines changed: 3 additions & 4 deletions

File tree

src/lobe/ImageModel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def load_from_signature(signature: Signature) -> ImageModel:
1717
model_format = signature.format
1818
if model_format == "tf":
1919
from .backends import _backend_tf as backend
20-
elif model_format == "tflite":
20+
elif model_format == "tf_lite":
2121
from .backends import _backend_tflite as backend
2222
else:
2323
raise ValueError("Model is an unsupported format")

src/lobe/backends/_backend_tflite.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@ class ImageClassificationModel:
1616
__MAX_UINT8 = 255
1717

1818
def __init__(self, signature):
19-
self.__model_path = "{}/{}".format(
19+
self.__model_path = "{}/{}.tflite".format(
2020
signature.model_path, signature.filename
2121
)
2222
self.__tflite_predict_fn = None
2323
self.__labels = signature.classes
2424

25-
raise ImportError("TFLite not yet supported")
2625

2726
def __load(self):
2827
self.__tflite_predict_fn = tflite.Interpreter(
@@ -55,7 +54,7 @@ def predict(self, image: Image.Image) -> PredictionResult:
5554
)
5655

5756
confidences_output = self.__tflite_predict_fn.get_tensor(
58-
output_details[1]["index"]
57+
output_details[2]["index"]
5958
)
6059

6160
confidences = np.squeeze(confidences_output)

0 commit comments

Comments
 (0)