Skip to content

Commit 9ee6190

Browse files
ziyeqinghancopybara-github
authored andcommitted
Change optimizer for image classification in TFLite Model Maker
PiperOrigin-RevId: 306180023
1 parent ff7d01d commit 9ee6190

5 files changed

Lines changed: 293 additions & 31 deletions

File tree

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Functions and classes related to optimization (weight updates)."""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow.compat.v2 as tf
21+
22+
23+
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
24+
"""Applies a warmup schedule on a given learning rate decay schedule."""
25+
26+
def __init__(self,
27+
initial_learning_rate,
28+
decay_schedule_fn,
29+
warmup_steps,
30+
name=None):
31+
super(WarmUp, self).__init__()
32+
self.initial_learning_rate = initial_learning_rate
33+
self.warmup_steps = warmup_steps
34+
self.decay_schedule_fn = decay_schedule_fn
35+
self.name = name
36+
37+
def __call__(self, step):
38+
with tf.name_scope(self.name or 'WarmUp') as name:
39+
# Implements linear warmup. i.e., if global_step < warmup_steps, the
40+
# learning rate will be `global_step/num_warmup_steps * init_lr`.
41+
global_step_float = tf.cast(step, tf.float32)
42+
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
43+
warmup_percent_done = global_step_float / warmup_steps_float
44+
warmup_learning_rate = self.initial_learning_rate * warmup_percent_done
45+
return tf.cond(
46+
global_step_float < warmup_steps_float,
47+
lambda: warmup_learning_rate,
48+
lambda: self.decay_schedule_fn(step),
49+
name=name)
50+
51+
def get_config(self):
52+
return {
53+
'initial_learning_rate': self.initial_learning_rate,
54+
'decay_schedule_fn': self.decay_schedule_fn,
55+
'warmup_steps': self.warmup_steps,
56+
'name': self.name
57+
}

tensorflow_examples/lite/model_maker/core/task/image_classifier.py

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20-
import tensorflow as tf
20+
import tensorflow.compat.v2 as tf
2121

2222
from tensorflow_examples.lite.model_maker.core import compat
2323
from tensorflow_examples.lite.model_maker.core import model_export_format as mef
@@ -26,8 +26,15 @@
2626
from tensorflow_examples.lite.model_maker.core.task import image_preprocessing
2727
from tensorflow_examples.lite.model_maker.core.task import metadata
2828
from tensorflow_examples.lite.model_maker.core.task import model_spec as ms
29+
from tensorflow_examples.lite.model_maker.core.task import train_image_classifier_lib
2930

30-
from tensorflow_hub.tools.make_image_classifier import make_image_classifier_lib as lib
31+
from tensorflow_hub.tools.make_image_classifier import make_image_classifier_lib as hub_lib
32+
33+
34+
def get_hub_lib_hparams(**kwargs):
35+
"""Gets the hyperparameters for the tensorflow hub's library."""
36+
hparams = hub_lib.get_default_hparams()
37+
return train_image_classifier_lib.add_params(hparams, **kwargs)
3138

3239

3340
def create(train_data,
@@ -41,7 +48,10 @@ def create(train_data,
4148
dropout_rate=None,
4249
learning_rate=None,
4350
momentum=None,
44-
use_augmentation=False):
51+
use_augmentation=False,
52+
use_hub_library=True,
53+
warmup_steps=None,
54+
model_dir=None):
4555
"""Loads data and retrains the model based on data for image classification.
4656
4757
Args:
@@ -50,36 +60,51 @@ def create(train_data,
5060
model_spec: Specification for the model.
5161
shuffle: Whether the data should be shuffled.
5262
validation_data: Validation data. If None, skips validation process.
53-
batch_size: Number of samples per training step.
63+
batch_size: Number of samples per training step. If `use_hub_library` is
64+
False, it represents the base learning rate when train batch size is 256
65+
and it's linear to the batch size.
5466
epochs: Number of epochs for training.
5567
train_whole_model: If true, the Hub module is trained together with the
5668
classification layer on top. Otherwise, only train the top classification
5769
layer.
58-
dropout_rate: the rate for dropout.
59-
learning_rate: a Python float forwarded to the optimizer.
60-
momentum: a Python float forwarded to the optimizer.
70+
dropout_rate: The rate for dropout.
71+
learning_rate: Base learning rate when train batch size is 256. Linear to
72+
the batch size.
73+
momentum: a Python float forwarded to the optimizer. Only used when
74+
`use_hub_library` is True.
6175
use_augmentation: Use data augmentation for preprocessing.
76+
use_hub_library: Use `make_image_classifier_lib` from tensorflow hub to
77+
retrain the model.
78+
warmup_steps: Number of warmup steps for warmup schedule on learning rate.
79+
If None, the default warmup_steps is used which is the total training
80+
steps in two epochs. Only used when `use_hub_library` is False.
81+
model_dir: The location of the model checkpoint files. Only used when
82+
`use_hub_library` is False.
83+
6284
Returns:
6385
An instance of ImageClassifier class.
6486
"""
6587
if compat.get_tf_behavior() not in model_spec.compat_tf_versions:
6688
raise ValueError('Incompatible versions. Expect {}, but got {}.'.format(
6789
model_spec.compat_tf_versions, compat.get_tf_behavior()))
6890

69-
# The hyperparameters for make_image_classifier by tensorflow hub.
70-
hparams = lib.get_default_hparams()
71-
if batch_size is not None:
72-
hparams = hparams._replace(batch_size=batch_size)
73-
if epochs is not None:
74-
hparams = hparams._replace(train_epochs=epochs)
75-
if train_whole_model is not None:
76-
hparams = hparams._replace(do_fine_tuning=train_whole_model)
77-
if dropout_rate is not None:
78-
hparams = hparams._replace(dropout_rate=dropout_rate)
79-
if learning_rate is not None:
80-
hparams = hparams._replace(learning_rate=learning_rate)
81-
if momentum is not None:
82-
hparams = hparams._replace(momentum=momentum)
91+
if use_hub_library:
92+
hparams = get_hub_lib_hparams(
93+
batch_size=batch_size,
94+
train_epochs=epochs,
95+
do_fine_tuning=train_whole_model,
96+
dropout_rate=dropout_rate,
97+
learning_rate=learning_rate,
98+
momentum=momentum)
99+
else:
100+
hparams = train_image_classifier_lib.HParams.get_hparams(
101+
batch_size=batch_size,
102+
train_epochs=epochs,
103+
do_fine_tuning=train_whole_model,
104+
dropout_rate=dropout_rate,
105+
learning_rate=learning_rate,
106+
warmup_steps=warmup_steps,
107+
model_dir=model_dir)
83108

84109
image_classifier = ImageClassifier(
85110
model_export_format,
@@ -105,7 +130,7 @@ def __init__(self,
105130
index_to_label,
106131
num_classes,
107132
shuffle=True,
108-
hparams=lib.get_default_hparams(),
133+
hparams=hub_lib.get_default_hparams(),
109134
use_augmentation=False):
110135
"""Init function for ImageClassifier class.
111136
@@ -118,6 +143,8 @@ def __init__(self,
118143
hparams: A namedtuple of hyperparameters. This function expects
119144
.dropout_rate: The fraction of the input units to drop, used in dropout
120145
layer.
146+
.do_fine_tuning: If true, the Hub module is trained together with the
147+
classification layer on top.
121148
use_augmentation: Use data augmentation for preprocessing.
122149
"""
123150
super(ImageClassifier,
@@ -138,21 +165,18 @@ def _create_model(self, hparams=None):
138165

139166
module_layer = hub_loader.HubKerasLayerV1V2(
140167
self.model_spec.uri, trainable=hparams.do_fine_tuning)
141-
return lib.build_model(module_layer, hparams,
142-
self.model_spec.input_image_shape, self.num_classes)
168+
return hub_lib.build_model(module_layer, hparams,
169+
self.model_spec.input_image_shape,
170+
self.num_classes)
143171

144172
def train(self, train_data, validation_data=None, hparams=None):
145173
"""Feeds the training data for training.
146174
147175
Args:
148176
train_data: Training data.
149177
validation_data: Validation data. If None, skips validation process.
150-
hparams: A namedtuple of hyperparameters. This function expects
151-
.train_epochs: a Python integer with the number of passes over the
152-
training dataset;
153-
.learning_rate: a Python float forwarded to the optimizer;
154-
.momentum: a Python float forwarded to the optimizer;
155-
.batch_size: a Python integer, number of samples per training step.
178+
hparams: An instance of hub_lib.HParams or
179+
train_image_classifier_lib.HParams. Anamedtuple of hyperparameters.
156180
157181
Returns:
158182
The tf.keras.callbacks.History object returned by tf.keras.Model.fit*().
@@ -170,7 +194,11 @@ def train(self, train_data, validation_data=None, hparams=None):
170194
validation_data, hparams.batch_size, is_training=False)
171195
validation_size = validation_data.size
172196
validation_data_and_size = (validation_ds, validation_size)
197+
173198
# Trains the models.
199+
lib = hub_lib
200+
if isinstance(hparams, train_image_classifier_lib.HParams):
201+
lib = train_image_classifier_lib
174202
return lib.train_model(self.model, hparams, train_data_and_size,
175203
validation_data_and_size)
176204

tensorflow_examples/lite/model_maker/core/task/image_classifier_test.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import os
2121

2222
import numpy as np
23-
import tensorflow as tf
23+
import tensorflow.compat.v2 as tf
2424
from tensorflow_examples.lite.model_maker.core import compat
2525
from tensorflow_examples.lite.model_maker.core import model_export_format as mef
2626
from tensorflow_examples.lite.model_maker.core import test_util
@@ -84,6 +84,19 @@ def test_mobilenetv2_model_create_v1_incompatible(self):
8484
_ = image_classifier.create(self.train_data, mef.ModelExportFormat.TFLITE,
8585
model_spec.mobilenet_v2_spec)
8686

87+
@test_util.test_in_tf_1and2
88+
def test_efficientnetlite0_model_with_model_maker_retraining_lib(self):
89+
model = image_classifier.create(
90+
self.train_data,
91+
mef.ModelExportFormat.TFLITE,
92+
model_spec.efficientnet_lite0_spec,
93+
epochs=2,
94+
batch_size=4,
95+
shuffle=True,
96+
use_hub_library=False)
97+
self._test_accuracy(model)
98+
self._test_export_to_tflite(model)
99+
87100
@test_util.test_in_tf_1and2
88101
def test_efficientnetlite0_model(self):
89102
model = image_classifier.create(

0 commit comments

Comments
 (0)