33from pydantic import BaseModel , create_model
44from pydantic .fields import FieldInfo
55
6+ from docarray .base_doc .doc import BaseDocWithoutId
67from docarray import BaseDoc , DocList
78from docarray .typing import AnyTensor
89from docarray .utils ._internal ._typing import safe_issubclass
@@ -50,16 +51,19 @@ class MyDoc(BaseDoc):
5051 :param model: The input model
5152 :return: A new subclass of BaseDoc, where every DocList type in the schema is replaced by List.
5253 """
53- if is_pydantic_v2 :
54- raise NotImplementedError (
55- 'This method is not supported in Pydantic 2.0. Please use Pydantic 1.8.2 or lower.'
56- )
57-
5854 fields : Dict [str , Any ] = {}
59- for field_name , field in model .__annotations__ .items ():
60- if field_name not in model .__fields__ :
55+ import copy
56+
57+ fields_copy = copy .deepcopy (model .__fields__ )
58+ annotations_copy = copy .deepcopy (model .__annotations__ )
59+ for field_name , field in annotations_copy .items ():
60+ if field_name not in fields_copy :
6161 continue
62- field_info = model .__fields__ [field_name ].field_info
62+
63+ if is_pydantic_v2 :
64+ field_info = fields_copy [field_name ]
65+ else :
66+ field_info = fields_copy [field_name ].field_info
6367 try :
6468 if safe_issubclass (field , DocList ):
6569 t : Any = field .doc_type
@@ -68,9 +72,8 @@ class MyDoc(BaseDoc):
6872 fields [field_name ] = (field , field_info )
6973 except TypeError :
7074 fields [field_name ] = (field , field_info )
71- return create_model (
72- model .__name__ , __base__ = model , __validators__ = model .__validators__ , ** fields
73- )
75+
76+ return create_model (model .__name__ , __base__ = model , __doc__ = model .__doc__ , ** fields )
7477
7578
7679def _get_field_annotation_from_schema (
@@ -201,6 +204,8 @@ def _get_field_annotation_from_schema(
201204 num_recursions = num_recursions + 1 ,
202205 definitions = definitions ,
203206 )
207+ elif field_type == 'null' :
208+ ret = None
204209 else :
205210 if num_recursions > 0 :
206211 raise ValueError (
@@ -255,14 +260,18 @@ class MyDoc(BaseDoc):
255260 :return: A BaseDoc class dynamically created following the `schema`.
256261 """
257262 if not definitions :
258- definitions = schema .get ('definitions' , {})
263+ definitions = (
264+ schema .get ('definitions' , {}) if not is_pydantic_v2 else schema .get ('$defs' )
265+ )
259266
260267 cached_models = cached_models if cached_models is not None else {}
261268 fields : Dict [str , Any ] = {}
262269 if base_doc_name in cached_models :
263270 return cached_models [base_doc_name ]
271+ has_id = False
264272 for field_name , field_schema in schema .get ('properties' , {}).items ():
265-
273+ if field_name == 'id' :
274+ has_id = True
266275 field_type = _get_field_annotation_from_schema (
267276 field_schema = field_schema ,
268277 field_name = field_name ,
@@ -272,17 +281,43 @@ class MyDoc(BaseDoc):
272281 num_recursions = 0 ,
273282 definitions = definitions ,
274283 )
275- fields [field_name ] = (
276- field_type ,
277- FieldInfo (default = field_schema .pop ('default' , None ), ** field_schema ),
278- )
284+ if not is_pydantic_v2 :
285+ field_schema ['default' ] = field_schema .get ('default' , None )
286+ fields [field_name ] = (
287+ field_type ,
288+ FieldInfo (** field_schema ),
289+ )
290+ else :
291+ field_kwargs = {}
292+ field_json_schema_extra = {}
293+ for k , v in field_schema .items ():
294+ if k in FieldInfo .__slots__ :
295+ field_kwargs [k ] = v
296+ else :
297+ field_json_schema_extra [k ] = v
298+ fields [field_name ] = (
299+ field_type ,
300+ FieldInfo (
301+ json_schema_extra = field_json_schema_extra ,
302+ ** field_kwargs ,
303+ ),
304+ )
279305
280- model = create_model (base_doc_name , __base__ = BaseDoc , ** fields )
281- model .__config__ .title = schema .get ('title' , model .__config__ .title )
306+ base_model = BaseDoc if has_id else BaseDocWithoutId
307+ model = create_model (base_doc_name , __base__ = base_model , ** fields )
308+ if not is_pydantic_v2 :
309+ model .__config__ .title = schema .get ('title' , model .__config__ .title )
310+ else :
311+ set_title = schema .get ('title' , model .model_config .get ('title' , None ))
312+ if set_title :
313+ model .model_config ['title' ] = set_title
282314
283315 for k in RESERVED_KEYS :
284316 if k in schema :
285317 schema .pop (k )
286- model .__config__ .schema_extra = schema
318+ if not is_pydantic_v2 :
319+ model .__config__ .schema_extra = schema
320+ else :
321+ model .model_config ['json_schema_extra' ] = schema
287322 cached_models [base_doc_name ] = model
288323 return model
0 commit comments