Skip to content

Commit 445d41a

Browse files
ziyeqinghancopybara-github
authored andcommitted
Add meta_data for image classification in model maker as default
PiperOrigin-RevId: 310107770
1 parent 25a0d86 commit 445d41a

7 files changed

Lines changed: 151 additions & 89 deletions

File tree

setup.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@
5555
with open(REQUIRMENTS) as f:
5656
MODEL_MAKER_REQUIRE = [l.strip() for l in f.read().splitlines() if l.strip()]
5757

58-
METADATA_REQUIRE = [
59-
'tflite-support==0.1.0a0',
60-
]
6158
if sys.version_info.major == 3:
6259
# Packages only for Python 3
6360
pass
@@ -86,7 +83,6 @@
8683
extras_require={
8784
'tests': TESTS_REQUIRE,
8885
'model_maker': MODEL_MAKER_REQUIRE,
89-
'metadata': METADATA_REQUIRE,
9086
},
9187
entry_points={
9288
'console_scripts': [

tensorflow_examples/lite/model_maker/core/task/image_classifier.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,18 @@
2121

2222
import tensorflow.compat.v2 as tf
2323

24+
from tensorflow_examples.lite.model_maker.core.task import metadata_writer_for_image_classifier as metadata_writer
25+
2426
from tensorflow_examples.lite.model_maker.core import compat
2527
from tensorflow_examples.lite.model_maker.core.export_format import ExportFormat
2628
from tensorflow_examples.lite.model_maker.core.task import classification_model
2729
from tensorflow_examples.lite.model_maker.core.task import hub_loader
2830
from tensorflow_examples.lite.model_maker.core.task import image_preprocessing
29-
from tensorflow_examples.lite.model_maker.core.task import metadata
3031
from tensorflow_examples.lite.model_maker.core.task import model_spec as ms
3132
from tensorflow_examples.lite.model_maker.core.task import train_image_classifier_lib
3233

3334
from tensorflow_hub.tools.make_image_classifier import make_image_classifier_lib as hub_lib
35+
from tflite_support import metadata as _metadata # pylint: disable=g-direct-tensorflow-import
3436

3537

3638
def get_hub_lib_hparams(**kwargs):
@@ -120,6 +122,35 @@ def create(train_data,
120122
return image_classifier
121123

122124

125+
def _get_model_info(model_spec, num_classes, quantized=False, version='v1'):
126+
"""Gets the specific info for the image model."""
127+
128+
if not isinstance(model_spec, ms.ImageModelSpec):
129+
raise ValueError('Currently only support models for image classification.')
130+
131+
name = model_spec.name
132+
if quantized:
133+
name += '_quantized'
134+
135+
if quantized and compat.get_tf_behavior() == 1:
136+
image_min = 0
137+
image_max = 255
138+
else:
139+
image_min = 0
140+
image_max = 1
141+
142+
return metadata_writer.ModelSpecificInfo(
143+
model_spec.name,
144+
version,
145+
image_width=model_spec.input_image_shape[1],
146+
image_height=model_spec.input_image_shape[0],
147+
mean=model_spec.mean_rgb,
148+
std=model_spec.stddev_rgb,
149+
image_min=image_min,
150+
image_max=image_max,
151+
num_classes=num_classes)
152+
153+
123154
class ImageClassifier(classification_model.ClassificationModel):
124155
"""ImageClassifier class for inference and exporting to tflite."""
125156

@@ -270,7 +301,7 @@ def _export_tflite(self,
270301
representative_data=None,
271302
inference_input_type=tf.float32,
272303
inference_output_type=tf.float32,
273-
with_metadata=False,
304+
with_metadata=True,
274305
export_metadata_json_file=False):
275306
"""Converts the retrained model to tflite format and saves it.
276307
@@ -299,22 +330,34 @@ def _export_tflite(self,
299330
representative_data, inference_input_type,
300331
inference_output_type)
301332
if with_metadata:
302-
if not metadata.TFLITE_SUPPORT_TOOLS_INSTALLED:
303-
tf.compat.v1.logging.warning('Needs to install tflite-support package.')
304-
return
305-
306333
if label_filepath is None:
307334
tf.compat.v1.logging.warning(
308335
'Label filepath is needed when exporting TFLite with metadata.')
309336
return
310337

311-
model_info = metadata.get_model_info(self.model_spec, quantized=quantized)
312-
populator = metadata.MetadataPopulatorForImageClassifier(
313-
tflite_filepath, model_info, label_filepath)
338+
model_basename = os.path.basename(tflite_filepath)
339+
export_directory = os.path.dirname(tflite_filepath)
340+
export_model_path = os.path.join(export_directory, model_basename)
341+
342+
model_info = _get_model_info(
343+
self.model_spec, self.num_classes, quantized=quantized)
344+
# Generate the metadata objects and put them in the model file
345+
populator = metadata_writer.MetadataPopulatorForImageClassifier(
346+
export_model_path, model_info, label_filepath)
314347
populator.populate()
315348

349+
# Validate the output model file by reading the metadata and produce
350+
# a json file with the metadata under the export path
316351
if export_metadata_json_file:
317-
metadata.export_metadata_json_file(tflite_filepath)
352+
displayer = _metadata.MetadataDisplayer.with_model_file(
353+
export_model_path)
354+
export_json_file = os.path.join(
355+
export_directory,
356+
os.path.splitext(model_basename)[0] + '.json')
357+
358+
content = displayer.get_metadata_json()
359+
with open(export_json_file, 'w') as f:
360+
f.write(content)
318361

319362
def _get_hparams_or_default(self, hparams):
320363
"""Returns hparams if not none, otherwise uses default one."""

tensorflow_examples/lite/model_maker/core/task/image_classifier_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from tensorflow_examples.lite.model_maker.core.data_util import image_dataloader
2727
from tensorflow_examples.lite.model_maker.core.export_format import ExportFormat
2828
from tensorflow_examples.lite.model_maker.core.task import image_classifier
29-
from tensorflow_examples.lite.model_maker.core.task import metadata
3029
from tensorflow_examples.lite.model_maker.core.task import model_spec
3130

3231

@@ -221,9 +220,6 @@ def _test_export_to_tflite_with_metadata(self,
221220

222221
self._check_label_file(labels_output_file)
223222

224-
if not metadata.TFLITE_SUPPORT_TOOLS_INSTALLED:
225-
return
226-
227223
self.assertTrue(os.path.isfile(json_output_file))
228224
self.assertGreater(os.path.getsize(json_output_file), 0)
229225

tensorflow_examples/lite/model_maker/core/task/metadata.py renamed to tensorflow_examples/lite/model_maker/core/task/metadata_writer_for_image_classifier.py

Lines changed: 89 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -11,82 +11,68 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Metadata populator for TFLite models."""
14+
# ==============================================================================
15+
"""Writes metadata and label file to the image classifier models."""
1516

1617
from __future__ import absolute_import
1718
from __future__ import division
1819
from __future__ import print_function
1920

2021
import os
2122

23+
from absl import app
24+
from absl import flags
2225
import tensorflow as tf
23-
from tensorflow_examples.lite.model_maker.core.task import model_spec as ms
24-
25-
TFLITE_SUPPORT_TOOLS_INSTALLED = True
26-
27-
try:
28-
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
29-
import flatbuffers
30-
from tflite_support import metadata as _metadata
31-
from tflite_support import metadata_schema_py_generated as _metadata_fb
32-
# pylint: enable=g-direct-tensorflow-import,g-import-not-at-top
33-
except ImportError:
34-
tf.compat.v1.logging.warning("Needs to install tflite-support package.")
35-
TFLITE_SUPPORT_TOOLS_INSTALLED = False
36-
37-
38-
def export_metadata_json_file(tflite_file):
39-
"""Exports metadata to json file."""
40-
displayer = _metadata.MetadataDisplayer.with_model_file(tflite_file)
41-
export_directory = os.path.dirname(tflite_file)
42-
try:
43-
json_file = os.path.join(
44-
export_directory,
45-
os.path.splitext(os.path.basename(tflite_file))[0] + ".json")
46-
with open(json_file, "w") as f:
47-
content = displayer.get_metadata_json()
48-
f.write(content)
49-
except AttributeError:
50-
# TODO(yuqili): Remove this line once the API is stable.
51-
displayer.export_metadata_json_file(export_directory)
52-
53-
54-
class ImageModelSpecificInfo(object):
26+
27+
import flatbuffers
28+
# pylint: disable=g-direct-tensorflow-import
29+
from tflite_support import metadata as _metadata
30+
from tflite_support import metadata_schema_py_generated as _metadata_fb
31+
# pylint: enable=g-direct-tensorflow-import
32+
33+
FLAGS = flags.FLAGS
34+
35+
36+
def define_flags():
37+
flags.DEFINE_string("model_file", None,
38+
"Path and file name to the TFLite model file.")
39+
flags.DEFINE_string("label_file", None, "Path to the label file.")
40+
flags.DEFINE_string("export_directory", None,
41+
"Path to save the TFLite model files with metadata.")
42+
flags.mark_flag_as_required("model_file")
43+
flags.mark_flag_as_required("label_file")
44+
flags.mark_flag_as_required("export_directory")
45+
46+
47+
class ModelSpecificInfo(object):
5548
"""Holds information that is specificly tied to an image classifier."""
5649

57-
def __init__(self,
58-
name,
59-
version,
60-
image_width,
61-
image_height,
62-
mean,
63-
std,
64-
image_min=0,
65-
image_max=1):
50+
def __init__(self, name, version, image_width, image_height, image_min,
51+
image_max, mean, std, num_classes):
6652
self.name = name
6753
self.version = version
6854
self.image_width = image_width
6955
self.image_height = image_height
70-
self.mean = mean
71-
self.std = std
7256
self.image_min = image_min
7357
self.image_max = image_max
58+
self.mean = mean
59+
self.std = std
60+
self.num_classes = num_classes
7461

7562

76-
def get_model_info(model_spec, quantized=False, version="v1"):
77-
if not isinstance(model_spec, ms.ImageModelSpec):
78-
raise ValueError("Currently only support models for image classification.")
79-
80-
name = model_spec.name
81-
if quantized:
82-
name += "_quantized"
83-
return ImageModelSpecificInfo(
84-
model_spec.name,
85-
version,
86-
image_width=model_spec.input_image_shape[1],
87-
image_height=model_spec.input_image_shape[0],
88-
mean=model_spec.mean_rgb,
89-
std=model_spec.stddev_rgb)
63+
_MODEL_INFO = {
64+
"mobilenet_v1_0.75_160_quantized.tflite":
65+
ModelSpecificInfo(
66+
name="MobileNetV1 image classifier",
67+
version="v1",
68+
image_width=160,
69+
image_height=160,
70+
image_min=0,
71+
image_max=255,
72+
mean=[127.5],
73+
std=[127.5],
74+
num_classes=1001)
75+
}
9076

9177

9278
class MetadataPopulatorForImageClassifier(object):
@@ -110,9 +96,10 @@ def _create_metadata(self):
11096
model_meta = _metadata_fb.ModelMetadataT()
11197
model_meta.name = self.model_info.name
11298
model_meta.description = ("Identify the most prominent object in the "
113-
"image from a set of categories.")
99+
"image from a set of %d categories." %
100+
self.model_info.num_classes)
114101
model_meta.version = self.model_info.version
115-
model_meta.author = "TFLite Model Maker"
102+
model_meta.author = "TensorFlow"
116103
model_meta.license = ("Apache License. Version 2.0 "
117104
"http://www.apache.org/licenses/LICENSE-2.0.")
118105

@@ -146,7 +133,7 @@ def _create_metadata(self):
146133
# Creates output info.
147134
output_meta = _metadata_fb.TensorMetadataT()
148135
output_meta.name = "probability"
149-
output_meta.description = "Probabilities of the labels respectively."
136+
output_meta.description = "Probabilities of the %d labels respectively." % self.model_info.num_classes
150137
output_meta.content = _metadata_fb.ContentT()
151138
output_meta.content.content_properties = _metadata_fb.FeaturePropertiesT()
152139
output_meta.content.contentPropertiesType = (
@@ -157,7 +144,7 @@ def _create_metadata(self):
157144
output_meta.stats = output_stats
158145
label_file = _metadata_fb.AssociatedFileT()
159146
label_file.name = os.path.basename(self.label_file_path)
160-
label_file.description = "Labels that %s can recognize." % model_meta.name
147+
label_file.description = "Labels for objects that the model can recognize."
161148
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
162149
output_meta.associatedFiles = [label_file]
163150

@@ -179,3 +166,42 @@ def _populate_metadata(self):
179166
populator.load_metadata_buffer(self.metadata_buf)
180167
populator.load_associated_files([self.label_file_path])
181168
populator.populate()
169+
170+
171+
def main(_):
172+
model_file = FLAGS.model_file
173+
model_basename = os.path.basename(model_file)
174+
if model_basename not in _MODEL_INFO:
175+
raise ValueError(
176+
"The model info for, {0}, is not defined yet.".format(model_basename))
177+
178+
export_model_path = os.path.join(FLAGS.export_directory, model_basename)
179+
180+
# Copies model_file to export_path.
181+
tf.io.gfile.copy(model_file, export_model_path, overwrite=True)
182+
183+
# Generate the metadata objects and put them in the model file
184+
populator = MetadataPopulatorForImageClassifier(
185+
export_model_path, _MODEL_INFO.get(model_basename), FLAGS.label_file)
186+
populator.populate()
187+
188+
# Validate the output model file by reading the metadata and produce
189+
# a json file with the metadata under the export path
190+
displayer = _metadata.MetadataDisplayer.with_model_file(export_model_path)
191+
export_json_file = os.path.join(FLAGS.export_directory,
192+
os.path.splitext(model_basename)[0] + ".json")
193+
json_file = displayer.get_metadata_json()
194+
with open(export_json_file, "w") as f:
195+
f.write(json_file)
196+
197+
print("Finished populating metadata and associated file to the model:")
198+
print(model_file)
199+
print("The metadata json file has been saved to:")
200+
print(export_json_file)
201+
print("The associated file that has been been packed to the model is:")
202+
print(displayer.get_packed_associated_file_list())
203+
204+
205+
if __name__ == "__main__":
206+
define_flags()
207+
app.run(main)

tensorflow_examples/lite/model_maker/core/task/testdata/efficientnet_lite0_metadata.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "efficientnet_lite0",
3-
"description": "Identify the most prominent object in the image from a set of categories.",
3+
"description": "Identify the most prominent object in the image from a set of 3 categories.",
44
"version": "v1",
55
"subgraph_metadata": [
66
{
@@ -40,7 +40,7 @@
4040
"output_tensor_metadata": [
4141
{
4242
"name": "probability",
43-
"description": "Probabilities of the labels respectively.",
43+
"description": "Probabilities of the 3 labels respectively.",
4444
"content": {
4545
"content_properties_type": "FeatureProperties"
4646
},
@@ -55,14 +55,14 @@
5555
"associated_files": [
5656
{
5757
"name": "labels.txt",
58-
"description": "Labels that efficientnet_lite0 can recognize.",
58+
"description": "Labels for objects that the model can recognize.",
5959
"type": "TENSOR_AXIS_LABELS"
6060
}
6161
]
6262
}
6363
]
6464
}
6565
],
66-
"author": "TFLite Model Maker",
66+
"author": "TensorFlow",
6767
"license": "Apache License. Version 2.0 http://www.apache.org/licenses/LICENSE-2.0."
6868
}

0 commit comments

Comments
 (0)