33Provides a function to validate the input to the model.
44'''
55# pylint: disable=too-many-branches
6+
7+ import json
68from typing import Any , Dict , List , Union
79
10+ # Error messages
11+ UNEXPECTED_INPUT_ERROR = "Unexpected input. {} is not a valid input option."
12+ MISSING_REQUIRED_ERROR = "{} is a required input."
13+ MISSING_DEFAULT_ERROR = "Schema error, missing default value for {}."
14+ MISSING_TYPE_ERROR = "Schema error, missing type for {}."
15+ INVALID_TYPE_ERROR = "{} should be {} type, not {}."
16+ CONSTRAINTS_ERROR = "{} does not meet the constraints."
17+ SCHEMA_ERROR = "Schema error, {} is not a dictionary."
18+
819
920def _add_error (error_list : List [str ], message : str ) -> None :
1021 error_list .append (message )
@@ -25,29 +36,41 @@ def validate(raw_input: Dict[str, Any], schema: Dict[str, Any]
2536 {"validated_input": {"input1": "value1", "input2": "value2"}
2637 '''
2738 error_list = []
39+ validated_input = raw_input .copy ()
2840
2941 # Check for unexpected inputs.
3042 for key in raw_input :
3143 if key not in schema :
32- _add_error (error_list , f"Unexpected input. { key } is not a valid input option." )
44+ _add_error (error_list , UNEXPECTED_INPUT_ERROR .format (key ))
45+
46+ # Check that items are dictionaries.
47+ for key , rules in schema .items ():
48+ if not isinstance (rules , dict ):
49+ try :
50+ schema [key ] = json .loads (rules )
51+ except json .decoder .JSONDecodeError :
52+ _add_error (error_list , SCHEMA_ERROR .format (key ))
3353
3454 # Checks for missing required inputs or sets the default values.
3555 for key , rules in schema .items ():
56+ if 'type' not in rules :
57+ _add_error (error_list , MISSING_TYPE_ERROR .format (key ))
58+
3659 if 'required' not in rules :
37- _add_error (error_list , f"Schema error, missing 'required' for { key } ." )
60+ _add_error (error_list , MISSING_REQUIRED_ERROR . format ( key ) )
3861 elif rules ['required' ] and key not in raw_input :
39- _add_error (error_list , f" { key } is a required input." )
62+ _add_error (error_list , MISSING_REQUIRED_ERROR . format ( key ) )
4063 elif rules ['required' ] and key not in raw_input and "default" not in rules :
41- _add_error (error_list , f"Schema error, missing default value for { key } ." )
64+ _add_error (error_list , MISSING_DEFAULT_ERROR . format ( key ) )
4265 elif not rules ['required' ] and key not in raw_input and "default" not in rules :
43- _add_error (error_list , f"Schema error, missing default value for { key } ." )
66+ _add_error (error_list , MISSING_DEFAULT_ERROR . format ( key ) )
4467 elif not rules ['required' ] and key not in raw_input :
45- raw_input [key ] = raw_input .get (key , rules ['default' ])
68+ validated_input [key ] = raw_input .get (key , rules ['default' ])
4669
4770 for key , rules in schema .items ():
4871 # Enforce floats to be floats.
4972 if rules ['type' ] is float and type (raw_input [key ]) in [int , float ]:
50- raw_input [key ] = float (raw_input [key ])
73+ validated_input [key ] = float (raw_input [key ])
5174
5275 # Check for the correct type.
5376 if not isinstance (raw_input [key ], rules ['type' ]) and raw_input [key ] is not None :
@@ -57,9 +80,9 @@ def validate(raw_input: Dict[str, Any], schema: Dict[str, Any]
5780 # Check lambda constraints.
5881 if "constraints" in rules :
5982 if not rules ['constraints' ](raw_input [key ]):
60- _add_error (error_list , f" { key } does not meet the constraints." )
83+ _add_error (error_list , CONSTRAINTS_ERROR . format ( key ) )
6184
62- validation_return = {"validated_input" : raw_input }
85+ validation_return = {"validated_input" : validated_input }
6386 if error_list :
6487 validation_return = {"errors" : error_list }
6588
0 commit comments