1717from __future__ import division
1818from __future__ import print_function
1919
20- import tensorflow as tf
20+ import tensorflow . compat . v2 as tf
2121
2222from tensorflow_examples .lite .model_maker .core import compat
2323from tensorflow_examples .lite .model_maker .core import model_export_format as mef
2626from tensorflow_examples .lite .model_maker .core .task import image_preprocessing
2727from tensorflow_examples .lite .model_maker .core .task import metadata
2828from tensorflow_examples .lite .model_maker .core .task import model_spec as ms
29+ from tensorflow_examples .lite .model_maker .core .task import train_image_classifier_lib
2930
30- from tensorflow_hub .tools .make_image_classifier import make_image_classifier_lib as lib
31+ from tensorflow_hub .tools .make_image_classifier import make_image_classifier_lib as hub_lib
32+
33+
34+ def get_hub_lib_hparams (** kwargs ):
35+ """Gets the hyperparameters for the tensorflow hub's library."""
36+ hparams = hub_lib .get_default_hparams ()
37+ return train_image_classifier_lib .add_params (hparams , ** kwargs )
3138
3239
3340def create (train_data ,
@@ -41,7 +48,10 @@ def create(train_data,
4148 dropout_rate = None ,
4249 learning_rate = None ,
4350 momentum = None ,
44- use_augmentation = False ):
51+ use_augmentation = False ,
52+ use_hub_library = True ,
53+ warmup_steps = None ,
54+ model_dir = None ):
4555 """Loads data and retrains the model based on data for image classification.
4656
4757 Args:
@@ -50,36 +60,51 @@ def create(train_data,
5060 model_spec: Specification for the model.
5161 shuffle: Whether the data should be shuffled.
5262 validation_data: Validation data. If None, skips validation process.
53- batch_size: Number of samples per training step.
63+ batch_size: Number of samples per training step. If `use_hub_library` is
64+ False, it represents the base learning rate when train batch size is 256
65+ and it's linear to the batch size.
5466 epochs: Number of epochs for training.
5567 train_whole_model: If true, the Hub module is trained together with the
5668 classification layer on top. Otherwise, only train the top classification
5769 layer.
58- dropout_rate: the rate for dropout.
59- learning_rate: a Python float forwarded to the optimizer.
60- momentum: a Python float forwarded to the optimizer.
70+ dropout_rate: The rate for dropout.
71+ learning_rate: Base learning rate when train batch size is 256. Linear to
72+ the batch size.
73+ momentum: a Python float forwarded to the optimizer. Only used when
74+ `use_hub_library` is True.
6175 use_augmentation: Use data augmentation for preprocessing.
76+ use_hub_library: Use `make_image_classifier_lib` from tensorflow hub to
77+ retrain the model.
78+ warmup_steps: Number of warmup steps for warmup schedule on learning rate.
79+ If None, the default warmup_steps is used which is the total training
80+ steps in two epochs. Only used when `use_hub_library` is False.
81+ model_dir: The location of the model checkpoint files. Only used when
82+ `use_hub_library` is False.
83+
6284 Returns:
6385 An instance of ImageClassifier class.
6486 """
6587 if compat .get_tf_behavior () not in model_spec .compat_tf_versions :
6688 raise ValueError ('Incompatible versions. Expect {}, but got {}.' .format (
6789 model_spec .compat_tf_versions , compat .get_tf_behavior ()))
6890
69- # The hyperparameters for make_image_classifier by tensorflow hub.
70- hparams = lib .get_default_hparams ()
71- if batch_size is not None :
72- hparams = hparams ._replace (batch_size = batch_size )
73- if epochs is not None :
74- hparams = hparams ._replace (train_epochs = epochs )
75- if train_whole_model is not None :
76- hparams = hparams ._replace (do_fine_tuning = train_whole_model )
77- if dropout_rate is not None :
78- hparams = hparams ._replace (dropout_rate = dropout_rate )
79- if learning_rate is not None :
80- hparams = hparams ._replace (learning_rate = learning_rate )
81- if momentum is not None :
82- hparams = hparams ._replace (momentum = momentum )
91+ if use_hub_library :
92+ hparams = get_hub_lib_hparams (
93+ batch_size = batch_size ,
94+ train_epochs = epochs ,
95+ do_fine_tuning = train_whole_model ,
96+ dropout_rate = dropout_rate ,
97+ learning_rate = learning_rate ,
98+ momentum = momentum )
99+ else :
100+ hparams = train_image_classifier_lib .HParams .get_hparams (
101+ batch_size = batch_size ,
102+ train_epochs = epochs ,
103+ do_fine_tuning = train_whole_model ,
104+ dropout_rate = dropout_rate ,
105+ learning_rate = learning_rate ,
106+ warmup_steps = warmup_steps ,
107+ model_dir = model_dir )
83108
84109 image_classifier = ImageClassifier (
85110 model_export_format ,
@@ -105,7 +130,7 @@ def __init__(self,
105130 index_to_label ,
106131 num_classes ,
107132 shuffle = True ,
108- hparams = lib .get_default_hparams (),
133+ hparams = hub_lib .get_default_hparams (),
109134 use_augmentation = False ):
110135 """Init function for ImageClassifier class.
111136
@@ -118,6 +143,8 @@ def __init__(self,
118143 hparams: A namedtuple of hyperparameters. This function expects
119144 .dropout_rate: The fraction of the input units to drop, used in dropout
120145 layer.
146+ .do_fine_tuning: If true, the Hub module is trained together with the
147+ classification layer on top.
121148 use_augmentation: Use data augmentation for preprocessing.
122149 """
123150 super (ImageClassifier ,
@@ -138,21 +165,18 @@ def _create_model(self, hparams=None):
138165
139166 module_layer = hub_loader .HubKerasLayerV1V2 (
140167 self .model_spec .uri , trainable = hparams .do_fine_tuning )
141- return lib .build_model (module_layer , hparams ,
142- self .model_spec .input_image_shape , self .num_classes )
168+ return hub_lib .build_model (module_layer , hparams ,
169+ self .model_spec .input_image_shape ,
170+ self .num_classes )
143171
144172 def train (self , train_data , validation_data = None , hparams = None ):
145173 """Feeds the training data for training.
146174
147175 Args:
148176 train_data: Training data.
149177 validation_data: Validation data. If None, skips validation process.
150- hparams: A namedtuple of hyperparameters. This function expects
151- .train_epochs: a Python integer with the number of passes over the
152- training dataset;
153- .learning_rate: a Python float forwarded to the optimizer;
154- .momentum: a Python float forwarded to the optimizer;
155- .batch_size: a Python integer, number of samples per training step.
178+ hparams: An instance of hub_lib.HParams or
179+ train_image_classifier_lib.HParams. Anamedtuple of hyperparameters.
156180
157181 Returns:
158182 The tf.keras.callbacks.History object returned by tf.keras.Model.fit*().
@@ -170,7 +194,11 @@ def train(self, train_data, validation_data=None, hparams=None):
170194 validation_data , hparams .batch_size , is_training = False )
171195 validation_size = validation_data .size
172196 validation_data_and_size = (validation_ds , validation_size )
197+
173198 # Trains the models.
199+ lib = hub_lib
200+ if isinstance (hparams , train_image_classifier_lib .HParams ):
201+ lib = train_image_classifier_lib
174202 return lib .train_model (self .model , hparams , train_data_and_size ,
175203 validation_data_and_size )
176204
0 commit comments