Skip to content

Commit 3313f10

Browse files
committed
Bug fixing
1 parent 80efc37 commit 3313f10

6 files changed

Lines changed: 183 additions & 53 deletions

File tree

back/back/apps/broker/serializers/messages/__init__.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,14 @@ class Reference(serializers.Serializer):
107107
knowledge_base_id = serializers.CharField(required=False, allow_null=True, allow_blank=True)
108108

109109

110-
class ToolUse(serializers.Serializer):
110+
class ToolUsePayload(serializers.Serializer):
111111
id = serializers.CharField(required=True)
112112
name = serializers.CharField(required=True)
113113
args = serializers.JSONField(required=True)
114114
text = serializers.CharField(required=False, allow_null=True, allow_blank=True)
115115

116116

117-
class ToolResult(serializers.Serializer):
117+
class ToolResultPayload(serializers.Serializer):
118118
id = serializers.CharField(required=False, allow_null=True, allow_blank=True)
119119
name = serializers.CharField(required=False, allow_null=True, allow_blank=True)
120120
result = serializers.CharField(required=True)
@@ -124,19 +124,6 @@ class MessagePayload(serializers.Serializer):
124124
content = serializers.SerializerMethodField()
125125
references = Reference(required=False, allow_null=True)
126126

127-
@extend_schema_field(
128-
PolymorphicProxySerializer(
129-
component_name="MessageContent",
130-
serializers=[
131-
serializers.CharField,
132-
ToolUse,
133-
ToolResult
134-
],
135-
)
136-
)
137-
def get_content(self, obj):
138-
return obj
139-
140127

141128
class HTMLPayload(serializers.Serializer):
142129
@staticmethod
@@ -176,6 +163,8 @@ class QuickRepliesPayload(serializers.Serializer):
176163
"ImagePayload": ImagePayload,
177164
"SatisfactionPayload": SatisfactionPayload,
178165
"QuickRepliesPayload": QuickRepliesPayload,
166+
"ToolUsePayload": ToolUsePayload,
167+
"ToolResultPayload": ToolResultPayload,
179168
},
180169
)
181170
)

back/back/apps/language_model/consumers/__init__.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from back.config import settings
2626
from back.utils import WSStatusCodes
2727
from back.utils.custom_channels import CustomAsyncConsumer
28-
from chat_rag.llms.types import Content, Message, ToolResult, ToolUse
2928
from chat_rag.llms import load_llm
29+
from chat_rag.llms.types import Content, Message, ToolResult, ToolUse
3030

3131
logger = getLogger(__name__)
3232

@@ -59,29 +59,21 @@ def process_stack(stack) -> List[Content]:
5959
"""
6060
contents = []
6161
payload = stack.get("payload", {})
62+
type = stack.get("type")
6263

6364
# Create a text content if available.
64-
text = payload.get("content")
65-
if text:
66-
contents.append(Content(text=text, type="text"))
65+
if type == "message":
66+
contents.append(Content(text=payload.get("content"), type="text"))
6767

6868
# Check if this stack represents a tool call (tool use).
69-
if payload.get("tool_use"):
70-
try:
71-
tool_use_obj = ToolUse(**payload["tool_use"])
72-
contents.append(Content(tool_use=tool_use_obj, type="tool_use"))
73-
except Exception as e:
74-
# If it fails to parse tool_use, we simply skip it.
75-
pass
69+
if type == "tool_use":
70+
tool_use_obj = ToolUse(**payload)
71+
contents.append(Content(tool_use=tool_use_obj, type="tool_use"))
7672

7773
# Check if this stack represents a tool result.
78-
if payload.get("tool_result"):
79-
try:
80-
tool_result_obj = ToolResult(**payload["tool_result"])
81-
contents.append(Content(tool_result=tool_result_obj, type="tool_result"))
82-
except Exception as e:
83-
# If it fails to parse tool_result, we simply skip it.
84-
pass
74+
if type == "tool_result":
75+
tool_result_obj = ToolResult(**payload)
76+
contents.append(Content(tool_result=tool_result_obj, type="tool_result"))
8577

8678
return contents
8779

chat_rag/chat_rag/llms/format_tools.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def format_tools(
135135
tools_formatted = []
136136
if mode in {Mode.OPENAI_TOOLS, Mode.MISTRAL_TOOLS}:
137137
for tool in tools:
138+
# As it is already in the openai format, we can just append it
138139
tools_formatted.append(tool)
139140

140141
elif mode == Mode.ANTHROPIC_TOOLS:

chat_rag/chat_rag/llms/openai_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def _format_tools(self, tools: List[Union[Callable, Dict]], tool_choice: str = N
2727
tools_formatted = format_tools(tools, mode=Mode.OPENAI_TOOLS)
2828
tool_choice = self._check_tool_choice(tools_formatted, tool_choice)
2929

30-
3130
# If the tool_choice is a named tool, then apply correct formatting
3231
if tool_choice in [tool['title'] for tool in tools]:
3332
tool_choice = {

sdk/chatfaq_sdk/layers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,14 +356,14 @@ def __init__(self, id: str = None, name: str = None, result: dict = None):
356356
super().__init__()
357357
self.id = id
358358
self.name = name
359-
self.result = result
359+
self.tool_result = result
360360

361361
async def build_payloads(self, ctx, data):
362362
payload = {
363363
"payload": {
364364
"id": self.id,
365365
"name": self.name,
366-
"result": self.result,
366+
"result": self.tool_result,
367367
}
368368
}
369369
yield [payload], True

sdk/chatfaq_sdk/utils.py

Lines changed: 166 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,173 @@
1-
def convert_mml_to_llm_format(mml):
1+
import inspect
2+
from typing import Dict, List
3+
4+
ROLES_MAP = {
5+
"bot": "assistant",
6+
"human": "user",
7+
}
8+
9+
10+
def convert_mml_to_llm_format(mml: List[Dict]) -> List[Dict]:
211
"""
312
Converts the MML (Message Markup Language) format to the common LLM message format.
13+
Analogous to format_msgs_chain_to_llm_context in back/back/apps/language_model/consumers/__init__.py
14+
415
516
:param mml: List of messages in MML format
617
:return: List of messages in LLM format {'role': 'user', 'content': '...'}
718
"""
8-
roles_map = {
9-
"bot": "assistant",
10-
"human": "user",
19+
aggregated_messages = []
20+
current_role = None # "user" for human and "assistant" for bot
21+
aggregated_contents = [] # list of Content objects for the current group
22+
23+
def process_stack(stack: Dict) -> List[Dict]:
24+
"""
25+
Process a single stack item into a list of LLM message format.
26+
"""
27+
contents = []
28+
type = stack.get("type")
29+
30+
if type == "message":
31+
contents.append(
32+
{
33+
"type": "text",
34+
"text": stack["payload"]["content"],
35+
}
36+
)
37+
elif type == "tool_use":
38+
contents.append(
39+
{
40+
"type": "tool_use",
41+
"tool_use": stack["payload"],
42+
}
43+
)
44+
elif type == "tool_result":
45+
contents.append(
46+
{
47+
"type": "tool_result",
48+
"tool_result": stack["payload"],
49+
}
50+
)
51+
52+
return contents
53+
54+
def process_msg(msg: Dict) -> List[Dict]:
55+
"""
56+
Process each broker message into a list of LLM message format by iterating over its stacks.
57+
"""
58+
contents = []
59+
for stack in msg.get("stack", []):
60+
contents.extend(process_stack(stack))
61+
return contents
62+
63+
def merge_contents(existing: List[Dict], new: List[Dict]) -> List[Dict]:
64+
"""
65+
Merge two lists of LLM message format.
66+
If the last element of the existing list and the first element of the new list are both text,
67+
then they are concatenated.
68+
"""
69+
if not existing:
70+
return new
71+
if not new:
72+
return existing
73+
74+
merged = existing.copy()
75+
if merged and new and merged[-1]["type"] == "text" and new[0]["type"] == "text":
76+
merged[-1]["text"] = (
77+
merged[-1]["text"].strip() + " " + new[0]["text"].strip()
78+
)
79+
merged.extend(new[1:])
80+
else:
81+
merged.extend(new)
82+
return merged
83+
84+
for msg in mml:
85+
role = ROLES_MAP[msg["sender"]["type"]]
86+
msg_contents = process_msg(msg)
87+
if not msg_contents:
88+
continue
89+
90+
if current_role is None:
91+
current_role = role
92+
aggregated_contents = msg_contents
93+
elif current_role == role:
94+
aggregated_contents = merge_contents(aggregated_contents, msg_contents)
95+
else:
96+
aggregated_messages.append(
97+
{
98+
"role": current_role,
99+
"content": aggregated_contents,
100+
}
101+
)
102+
current_role = role
103+
aggregated_contents = msg_contents
104+
105+
if aggregated_contents:
106+
aggregated_messages.append(
107+
{
108+
"role": current_role,
109+
"content": aggregated_contents,
110+
}
111+
)
112+
113+
return aggregated_messages
114+
115+
116+
def function_to_json(func) -> dict:
117+
"""
118+
Converts a Python function into a JSON-serializable dictionary
119+
that describes the function's signature, including its name,
120+
description, and parameters.
121+
Function from https://github.com/openai/swarm
122+
123+
Args:
124+
func: The function to be converted.
125+
126+
Returns:
127+
A dictionary representing the function's signature in JSON format.
128+
"""
129+
type_map = {
130+
str: "string",
131+
int: "integer",
132+
float: "number",
133+
bool: "boolean",
134+
list: "array",
135+
dict: "object",
136+
type(None): "null",
137+
}
138+
139+
try:
140+
signature = inspect.signature(func)
141+
except ValueError as e:
142+
raise ValueError(
143+
f"Failed to get signature for function {func.__name__}: {str(e)}"
144+
)
145+
146+
parameters = {}
147+
for param in signature.parameters.values():
148+
try:
149+
param_type = type_map.get(param.annotation, "string")
150+
except KeyError as e:
151+
raise KeyError(
152+
f"Unknown type annotation {param.annotation} for parameter {param.name}: {str(e)}"
153+
)
154+
parameters[param.name] = {"type": param_type}
155+
156+
required = [
157+
param.name
158+
for param in signature.parameters.values()
159+
if param.default == inspect._empty
160+
]
161+
162+
return {
163+
"type": "function",
164+
"function": {
165+
"name": func.__name__,
166+
"description": func.__doc__ or "",
167+
"parameters": {
168+
"type": "object",
169+
"properties": parameters,
170+
"required": required,
171+
},
172+
},
11173
}
12-
messages = []
13-
14-
for message in mml:
15-
for stack in message.get("stack", []):
16-
content = stack["payload"].get("content")
17-
if not content:
18-
continue
19-
messages.append({
20-
"role": roles_map[message["sender"]["type"]],
21-
"content": content,
22-
})
23-
24-
return messages

0 commit comments

Comments
 (0)