1818
1919import os
2020import numpy as np
21- import tensorflow as tf
21+ import tensorflow . compat . v2 as tf
2222
2323from tensorflow_examples .lite .model_maker .core import compat
2424from tensorflow_examples .lite .model_maker .core import model_export_format as mef
@@ -67,6 +67,23 @@ def test_average_wordvec_model_create_v1_incompatible(self):
6767 model_spec = model_spec ,
6868 )
6969
70+ @test_util .test_in_tf_2
71+ def test_bert_model (self ):
72+ model_spec = ms .BertModelSpec (seq_len = 2 , trainable = False )
73+ all_data = text_dataloader .TextClassifierDataLoader .from_folder (
74+ self .text_dir , model_spec = model_spec )
75+ # Splits data, 90% data for training, 10% for testing
76+ self .train_data , self .test_data = all_data .split (0.9 )
77+
78+ model = text_classifier .create (
79+ self .train_data ,
80+ mef .ModelExportFormat .TFLITE ,
81+ model_spec = model_spec ,
82+ epochs = 1 ,
83+ batch_size = 1 ,
84+ shuffle = True )
85+ self ._test_accuracy (model , 0.5 )
86+
7087 @test_util .test_in_tf_2
7188 def test_average_wordvec_model (self ):
7289 model_spec = ms .AverageWordVecModelSpec (seq_len = 2 )
@@ -86,9 +103,9 @@ def test_average_wordvec_model(self):
86103 self ._test_export_to_tflite (model )
87104 self ._test_predict_top_k (model )
88105
89- def _test_accuracy (self , model ):
106+ def _test_accuracy (self , model , threshold = 1.0 ):
90107 _ , accuracy = model .evaluate (self .test_data )
91- self .assertEqual (accuracy , 1.0 )
108+ self .assertEqual (accuracy , threshold )
92109
93110 def _test_predict_top_k (self , model ):
94111 topk = model .predict_top_k (self .test_data , batch_size = 4 )
0 commit comments