Skip to content

Commit 2d32c60

Browse files
committed
Update Google GenAI SDK and refactor tool handling across LLM clients
1 parent ce80ae2 commit 2d32c60

File tree

8 files changed

+45
-41
lines changed

8 files changed

+45
-41
lines changed

chat_rag/chat_rag/llms/base_llm.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,6 @@
33

44

55
class LLM:
6-
def _check_tool_choice(self, tools: List[Dict], tool_choice: str) -> str:
7-
"""
8-
Adhere to the tool_choice parameter requirements.
9-
"""
10-
11-
if tool_choice:
12-
tool_choices = ["required", "auto"] + [
13-
tool["function"]["name"] for tool in tools
14-
]
15-
assert tool_choice in tool_choices, (
16-
f"tool_choice must be one of {tool_choices}"
17-
)
18-
19-
return tool_choice
206

217
def stream(
228
self,

chat_rag/chat_rag/llms/claude_client.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ def _format_tools(
2626
"""
2727
Format the tools from a generic BaseModel to the OpenAI format.
2828
"""
29-
tools_formatted = format_tools(tools, mode=Mode.ANTHROPIC_TOOLS)
30-
tool_choice = self._check_tool_choice(tools_formatted, tool_choice)
31-
29+
tools_formatted, tool_choice = format_tools(tools, tool_choice, mode=Mode.ANTHROPIC_TOOLS)
3230

3331
# If any messages have cache_control, add cache_control to the last tool so they are cached also
3432
if any(message.get("cache_control") for message in messages):
@@ -37,7 +35,7 @@ def _format_tools(
3735

3836
if tool_choice:
3937
# If the tool_choice is a named tool, then apply correct formatting
40-
if tool_choice in [tool["title"] for tool in tools]:
38+
if tool_choice in [tool["name"] for tool in tools_formatted]:
4139
tool_choice = {"type": "tool", "name": tool_choice}
4240
else: # if it's required or auto, then apply the correct formatting
4341
tool_choice = (
@@ -127,7 +125,7 @@ def _map_anthropic_message(self, message) -> Message:
127125
input_tokens=message.usage.input_tokens,
128126
output_tokens=message.usage.output_tokens,
129127
cache_creation_input_tokens=message.usage.cache_creation_input_tokens,
130-
cache_creation_read_tokens=message.usage.cache_creation_read_tokens,
128+
cache_creation_read_tokens=message.usage.cache_read_input_tokens,
131129
),
132130
)
133131

chat_rag/chat_rag/llms/format_tools.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import enum
22
import inspect
3-
from typing import Any, Callable, Dict, List, Union
3+
from typing import Any, Callable, Dict, List, Tuple, Union
44

55

66
def function_to_json(func) -> dict:
@@ -114,15 +114,33 @@ def uppercase_types_recursively(schema: Dict[str, Any]) -> Dict[str, Any]:
114114
return schema
115115

116116

117+
def _check_tool_choice(tools: List[Dict], tool_choice: str) -> str:
118+
"""
119+
Adhere to the tool_choice parameter requirements.
120+
"""
121+
122+
if tool_choice:
123+
tool_choices = ["required", "auto"] + [
124+
tool["function"]["name"] for tool in tools
125+
]
126+
assert tool_choice in tool_choices, (
127+
f"tool_choice must be one of {tool_choices}"
128+
)
129+
130+
return tool_choice
131+
132+
117133
def format_tools(
118-
tools: List[Union[Callable, Dict]], mode: Mode
119-
) -> List[Dict[str, Any]]:
134+
tools: List[Union[Callable, Dict]], tool_choice: str, mode: Mode
135+
) -> Tuple[List[Dict[str, Any]], str]:
120136
"""
121137
Given a series of functions, return the JSON schema required by each LLM provider.
122138
Parameters
123139
----------
124140
tools : List[Union[Callable, Dict]]
125141
A list of tools in the OpenAI JSON format or a callable function.
142+
tool_choice : str
143+
The tool_choice parameter to use for the LLM provider
126144
mode : Mode
127145
The LLM provider to format the tools for
128146
Returns
@@ -132,6 +150,7 @@ def format_tools(
132150
"""
133151
# first convert to the openai dict format if they are a callable
134152
tools = [function_to_json(tool) if callable(tool) else tool for tool in tools]
153+
tool_choice = _check_tool_choice(tools, tool_choice)
135154
tools_formatted = []
136155
if mode in {Mode.OPENAI_TOOLS, Mode.MISTRAL_TOOLS}:
137156
for tool in tools:
@@ -150,4 +169,4 @@ def format_tools(
150169
else:
151170
raise ValueError(f"Unknown mode {mode}")
152171

153-
return tools_formatted
172+
return tools_formatted, tool_choice

chat_rag/chat_rag/llms/gemini_client.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,14 @@ def _format_tools(self, tools: List[Union[Callable, Dict]], tool_choice: str = N
4949
"""
5050
Format the tools from a generic BaseModel to the Gemini format.
5151
"""
52-
tools_formatted = format_tools(tools, mode=Mode.GEMINI_TOOLS)
53-
tool_choice = self._check_tool_choice(tools_formatted, tool_choice)
52+
tools_formatted, tool_choice = format_tools(tools, tool_choice, mode=Mode.GEMINI_TOOLS)
5453

5554
tools_formatted = [
5655
Tool(function_declarations=[tool]) for tool in tools_formatted
5756
]
5857

5958
# If the tool_choice is a named tool, then apply correct formatting
60-
if tool_choice in [tool["title"] for tool in tools]:
59+
if tool_choice in [tool.function_declarations[0].name for tool in tools_formatted]:
6160
tool_choice = ToolConfig(
6261
function_calling_config=FunctionCallingConfig(
6362
mode="ANY", allowed_function_names=[tool_choice]
@@ -66,7 +65,7 @@ def _format_tools(self, tools: List[Union[Callable, Dict]], tool_choice: str = N
6665
elif tool_choice == "required":
6766
tool_choice = ToolConfig(
6867
function_calling_config=FunctionCallingConfig(
69-
mode="ANY", allowed_function_names=[tool["title"] for tool in tools]
68+
mode="ANY", allowed_function_names=[tool.function_declarations[0].name for tool in tools_formatted]
7069
)
7170
)
7271
elif tool_choice == "auto":
@@ -232,18 +231,20 @@ def _prepare_messages(
232231
"""
233232
def format_content(message: Message):
234233
parts = []
235-
tool_calls = []
236-
tool_results = []
237234
if isinstance(message.content, str):
238235
parts = [Part(text=message.content)]
239236
else:
240237
for content in message.content:
241238
if content.type == "text":
242239
parts.append(Part(text=content.text))
243240
elif content.type == "tool_use":
244-
tool_calls.append(Part(function_call=FunctionCall(id=content.tool_use.id, name=content.tool_use.name, args=content.tool_use.args)))
241+
parts.append(Part(function_call=FunctionCall(id=content.tool_use.id, name=content.tool_use.name, args=content.tool_use.args)))
245242
elif content.type == "tool_result":
246-
tool_results.append(Part(function_response=FunctionResponse(id=content.tool_result.id, name=content.tool_result.name, response=content.tool_result.result)))
243+
result = content.tool_result.result
244+
if isinstance(result, str):
245+
result = {"content": result}
246+
parts.append(Part(function_response=FunctionResponse(id=content.tool_result.id, name=content.tool_result.name, response=result)))
247+
247248
return parts
248249

249250
messages = [Message(**m) if isinstance(m, Dict) else m for m in messages]

chat_rag/chat_rag/llms/openai_client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ def _format_tools(
2727
"""
2828
Format the tools from a openai dict or a callable function to the OpenAI format.
2929
"""
30-
tools_formatted = format_tools(tools, mode=Mode.OPENAI_TOOLS)
31-
tool_choice = self._check_tool_choice(tools_formatted, tool_choice)
30+
tools_formatted, tool_choice = format_tools(tools, tool_choice, mode=Mode.OPENAI_TOOLS)
3231

3332
# If the tool_choice is a named tool, then apply correct formatting
3433
if tool_choice in [tool["function"]["name"] for tool in tools_formatted]:

chat_rag/poetry.lock

Lines changed: 4 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

chat_rag/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ ragatouille = {version = "^0.0.8.post2", optional = true}
3535
hdbscan = {version = "^0.8.36", optional = true}
3636
umap-learn = {version = "^0.5.6", optional = true}
3737
bm25s = {version = "^0.1.7", optional = true}
38-
google-genai = "0.5.0"
38+
google-genai = "1.2.0"
3939

4040
[tool.poetry.extras]
4141
full = [

sdk/examples/agent_example/fsm_definition.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from chatfaq_sdk.utils import convert_mml_to_llm_format
66

77

8+
MODEL_NAME = "gemini-2.0-flash"
9+
810
def get_weather(location: str) -> str:
911
"""
1012
Get current temperature for a given location.
@@ -24,7 +26,7 @@ async def send_answer(sdk: ChatFAQSDK, ctx: dict):
2426

2527
response = await llm_request(
2628
sdk,
27-
"gpt-4o",
29+
MODEL_NAME,
2830
use_conversation_context=False,
2931
conversation_id=ctx["conversation_id"],
3032
bot_channel_name=ctx["bot_channel_name"],
@@ -66,7 +68,7 @@ async def send_answer(sdk: ChatFAQSDK, ctx: dict):
6668

6769
response = await llm_request(
6870
sdk,
69-
"gpt-4o",
71+
MODEL_NAME,
7072
messages=messages,
7173
use_conversation_context=False,
7274
conversation_id=ctx["conversation_id"],

0 commit comments

Comments
 (0)