Skip to content

Commit 98be67f

Browse files
committed
Add structured outputs for openai
1 parent b94979e commit 98be67f

File tree

3 files changed

+87
-3
lines changed

3 files changed

+87
-3
lines changed

chat_rag/chat_rag/llms/base_llm.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Dict, List, Optional, Union, Tuple, Callable
1+
from typing import Callable, Dict, List, Optional, Tuple, Union
2+
23
from chat_rag.llms.types import Message
34

45

@@ -43,3 +44,18 @@ async def agenerate(
4344
tool_choice: str = None,
4445
) -> Message:
4546
pass
47+
48+
def parse(
49+
self,
50+
messages: List[Dict[str, str]],
51+
schema: Dict,
52+
) -> Message:
53+
raise NotImplementedError("This LLM does not support enforced structured output.")
54+
55+
async def aparse(
56+
self,
57+
messages: List[Dict[str, str]],
58+
schema: Dict,
59+
) -> Message:
60+
raise NotImplementedError("This LLM does not support enforced structured output.")
61+

chat_rag/chat_rag/llms/openai_client.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import Callable, Dict, List, Union
21
import json
2+
from typing import Callable, Dict, List, Union
33

44
from openai import AsyncOpenAI, OpenAI
5+
from openai.lib._pydantic import _ensure_strict_json_schema
56

67
from chat_rag.llms.types import Content, Message, ToolUse, Usage
78

@@ -27,7 +28,9 @@ def _format_tools(
2728
"""
2829
Format the tools from a openai dict or a callable function to the OpenAI format.
2930
"""
30-
tools_formatted, tool_choice = format_tools(tools, tool_choice, mode=Mode.OPENAI_TOOLS)
31+
tools_formatted, tool_choice = format_tools(
32+
tools, tool_choice, mode=Mode.OPENAI_TOOLS
33+
)
3134

3235
# If the tool_choice is a named tool, then apply correct formatting
3336
if tool_choice in [tool["function"]["name"] for tool in tools_formatted]:
@@ -278,3 +281,67 @@ async def agenerate(
278281
)
279282

280283
return self._map_openai_message(response)
284+
285+
def parse(
286+
self,
287+
messages: List[Union[Dict, Message]],
288+
schema: Dict,
289+
**kwargs,
290+
) -> Message:
291+
"""
292+
Parse the response from the model into a structured format.
293+
Parameters
294+
----------
295+
messages : List[Tuple[str, str]]
296+
The messages to use for the prompt. Pair of (role, message).
297+
schema : Dict
298+
The schema to use for the response. It must be a pydantic model json schema, it can be generated using the `model_json_schema` method of the pydantic model.
299+
Returns
300+
-------
301+
Message
302+
The parsed message.
303+
"""
304+
messages = self._format_messages(messages)
305+
response_format = {
306+
"type": "json_schema",
307+
"json_schema": {
308+
"schema": _ensure_strict_json_schema(schema, path=(), root=schema),
309+
"name": schema["title"],
310+
"strict": True,
311+
},
312+
}
313+
314+
response = self.client.beta.chat.completions.parse(
315+
model=self.llm_name,
316+
messages=messages,
317+
response_format=response_format,
318+
)
319+
320+
return json.loads(response.choices[0].message.content)
321+
322+
async def aparse(
323+
self,
324+
messages: List[Union[Dict, Message]],
325+
schema: Dict,
326+
**kwargs,
327+
) -> Message:
328+
"""
329+
Parse the response from the model into a structured format.
330+
"""
331+
messages = self._format_messages(messages)
332+
response_format = {
333+
"type": "json_schema",
334+
"json_schema": {
335+
"schema": _ensure_strict_json_schema(schema, path=(), root=schema),
336+
"name": schema["title"],
337+
"strict": True,
338+
},
339+
}
340+
341+
response = await self.aclient.beta.chat.completions.parse(
342+
model=self.llm_name,
343+
messages=messages,
344+
response_format=response_format,
345+
)
346+
347+
return json.loads(response.choices[0].message.content)

sdk/examples/agent_example/fsm_definition.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ async def send_greeting(sdk: ChatFAQSDK, ctx: dict):
2121
async def send_answer(sdk: ChatFAQSDK, ctx: dict):
2222
agent = Agent(
2323
sdk=sdk,
24+
model_name=MODEL_NAME,
2425
tools=[get_weather],
2526
system_instruction="You are a knowledgeable weather assistant. Use provided tools when necessary."
2627
)

0 commit comments

Comments
 (0)