1- # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
1+ # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- """Metadata populator for TFLite models."""
14+ # ==============================================================================
15+ """Writes metadata and label file to the image classifier models."""
1516
1617from __future__ import absolute_import
1718from __future__ import division
1819from __future__ import print_function
1920
2021import os
2122
23+ from absl import app
24+ from absl import flags
2225import tensorflow as tf
23- from tensorflow_examples .lite .model_maker .core .task import model_spec as ms
24-
25- TFLITE_SUPPORT_TOOLS_INSTALLED = True
26-
27- try :
28- # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
29- import flatbuffers
30- from tflite_support import metadata as _metadata
31- from tflite_support import metadata_schema_py_generated as _metadata_fb
32- # pylint: enable=g-direct-tensorflow-import,g-import-not-at-top
33- except ImportError :
34- tf .compat .v1 .logging .warning ("Needs to install tflite-support package." )
35- TFLITE_SUPPORT_TOOLS_INSTALLED = False
36-
37-
38- def export_metadata_json_file (tflite_file ):
39- """Exports metadata to json file."""
40- displayer = _metadata .MetadataDisplayer .with_model_file (tflite_file )
41- export_directory = os .path .dirname (tflite_file )
42- try :
43- json_file = os .path .join (
44- export_directory ,
45- os .path .splitext (os .path .basename (tflite_file ))[0 ] + ".json" )
46- with open (json_file , "w" ) as f :
47- content = displayer .get_metadata_json ()
48- f .write (content )
49- except AttributeError :
50- # TODO(yuqili): Remove this line once the API is stable.
51- displayer .export_metadata_json_file (export_directory )
52-
53-
54- class ImageModelSpecificInfo (object ):
26+
27+ import flatbuffers
28+ # pylint: disable=g-direct-tensorflow-import
29+ from tflite_support import metadata as _metadata
30+ from tflite_support import metadata_schema_py_generated as _metadata_fb
31+ # pylint: enable=g-direct-tensorflow-import
32+
33+ FLAGS = flags .FLAGS
34+
35+
36+ def define_flags ():
37+ flags .DEFINE_string ("model_file" , None ,
38+ "Path and file name to the TFLite model file." )
39+ flags .DEFINE_string ("label_file" , None , "Path to the label file." )
40+ flags .DEFINE_string ("export_directory" , None ,
41+ "Path to save the TFLite model files with metadata." )
42+ flags .mark_flag_as_required ("model_file" )
43+ flags .mark_flag_as_required ("label_file" )
44+ flags .mark_flag_as_required ("export_directory" )
45+
46+
47+ class ModelSpecificInfo (object ):
5548 """Holds information that is specificly tied to an image classifier."""
5649
57- def __init__ (self ,
58- name ,
59- version ,
60- image_width ,
61- image_height ,
62- mean ,
63- std ,
64- image_min = 0 ,
65- image_max = 1 ):
50+ def __init__ (self , name , version , image_width , image_height , image_min ,
51+ image_max , mean , std , num_classes ):
6652 self .name = name
6753 self .version = version
6854 self .image_width = image_width
6955 self .image_height = image_height
70- self .mean = mean
71- self .std = std
7256 self .image_min = image_min
7357 self .image_max = image_max
58+ self .mean = mean
59+ self .std = std
60+ self .num_classes = num_classes
7461
7562
76- def get_model_info (model_spec , quantized = False , version = "v1" ):
77- if not isinstance (model_spec , ms .ImageModelSpec ):
78- raise ValueError ("Currently only support models for image classification." )
79-
80- name = model_spec .name
81- if quantized :
82- name += "_quantized"
83- return ImageModelSpecificInfo (
84- model_spec .name ,
85- version ,
86- image_width = model_spec .input_image_shape [1 ],
87- image_height = model_spec .input_image_shape [0 ],
88- mean = model_spec .mean_rgb ,
89- std = model_spec .stddev_rgb )
63+ _MODEL_INFO = {
64+ "mobilenet_v1_0.75_160_quantized.tflite" :
65+ ModelSpecificInfo (
66+ name = "MobileNetV1 image classifier" ,
67+ version = "v1" ,
68+ image_width = 160 ,
69+ image_height = 160 ,
70+ image_min = 0 ,
71+ image_max = 255 ,
72+ mean = [127.5 ],
73+ std = [127.5 ],
74+ num_classes = 1001 )
75+ }
9076
9177
9278class MetadataPopulatorForImageClassifier (object ):
@@ -110,9 +96,10 @@ def _create_metadata(self):
11096 model_meta = _metadata_fb .ModelMetadataT ()
11197 model_meta .name = self .model_info .name
11298 model_meta .description = ("Identify the most prominent object in the "
113- "image from a set of categories." )
99+ "image from a set of %d categories." %
100+ self .model_info .num_classes )
114101 model_meta .version = self .model_info .version
115- model_meta .author = "TFLite Model Maker "
102+ model_meta .author = "TensorFlow "
116103 model_meta .license = ("Apache License. Version 2.0 "
117104 "http://www.apache.org/licenses/LICENSE-2.0." )
118105
@@ -146,7 +133,7 @@ def _create_metadata(self):
146133 # Creates output info.
147134 output_meta = _metadata_fb .TensorMetadataT ()
148135 output_meta .name = "probability"
149- output_meta .description = "Probabilities of the labels respectively."
136+ output_meta .description = "Probabilities of the %d labels respectively." % self . model_info . num_classes
150137 output_meta .content = _metadata_fb .ContentT ()
151138 output_meta .content .content_properties = _metadata_fb .FeaturePropertiesT ()
152139 output_meta .content .contentPropertiesType = (
@@ -157,7 +144,7 @@ def _create_metadata(self):
157144 output_meta .stats = output_stats
158145 label_file = _metadata_fb .AssociatedFileT ()
159146 label_file .name = os .path .basename (self .label_file_path )
160- label_file .description = "Labels that %s can recognize." % model_meta . name
147+ label_file .description = "Labels for objects that the model can recognize."
161148 label_file .type = _metadata_fb .AssociatedFileType .TENSOR_AXIS_LABELS
162149 output_meta .associatedFiles = [label_file ]
163150
@@ -179,3 +166,42 @@ def _populate_metadata(self):
179166 populator .load_metadata_buffer (self .metadata_buf )
180167 populator .load_associated_files ([self .label_file_path ])
181168 populator .populate ()
169+
170+
171+ def main (_ ):
172+ model_file = FLAGS .model_file
173+ model_basename = os .path .basename (model_file )
174+ if model_basename not in _MODEL_INFO :
175+ raise ValueError (
176+ "The model info for, {0}, is not defined yet." .format (model_basename ))
177+
178+ export_model_path = os .path .join (FLAGS .export_directory , model_basename )
179+
180+ # Copies model_file to export_path.
181+ tf .io .gfile .copy (model_file , export_model_path , overwrite = True )
182+
183+ # Generate the metadata objects and put them in the model file
184+ populator = MetadataPopulatorForImageClassifier (
185+ export_model_path , _MODEL_INFO .get (model_basename ), FLAGS .label_file )
186+ populator .populate ()
187+
188+ # Validate the output model file by reading the metadata and produce
189+ # a json file with the metadata under the export path
190+ displayer = _metadata .MetadataDisplayer .with_model_file (export_model_path )
191+ export_json_file = os .path .join (FLAGS .export_directory ,
192+ os .path .splitext (model_basename )[0 ] + ".json" )
193+ json_file = displayer .get_metadata_json ()
194+ with open (export_json_file , "w" ) as f :
195+ f .write (json_file )
196+
197+ print ("Finished populating metadata and associated file to the model:" )
198+ print (model_file )
199+ print ("The metadata json file has been saved to:" )
200+ print (export_json_file )
201+ print ("The associated file that has been been packed to the model is:" )
202+ print (displayer .get_packed_associated_file_list ())
203+
204+
205+ if __name__ == "__main__" :
206+ define_flags ()
207+ app .run (main )
0 commit comments