Skip to content

Commit 69c65be

Browse files
committed
Add thinking
1 parent 10dfc0a commit 69c65be

File tree

9 files changed

+79
-24
lines changed

9 files changed

+79
-24
lines changed

back/back/apps/broker/serializers/rpc.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,16 @@ class CacheConfigSerializer(serializers.Serializer):
9191
name = serializers.CharField(required=False, allow_null=True)
9292

9393

94+
class ThinkingField(serializers.Field):
95+
"""Custom field that accepts both a string or a dictionary"""
96+
def to_internal_value(self, data):
97+
# Return as is - can be either string or dict
98+
return data
99+
100+
def to_representation(self, value):
101+
return value
102+
103+
94104
class RPCLLMRequestSerializer(serializers.Serializer):
95105
"""
96106
Represents the LLM requests coming from the RPC server
@@ -112,6 +122,8 @@ class RPCLLMRequestSerializer(serializers.Serializer):
112122
The seed to use in the LLM
113123
stream: bool
114124
Whether the LLM response should be streamed or not
125+
thinking: str or Dict
126+
The thinking to use in the LLM
115127
"""
116128

117129
llm_config_name = serializers.CharField(required=True, allow_blank=False, allow_null=False)
@@ -121,6 +133,7 @@ class RPCLLMRequestSerializer(serializers.Serializer):
121133
temperature = serializers.FloatField(default=0.7, required=False)
122134
max_tokens = serializers.IntegerField(default=1024, required=False)
123135
seed = serializers.IntegerField(default=42, required=False)
136+
thinking = ThinkingField(default=None, required=False, allow_null=True)
124137
tools = serializers.ListField(
125138
child=serializers.DictField(), allow_empty=True, required=False, allow_null=True
126139
)

back/back/apps/language_model/consumers/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import uuid
33
from logging import getLogger
4-
from typing import Awaitable, Callable, Dict, List, Optional
4+
from typing import Awaitable, Callable, Dict, List, Optional, Union
55

66
from channels.db import database_sync_to_async
77
from channels.generic.websocket import AsyncJsonWebsocketConsumer
@@ -205,6 +205,7 @@ async def query_llm(
205205
temperature: float = 0.7,
206206
max_tokens: int = 1024,
207207
seed: int = 42,
208+
thinking: Union[str, Dict] = None,
208209
tools: List[Dict] = None,
209210
tool_choice: str = None,
210211
use_conversation_context: bool = True,
@@ -303,6 +304,7 @@ async def query_llm(
303304
temperature=temperature,
304305
max_tokens=max_tokens,
305306
seed=seed,
307+
thinking=thinking,
306308
cache_config=cache_config,
307309
)
308310
async for res in response:
@@ -321,6 +323,7 @@ async def query_llm(
321323
temperature=temperature,
322324
max_tokens=max_tokens,
323325
seed=seed,
326+
thinking=thinking,
324327
tools=tools,
325328
tool_choice=tool_choice,
326329
cache_config=cache_config,
@@ -432,6 +435,7 @@ async def process_llm_request(self, data):
432435
data.get("temperature"),
433436
data.get("max_tokens"),
434437
data.get("seed"),
438+
data.get("thinking"),
435439
data.get("tools"),
436440
data.get("tool_choice"),
437441
data.get("use_conversation_context"),

chat_rag/chat_rag/llms/claude_client.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
class ClaudeChatModel(LLM):
14-
def __init__(self, llm_name: str = "claude-3-opus-20240229", **kwargs) -> None:
14+
def __init__(self, llm_name: str = "claude-3-7-sonnet-latest", **kwargs) -> None:
1515
self.llm_name = llm_name
1616
self.client = Anthropic(
1717
api_key=os.environ.get("ANTHROPIC_API_KEY"),
@@ -85,7 +85,7 @@ def format_content(message: Union[Dict, Message]):
8585
return content_list
8686

8787
messages_formatted = [
88-
{"role": message["role"], "content": format_content(message)}
88+
{"role": message["role"] if isinstance(message, Dict) else message.role, "content": format_content(message)}
8989
for message in messages
9090
]
9191

@@ -102,6 +102,7 @@ def _map_anthropic_message(self, message) -> Message:
102102
"""
103103
content_list = []
104104
for part in message.content:
105+
# TODO: Handle thinking output block types
105106
if part.type == "text":
106107
content_list.append(
107108
Content(
@@ -135,6 +136,7 @@ def stream(
135136
temperature: float = 0.2,
136137
max_tokens: int = 1024,
137138
seed: int = None,
139+
thinking: dict = None,
138140
**kwargs,
139141
):
140142
"""
@@ -157,18 +159,24 @@ def stream(
157159
temperature=temperature,
158160
max_tokens=max_tokens,
159161
stream=True,
162+
thinking=thinking if thinking else NOT_GIVEN,
160163
)
161164

162165
for event in stream:
163166
if event.type == "content_block_delta":
164-
yield event.delta.text
167+
if event.delta.type == "thinking_delta":
168+
pass # Pass for now until I figure out a common interface for thinking
169+
# yield event.delta.thinking
170+
elif event.delta.type == "text_delta":
171+
yield event.delta.text
165172

166173
async def astream(
167174
self,
168175
messages: List[Union[Dict, Message]],
169176
temperature: float = 0.2,
170177
max_tokens: int = 1024,
171178
seed: int = None,
179+
thinking: dict = None,
172180
**kwargs,
173181
):
174182
"""
@@ -191,18 +199,24 @@ async def astream(
191199
temperature=temperature,
192200
max_tokens=max_tokens,
193201
stream=True,
202+
thinking=thinking if thinking else NOT_GIVEN,
194203
)
195204

196205
async for event in stream:
197206
if event.type == "content_block_delta":
198-
yield event.delta.text
207+
if event.delta.type == "thinking_delta":
208+
pass # Pass for now until I figure out a common interface for thinking
209+
# yield event.delta.thinking
210+
elif event.delta.type == "text_delta":
211+
yield event.delta.text
199212

200213
def generate(
201214
self,
202215
messages: List[Union[Dict, Message]],
203216
temperature: float = 0.2,
204217
max_tokens: int = 1024,
205218
seed: int = None,
219+
thinking: dict = None,
206220
tools: List[Union[Callable, Dict]] = None,
207221
tool_choice: str = None,
208222
**kwargs,
@@ -232,6 +246,7 @@ def generate(
232246
temperature=temperature,
233247
max_tokens=max_tokens,
234248
**tool_kwargs,
249+
thinking=thinking if thinking else NOT_GIVEN,
235250
)
236251

237252
return self._map_anthropic_message(message)
@@ -242,6 +257,7 @@ async def agenerate(
242257
temperature: float = 0.2,
243258
max_tokens: int = 1024,
244259
seed: int = None,
260+
thinking: dict = None,
245261
tools: List[Union[Callable, Dict]] = None,
246262
tool_choice: str = None,
247263
**kwargs,
@@ -271,6 +287,7 @@ async def agenerate(
271287
temperature=temperature,
272288
max_tokens=max_tokens,
273289
**tool_kwargs,
290+
thinking=thinking if thinking else NOT_GIVEN,
274291
)
275292

276293
return self._map_anthropic_message(message)

chat_rag/chat_rag/llms/openai_client.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from typing import Callable, Dict, List, Union
33

4-
from openai import AsyncOpenAI, OpenAI
4+
from openai import AsyncOpenAI, OpenAI, NOT_GIVEN
55
from openai.lib._pydantic import _ensure_strict_json_schema
66

77
from chat_rag.llms.types import Content, Message, ToolUse, Usage
@@ -142,6 +142,7 @@ def stream(
142142
temperature: float = 1.0,
143143
max_tokens: int = 1024,
144144
seed: int = None,
145+
thinking: str = NOT_GIVEN,
145146
**kwargs,
146147
):
147148
"""
@@ -161,10 +162,11 @@ def stream(
161162
model=self.llm_name,
162163
messages=messages,
163164
temperature=temperature,
164-
max_tokens=max_tokens,
165+
max_completion_tokens=max_tokens,
165166
seed=seed,
166167
n=1,
167168
stream=True,
169+
reasoning_effort=thinking,
168170
)
169171
for chunk in response:
170172
if chunk.choices[0].finish_reason == "stop":
@@ -178,6 +180,7 @@ async def astream(
178180
temperature: float = 1.0,
179181
max_tokens: int = 1024,
180182
seed: int = None,
183+
thinking: str = NOT_GIVEN,
181184
**kwargs,
182185
):
183186
"""
@@ -196,10 +199,11 @@ async def astream(
196199
model=self.llm_name,
197200
messages=messages,
198201
temperature=temperature,
199-
max_tokens=max_tokens,
202+
max_completion_tokens=max_tokens,
200203
seed=seed,
201204
n=1,
202205
stream=True,
206+
reasoning_effort=thinking,
203207
)
204208
async for chunk in response:
205209
if chunk.choices[0].finish_reason == "stop":
@@ -213,6 +217,7 @@ def generate(
213217
temperature: float = 1.0,
214218
max_tokens: int = 1024,
215219
seed: int = None,
220+
thinking: str = NOT_GIVEN,
216221
tools: List[Union[Callable, Dict]] = None,
217222
tool_choice: str = None,
218223
**kwargs,
@@ -237,8 +242,9 @@ def generate(
237242
model=self.llm_name,
238243
messages=messages,
239244
temperature=temperature,
240-
max_tokens=max_tokens,
245+
max_completion_tokens=max_tokens,
241246
seed=seed,
247+
reasoning_effort=thinking,
242248
n=1,
243249
tools=tools,
244250
tool_choice=tool_choice,
@@ -253,6 +259,7 @@ async def agenerate(
253259
temperature: float = 1.0,
254260
max_tokens: int = 1024,
255261
seed: int = None,
262+
thinking: str = NOT_GIVEN,
256263
tools: List[Union[Callable, Dict]] = None,
257264
tool_choice: str = None,
258265
**kwargs,
@@ -276,8 +283,9 @@ async def agenerate(
276283
model=self.llm_name,
277284
messages=messages,
278285
temperature=temperature,
279-
max_tokens=max_tokens,
286+
max_completion_tokens=max_tokens,
280287
seed=seed,
288+
reasoning_effort=thinking,
281289
n=1,
282290
tools=tools,
283291
tool_choice=tool_choice,

chat_rag/poetry.lock

Lines changed: 11 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

chat_rag/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ certifi = "^2023.7.22"
1717
urllib3 = "^1.26.18"
1818
aiohttp = "^3.8.5"
1919
cryptography = "^41.0.4"
20-
openai = "^1.33.0"
21-
anthropic = "0.28.0"
20+
openai = "1.66.2"
21+
anthropic = "0.49.0"
2222
mistralai = "0.4.0"
2323
docstring-parser = "^0.16"
2424
torch = {version = "2.3.0", optional = true}

sdk/chatfaq_sdk/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ async def send_llm_request(
328328
temperature,
329329
max_tokens,
330330
seed,
331+
thinking,
331332
tools,
332333
tool_choice,
333334
conversation_id,
@@ -353,6 +354,7 @@ async def send_llm_request(
353354
"temperature": temperature,
354355
"max_tokens": max_tokens,
355356
"seed": seed,
357+
"thinking": thinking,
356358
"tools": tools,
357359
"tool_choice": tool_choice,
358360
"use_conversation_context": use_conversation_context,

0 commit comments

Comments
 (0)