1- from typing import Callable , Dict , List , Union
21import json
2+ from typing import Callable , Dict , List , Union
33
44from openai import AsyncOpenAI , OpenAI
5+ from openai .lib ._pydantic import _ensure_strict_json_schema
56
67from chat_rag .llms .types import Content , Message , ToolUse , Usage
78
@@ -27,7 +28,9 @@ def _format_tools(
2728 """
2829 Format the tools from a openai dict or a callable function to the OpenAI format.
2930 """
30- tools_formatted , tool_choice = format_tools (tools , tool_choice , mode = Mode .OPENAI_TOOLS )
31+ tools_formatted , tool_choice = format_tools (
32+ tools , tool_choice , mode = Mode .OPENAI_TOOLS
33+ )
3134
3235 # If the tool_choice is a named tool, then apply correct formatting
3336 if tool_choice in [tool ["function" ]["name" ] for tool in tools_formatted ]:
@@ -278,3 +281,67 @@ async def agenerate(
278281 )
279282
280283 return self ._map_openai_message (response )
284+
285+ def parse (
286+ self ,
287+ messages : List [Union [Dict , Message ]],
288+ schema : Dict ,
289+ ** kwargs ,
290+ ) -> Message :
291+ """
292+ Parse the response from the model into a structured format.
293+ Parameters
294+ ----------
295+ messages : List[Tuple[str, str]]
296+ The messages to use for the prompt. Pair of (role, message).
297+ schema : Dict
298+ The schema to use for the response. It must be a pydantic model json schema, it can be generated using the `model_json_schema` method of the pydantic model.
299+ Returns
300+ -------
301+ Message
302+ The parsed message.
303+ """
304+ messages = self ._format_messages (messages )
305+ response_format = {
306+ "type" : "json_schema" ,
307+ "json_schema" : {
308+ "schema" : _ensure_strict_json_schema (schema , path = (), root = schema ),
309+ "name" : schema ["title" ],
310+ "strict" : True ,
311+ },
312+ }
313+
314+ response = self .client .beta .chat .completions .parse (
315+ model = self .llm_name ,
316+ messages = messages ,
317+ response_format = response_format ,
318+ )
319+
320+ return json .loads (response .choices [0 ].message .content )
321+
322+ async def aparse (
323+ self ,
324+ messages : List [Union [Dict , Message ]],
325+ schema : Dict ,
326+ ** kwargs ,
327+ ) -> Message :
328+ """
329+ Parse the response from the model into a structured format.
330+ """
331+ messages = self ._format_messages (messages )
332+ response_format = {
333+ "type" : "json_schema" ,
334+ "json_schema" : {
335+ "schema" : _ensure_strict_json_schema (schema , path = (), root = schema ),
336+ "name" : schema ["title" ],
337+ "strict" : True ,
338+ },
339+ }
340+
341+ response = await self .aclient .beta .chat .completions .parse (
342+ model = self .llm_name ,
343+ messages = messages ,
344+ response_format = response_format ,
345+ )
346+
347+ return json .loads (response .choices [0 ].message .content )
0 commit comments