Skip to content

Commit 104d7ce

Browse files
ziyeqinghancopybara-github
authored andcommitted
Add unittest for bert text classifier.
PiperOrigin-RevId: 302816228
1 parent d75d95b commit 104d7ce

1 file changed

Lines changed: 20 additions & 3 deletions

File tree

tensorflow_examples/lite/model_maker/core/task/text_classifier_test.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import os
2020
import numpy as np
21-
import tensorflow as tf
21+
import tensorflow.compat.v2 as tf
2222

2323
from tensorflow_examples.lite.model_maker.core import compat
2424
from tensorflow_examples.lite.model_maker.core import model_export_format as mef
@@ -67,6 +67,23 @@ def test_average_wordvec_model_create_v1_incompatible(self):
6767
model_spec=model_spec,
6868
)
6969

70+
@test_util.test_in_tf_2
71+
def test_bert_model(self):
72+
model_spec = ms.BertModelSpec(seq_len=2, trainable=False)
73+
all_data = text_dataloader.TextClassifierDataLoader.from_folder(
74+
self.text_dir, model_spec=model_spec)
75+
# Splits data, 90% data for training, 10% for testing
76+
self.train_data, self.test_data = all_data.split(0.9)
77+
78+
model = text_classifier.create(
79+
self.train_data,
80+
mef.ModelExportFormat.TFLITE,
81+
model_spec=model_spec,
82+
epochs=1,
83+
batch_size=1,
84+
shuffle=True)
85+
self._test_accuracy(model, 0.5)
86+
7087
@test_util.test_in_tf_2
7188
def test_average_wordvec_model(self):
7289
model_spec = ms.AverageWordVecModelSpec(seq_len=2)
@@ -86,9 +103,9 @@ def test_average_wordvec_model(self):
86103
self._test_export_to_tflite(model)
87104
self._test_predict_top_k(model)
88105

89-
def _test_accuracy(self, model):
106+
def _test_accuracy(self, model, threshold=1.0):
90107
_, accuracy = model.evaluate(self.test_data)
91-
self.assertEqual(accuracy, 1.0)
108+
self.assertEqual(accuracy, threshold)
92109

93110
def _test_predict_top_k(self, model):
94111
topk = model.predict_top_k(self.test_data, batch_size=4)

0 commit comments

Comments
 (0)