Skip to content

Commit 28ac71e

Browse files
ziyeqinghancopybara-github
authored andcommitted
Add CustomModel to be the base class for QA task.
PiperOrigin-RevId: 305842035
1 parent 2440842 commit 28ac71e

6 files changed

Lines changed: 387 additions & 104 deletions

File tree

tensorflow_examples/lite/model_maker/core/task/classification_model.py

Lines changed: 10 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -11,35 +11,20 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Custom model that is already retained by data."""
14+
"""Custom classification model that is already retained by data."""
1515

1616
from __future__ import absolute_import
1717
from __future__ import division
1818
from __future__ import print_function
1919

20-
import abc
21-
import os
22-
import tempfile
2320

2421
import numpy as np
25-
import tensorflow as tf
26-
from tensorflow_examples.lite.model_maker.core import compat
22+
import tensorflow.compat.v2 as tf
2723
from tensorflow_examples.lite.model_maker.core import model_export_format as mef
24+
from tensorflow_examples.lite.model_maker.core.task import custom_model
2825

29-
DEFAULT_QUANTIZATION_STEPS = 2000
3026

31-
32-
def get_representative_dataset_gen(dataset, num_steps):
33-
34-
def representative_dataset_gen():
35-
"""Generates representative dataset for quantized."""
36-
for image, _ in dataset.take(num_steps):
37-
yield [image]
38-
39-
return representative_dataset_gen
40-
41-
42-
class ClassificationModel(abc.ABC):
27+
class ClassificationModel(custom_model.CustomModel):
4328
""""The abstract base class that represents a Tensorflow classification model."""
4429

4530
def __init__(self, model_export_format, model_spec, index_to_label,
@@ -60,28 +45,11 @@ def __init__(self, model_export_format, model_spec, index_to_label,
6045
raise ValueError('Model export format %s is not supported currently.' %
6146
str(model_export_format))
6247

63-
self.model_export_format = model_export_format
64-
self.model_spec = model_spec
48+
super(ClassificationModel, self).__init__(model_export_format, model_spec,
49+
shuffle)
6550
self.index_to_label = index_to_label
6651
self.num_classes = num_classes
67-
self.shuffle = shuffle
6852
self.train_whole_model = train_whole_model
69-
self.model = None
70-
71-
@abc.abstractmethod
72-
def preprocess(self, sample_data, label):
73-
return
74-
75-
@abc.abstractmethod
76-
def train(self, train_data, validation_data=None, **kwargs):
77-
return
78-
79-
@abc.abstractmethod
80-
def export(self, **kwargs):
81-
return
82-
83-
def summary(self):
84-
self.model.summary()
8553

8654
def evaluate(self, data, batch_size=32):
8755
"""Evaluates the model.
@@ -122,31 +90,6 @@ def predict_top_k(self, data, k=1, batch_size=32):
12290

12391
return label_prob
12492

125-
def _gen_dataset(self,
126-
data,
127-
batch_size=32,
128-
is_training=True,
129-
input_pipeline_context=None):
130-
"""Generates training / validation dataset."""
131-
# The dataset is always sharded by number of hosts.
132-
# num_input_pipelines is the number of hosts rather than number of cores.
133-
ds = data.dataset
134-
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
135-
ds = ds.shard(input_pipeline_context.num_input_pipelines,
136-
input_pipeline_context.input_pipeline_id)
137-
138-
ds = ds.map(
139-
self.preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
140-
141-
if is_training:
142-
if self.shuffle:
143-
ds = ds.shuffle(buffer_size=min(data.size, 100))
144-
ds = ds.repeat()
145-
146-
ds = ds.batch(batch_size)
147-
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
148-
return ds
149-
15093
def _export_tflite(self,
15194
tflite_filename,
15295
label_filename,
@@ -164,41 +107,11 @@ def _export_tflite(self,
164107
representative_data: Representative data used for post-training
165108
quantization. Used only if `quantized` is True.
166109
"""
167-
temp_dir = None
168-
if compat.get_tf_behavior() == 1:
169-
temp_dir = tempfile.TemporaryDirectory()
170-
save_path = os.path.join(temp_dir.name, 'saved_model')
171-
self.model.save(save_path, include_optimizer=False, save_format='tf')
172-
converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(save_path)
173-
else:
174-
converter = tf.lite.TFLiteConverter.from_keras_model(self.model)
175-
176-
if quantized:
177-
if quantization_steps is None:
178-
quantization_steps = DEFAULT_QUANTIZATION_STEPS
179-
if representative_data is None:
180-
raise ValueError(
181-
'representative_data couldn\'t be None if model is quantized.')
182-
ds = self._gen_dataset(
183-
representative_data, batch_size=1, is_training=False)
184-
converter.representative_dataset = tf.lite.RepresentativeDataset(
185-
get_representative_dataset_gen(ds, quantization_steps))
186-
187-
converter.optimizations = [tf.lite.Optimize.DEFAULT]
188-
converter.inference_input_type = tf.uint8
189-
converter.inference_output_type = tf.uint8
190-
converter.target_spec.supported_ops = [
191-
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
192-
]
193-
tflite_model = converter.convert()
194-
if temp_dir:
195-
temp_dir.cleanup()
196-
197-
with tf.io.gfile.GFile(tflite_filename, 'wb') as f:
198-
f.write(tflite_model)
110+
super(ClassificationModel,
111+
self)._export_tflite(tflite_filename, quantized, quantization_steps,
112+
representative_data)
199113

200114
with tf.io.gfile.GFile(label_filename, 'w') as f:
201115
f.write('\n'.join(self.index_to_label))
202116

203-
tf.compat.v1.logging.info('Export to tflite model %s, saved labels in %s.',
204-
tflite_filename, label_filename)
117+
tf.compat.v1.logging.info('Saved labels in %s.', label_filename)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License');
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an 'AS IS' BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
import tensorflow.compat.v2 as tf
20+
from tensorflow_examples.lite.model_maker.core import model_export_format as mef
21+
from tensorflow_examples.lite.model_maker.core import test_util
22+
from tensorflow_examples.lite.model_maker.core.task import classification_model
23+
24+
25+
class MockClassificationModel(classification_model.ClassificationModel):
26+
27+
def train(self, train_data, validation_data=None, **kwargs):
28+
pass
29+
30+
def export(self, **kwargs):
31+
pass
32+
33+
def evaluate(self, data, **kwargs):
34+
pass
35+
36+
37+
class ClassificationModelTest(tf.test.TestCase):
38+
39+
def test_predict_top_k(self):
40+
input_shape = [24, 24, 3]
41+
num_classes = 2
42+
model = MockClassificationModel(
43+
model_export_format=mef.ModelExportFormat.TFLITE,
44+
model_spec=None,
45+
index_to_label=['pos', 'neg'],
46+
num_classes=2,
47+
train_whole_model=False,
48+
shuffle=False)
49+
model.model = test_util.build_model(input_shape, num_classes)
50+
data = test_util.get_dataloader(2, input_shape, num_classes)
51+
52+
topk_results = model.predict_top_k(data, k=2, batch_size=1)
53+
for topk_result in topk_results:
54+
top1_result, top2_result = topk_result[0], topk_result[1]
55+
top1_label, top1_prob = top1_result[0], top1_result[1]
56+
top2_label, top2_prob = top2_result[0], top2_result[1]
57+
58+
self.assertIn(top1_label, model.index_to_label)
59+
self.assertIn(top2_label, model.index_to_label)
60+
self.assertNotEqual(top1_label, top2_label)
61+
62+
self.assertLessEqual(top1_prob, 1)
63+
self.assertGreaterEqual(top1_prob, top2_prob)
64+
self.assertGreaterEqual(top2_prob, 0)
65+
66+
self.assertEqual(top1_prob + top2_prob, 1.0)
67+
68+
69+
if __name__ == '__main__':
70+
tf.test.main()
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Base custom model that is already retained by data."""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import abc
21+
import os
22+
import tempfile
23+
24+
import tensorflow.compat.v2 as tf
25+
from tensorflow_examples.lite.model_maker.core import compat
26+
from tensorflow_examples.lite.model_maker.core import model_export_format as mef
27+
28+
DEFAULT_QUANTIZATION_STEPS = 2000
29+
30+
31+
def get_representative_dataset_gen(dataset, num_steps):
32+
33+
def representative_dataset_gen():
34+
"""Generates representative dataset for quantized."""
35+
for image, _ in dataset.take(num_steps):
36+
yield [image]
37+
38+
return representative_dataset_gen
39+
40+
41+
class CustomModel(abc.ABC):
42+
""""The abstract base class that represents a Tensorflow classification model."""
43+
44+
def __init__(self, model_export_format, model_spec, shuffle):
45+
"""Initialize a instance with data, deploy mode and other related parameters.
46+
47+
Args:
48+
model_export_format: Model export format such as saved_model / tflite.
49+
model_spec: Specification for the model.
50+
shuffle: Whether the data should be shuffled.
51+
"""
52+
if model_export_format != mef.ModelExportFormat.TFLITE:
53+
raise ValueError('Model export format %s is not supported currently.' %
54+
str(model_export_format))
55+
56+
self.model_export_format = model_export_format
57+
self.model_spec = model_spec
58+
self.shuffle = shuffle
59+
self.model = None
60+
61+
def preprocess(self, sample_data, label):
62+
"""Preprocess the data."""
63+
# TODO(yuqili): remove this method once preprocess for image classifier is
64+
# also moved to DataLoader part.
65+
return sample_data, label
66+
67+
@abc.abstractmethod
68+
def train(self, train_data, validation_data=None, **kwargs):
69+
return
70+
71+
@abc.abstractmethod
72+
def export(self, **kwargs):
73+
return
74+
75+
def summary(self):
76+
self.model.summary()
77+
78+
@abc.abstractmethod
79+
def evaluate(self, data, **kwargs):
80+
return
81+
82+
def _gen_dataset(self,
83+
data,
84+
batch_size=32,
85+
is_training=True,
86+
input_pipeline_context=None):
87+
"""Generates training / validation dataset."""
88+
# The dataset is always sharded by number of hosts.
89+
# num_input_pipelines is the number of hosts rather than number of cores.
90+
ds = data.dataset
91+
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
92+
ds = ds.shard(input_pipeline_context.num_input_pipelines,
93+
input_pipeline_context.input_pipeline_id)
94+
95+
ds = ds.map(
96+
self.preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
97+
98+
if is_training:
99+
if self.shuffle:
100+
ds = ds.shuffle(buffer_size=min(data.size, 100))
101+
ds = ds.repeat()
102+
103+
ds = ds.batch(batch_size)
104+
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
105+
return ds
106+
107+
def _export_tflite(self,
108+
tflite_filename,
109+
quantized=False,
110+
quantization_steps=None,
111+
representative_data=None):
112+
"""Converts the retrained model to tflite format and saves it.
113+
114+
Args:
115+
tflite_filename: File name to save tflite model.
116+
quantized: boolean, if True, save quantized model.
117+
quantization_steps: Number of post-training quantization calibration steps
118+
to run. Used only if `quantized` is True.
119+
representative_data: Representative data used for post-training
120+
quantization. Used only if `quantized` is True.
121+
"""
122+
temp_dir = None
123+
if compat.get_tf_behavior() == 1:
124+
temp_dir = tempfile.TemporaryDirectory()
125+
save_path = os.path.join(temp_dir.name, 'saved_model')
126+
self.model.save(save_path, include_optimizer=False, save_format='tf')
127+
converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(save_path)
128+
else:
129+
converter = tf.lite.TFLiteConverter.from_keras_model(self.model)
130+
131+
if quantized:
132+
if quantization_steps is None:
133+
quantization_steps = DEFAULT_QUANTIZATION_STEPS
134+
if representative_data is None:
135+
raise ValueError(
136+
'representative_data couldn\'t be None if model is quantized.')
137+
ds = self._gen_dataset(
138+
representative_data, batch_size=1, is_training=False)
139+
converter.representative_dataset = tf.lite.RepresentativeDataset(
140+
get_representative_dataset_gen(ds, quantization_steps))
141+
142+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
143+
converter.inference_input_type = tf.uint8
144+
converter.inference_output_type = tf.uint8
145+
converter.target_spec.supported_ops = [
146+
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
147+
]
148+
tflite_model = converter.convert()
149+
if temp_dir:
150+
temp_dir.cleanup()
151+
152+
with tf.io.gfile.GFile(tflite_filename, 'wb') as f:
153+
f.write(tflite_model)
154+
155+
tf.compat.v1.logging.info('Export to tflite model in %s.', tflite_filename)

0 commit comments

Comments
 (0)