Skip to content

Commit ce15879

Browse files
ziyeqinghancopybara-github
authored andcommitted
Add inference_input_type and inference_output_type for quantization in model maker.
PiperOrigin-RevId: 308995760
1 parent 7a763b5 commit ce15879

3 files changed

Lines changed: 32 additions & 6 deletions

File tree

tensorflow_examples/lite/model_maker/core/task/custom_model.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ def _export_tflite(self,
143143
tflite_filepath,
144144
quantized=False,
145145
quantization_steps=None,
146-
representative_data=None):
146+
representative_data=None,
147+
inference_input_type=tf.float32,
148+
inference_output_type=tf.float32):
147149
"""Converts the retrained model to tflite format and saves it.
148150
149151
Args:
@@ -153,6 +155,12 @@ def _export_tflite(self,
153155
to run. Used only if `quantized` is True.
154156
representative_data: Representative data used for post-training
155157
quantization. Used only if `quantized` is True.
158+
inference_input_type: Target data type of real-number input arrays. Allows
159+
for a different type for input arrays. Defaults to tf.float32. Must be
160+
be `{tf.float32, tf.uint8, tf.int8}`
161+
inference_output_type: Target data type of real-number output arrays.
162+
Allows for a different type for output arrays. Defaults to tf.float32.
163+
Must be `{tf.float32, tf.uint8, tf.int8}`
156164
"""
157165
if tflite_filepath is None:
158166
raise ValueError(
@@ -181,8 +189,8 @@ def _export_tflite(self,
181189
get_representative_dataset_gen(ds, quantization_steps))
182190

183191
converter.optimizations = [tf.lite.Optimize.DEFAULT]
184-
converter.inference_input_type = tf.uint8
185-
converter.inference_output_type = tf.uint8
192+
converter.inference_input_type = inference_input_type
193+
converter.inference_output_type = inference_output_type
186194
converter.target_spec.supported_ops = [
187195
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
188196
]

tensorflow_examples/lite/model_maker/core/task/image_classifier.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,8 @@ def _export_tflite(self,
268268
quantized=False,
269269
quantization_steps=None,
270270
representative_data=None,
271+
inference_input_type=tf.float32,
272+
inference_output_type=tf.float32,
271273
with_metadata=False,
272274
export_metadata_json_file=False):
273275
"""Converts the retrained model to tflite format and saves it.
@@ -281,14 +283,21 @@ def _export_tflite(self,
281283
to run. Used only if `quantized` is True.
282284
representative_data: Representative data used for post-training
283285
quantization. Used only if `quantized` is True.
286+
inference_input_type: Target data type of real-number input arrays. Allows
287+
for a different type for input arrays. Defaults to tf.float32. Must be
288+
be `{tf.float32, tf.uint8, tf.int8}`
289+
inference_output_type: Target data type of real-number output arrays.
290+
Allows for a different type for output arrays. Defaults to tf.float32.
291+
Must be `{tf.float32, tf.uint8, tf.int8}`
284292
with_metadata: Whether the output tflite model contains metadata.
285293
export_metadata_json_file: Whether to export metadata in json file. If
286294
True, export the metadata in the same directory as tflite model.Used
287295
only if `with_metadata` is True.
288296
"""
289297
super(ImageClassifier,
290298
self)._export_tflite(tflite_filepath, quantized, quantization_steps,
291-
representative_data)
299+
representative_data, inference_input_type,
300+
inference_output_type)
292301
if with_metadata:
293302
if not metadata.TFLITE_SUPPORT_TOOLS_INSTALLED:
294303
tf.compat.v1.logging.warning('Needs to install tflite-support package.')

tensorflow_examples/lite/model_maker/core/task/text_classifier.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,9 @@ def _export_tflite(self,
187187
tflite_filepath,
188188
quantized=False,
189189
quantization_steps=None,
190-
representative_data=None):
190+
representative_data=None,
191+
inference_input_type=tf.float32,
192+
inference_output_type=tf.float32):
191193
"""Converts the retrained model to tflite format and saves it.
192194
193195
Args:
@@ -197,11 +199,18 @@ def _export_tflite(self,
197199
to run. Used only if `quantized` is True.
198200
representative_data: Representative data used for post-training
199201
quantization. Used only if `quantized` is True.
202+
inference_input_type: Target data type of real-number input arrays. Allows
203+
for a different type for input arrays. Defaults to tf.float32. Must be
204+
be `{tf.float32, tf.uint8, tf.int8}`
205+
inference_output_type: Target data type of real-number output arrays.
206+
Allows for a different type for output arrays. Defaults to tf.float32.
207+
Must be `{tf.float32, tf.uint8, tf.int8}`
200208
"""
201209
# Sets batch size from None to 1 when converting to tflite.
202210
self._set_batch_size(self.model, batch_size=1)
203211
super(TextClassifier,
204212
self)._export_tflite(tflite_filepath, quantized, quantization_steps,
205-
representative_data)
213+
representative_data, inference_input_type,
214+
inference_output_type)
206215
# Sets batch size back to None to support retraining later.
207216
self._set_batch_size(self.model, batch_size=None)

0 commit comments

Comments
 (0)