Skip to content

Commit 8f429e0

Browse files
committed
keras prediction functions now works both with model path and model object
1 parent f953f94 commit 8f429e0

3 files changed

Lines changed: 16 additions & 11 deletions

File tree

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ from imagepreprocessing.utilities import create_confusion_matrix, train_test_sp
4848

4949
images_path = "deep_learning/test_images/food2"
5050
save_path = "food"
51-
model_path = "deep_learning/saved_models/alexnet.h5"
5251

5352
# Create training data split the data
5453
x, y, x_val, y_val = create_training_data_keras(images_path, save_path = save_path, validation_split=0.2, percent_to_use=0.5)
@@ -63,7 +62,7 @@ x, y, test_x, test_y = train_test_split(x,y,save_path = save_path)
6362
class_names = ["apple", "melon", "orange"]
6463

6564
# make prediction
66-
predictions = make_prediction_from_array_keras(test_x, model_path, print_output=False)
65+
predictions = make_prediction_from_array_keras(test_x, model, print_output=False)
6766

6867
# create confusion matrix
6968
create_confusion_matrix(predictions, test_y, class_names=class_names, one_hot=True)

imagepreprocessing/keras_functions.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -265,13 +265,13 @@ def create_training_data_keras(source_path, save_path = None, image_size = (224,
265265
return x, y
266266

267267

268-
def make_prediction_from_directory_keras(images_path, keras_model_path, image_size = (224,224), print_output=True, model_summary=True, show_images=False, grayscale = False, files_to_exclude = [".DS_Store",""]):
268+
def make_prediction_from_directory_keras(images_path, keras_model, image_size = (224,224), print_output=True, model_summary=True, show_images=False, grayscale = False, files_to_exclude = [".DS_Store",""]):
269269
"""
270270
Reads test data from directory resizes it and makes prediction with using a keras model
271271
272272
# Arguments:
273273
images_path: source path of the test images see input format
274-
keras_model_path: path of the keras model
274+
keras_model: a keras model object or path of the model
275275
img_size (224): size of the images for resizing
276276
print_output (True): prints output
277277
model_summary (True): shows keras model summary
@@ -323,8 +323,11 @@ def make_prediction_from_directory_keras(images_path, keras_model_path, image_si
323323
if exclude in images:
324324
images.remove(exclude)
325325

326-
# load model
327-
model = keras.models.load_model(keras_model_path)
326+
# prepare model
327+
if(isinstance(keras_model, keras.Model)):
328+
model = keras_model
329+
else:
330+
model = keras.models.load_model(keras_model)
328331

329332
# get all images
330333
for image in images:
@@ -367,13 +370,13 @@ def make_prediction_from_directory_keras(images_path, keras_model_path, image_si
367370
return predictions
368371

369372

370-
def make_prediction_from_array_keras(test_x, keras_model_path, print_output=True, model_summary=True, show_images=False):
373+
def make_prediction_from_array_keras(test_x, keras_model, print_output=True, model_summary=True, show_images=False):
371374
"""
372375
makes prediction with using a keras model
373376
374377
# Arguments:
375378
test_x: numpy array of images
376-
keras_model_path: path of the keras model
379+
keras_model: a keras model object or path of the model
377380
print_output (True): prints output
378381
model_summary (True): shows keras model summary
379382
show_images (False): shows the predicted image
@@ -392,8 +395,11 @@ def make_prediction_from_array_keras(test_x, keras_model_path, print_output=True
392395
import keras
393396
import cv2
394397

395-
# load model
396-
model = keras.models.load_model(keras_model_path)
398+
# prepare model
399+
if(isinstance(keras_model, keras.Model)):
400+
model = keras_model
401+
else:
402+
model = keras.models.load_model(keras_model)
397403

398404
# show model summary
399405
if(model_summary):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="imagepreprocessing",
8-
version="1.4.1",
8+
version="1.5.0",
99
author="Can Kurt",
1010
author_email="[email protected]",
1111
description="image preprocessing",

0 commit comments

Comments
 (0)