1111
1212
1313class ClaudeChatModel (LLM ):
14- def __init__ (self , llm_name : str = "claude-3-opus-20240229 " , ** kwargs ) -> None :
14+ def __init__ (self , llm_name : str = "claude-3-7-sonnet-latest " , ** kwargs ) -> None :
1515 self .llm_name = llm_name
1616 self .client = Anthropic (
1717 api_key = os .environ .get ("ANTHROPIC_API_KEY" ),
@@ -85,7 +85,7 @@ def format_content(message: Union[Dict, Message]):
8585 return content_list
8686
8787 messages_formatted = [
88- {"role" : message ["role" ], "content" : format_content (message )}
88+ {"role" : message ["role" ] if isinstance ( message , Dict ) else message . role , "content" : format_content (message )}
8989 for message in messages
9090 ]
9191
@@ -102,6 +102,7 @@ def _map_anthropic_message(self, message) -> Message:
102102 """
103103 content_list = []
104104 for part in message .content :
105+ # TODO: Handle thinking output block types
105106 if part .type == "text" :
106107 content_list .append (
107108 Content (
@@ -135,6 +136,7 @@ def stream(
135136 temperature : float = 0.2 ,
136137 max_tokens : int = 1024 ,
137138 seed : int = None ,
139+ thinking : dict = None ,
138140 ** kwargs ,
139141 ):
140142 """
@@ -157,18 +159,24 @@ def stream(
157159 temperature = temperature ,
158160 max_tokens = max_tokens ,
159161 stream = True ,
162+ thinking = thinking if thinking else NOT_GIVEN ,
160163 )
161164
162165 for event in stream :
163166 if event .type == "content_block_delta" :
164- yield event .delta .text
167+ if event .delta .type == "thinking_delta" :
168+ pass # Pass for now until I figure out a common interface for thinking
169+ # yield event.delta.thinking
170+ elif event .delta .type == "text_delta" :
171+ yield event .delta .text
165172
166173 async def astream (
167174 self ,
168175 messages : List [Union [Dict , Message ]],
169176 temperature : float = 0.2 ,
170177 max_tokens : int = 1024 ,
171178 seed : int = None ,
179+ thinking : dict = None ,
172180 ** kwargs ,
173181 ):
174182 """
@@ -191,18 +199,24 @@ async def astream(
191199 temperature = temperature ,
192200 max_tokens = max_tokens ,
193201 stream = True ,
202+ thinking = thinking if thinking else NOT_GIVEN ,
194203 )
195204
196205 async for event in stream :
197206 if event .type == "content_block_delta" :
198- yield event .delta .text
207+ if event .delta .type == "thinking_delta" :
208+ pass # Pass for now until I figure out a common interface for thinking
209+ # yield event.delta.thinking
210+ elif event .delta .type == "text_delta" :
211+ yield event .delta .text
199212
200213 def generate (
201214 self ,
202215 messages : List [Union [Dict , Message ]],
203216 temperature : float = 0.2 ,
204217 max_tokens : int = 1024 ,
205218 seed : int = None ,
219+ thinking : dict = None ,
206220 tools : List [Union [Callable , Dict ]] = None ,
207221 tool_choice : str = None ,
208222 ** kwargs ,
@@ -232,6 +246,7 @@ def generate(
232246 temperature = temperature ,
233247 max_tokens = max_tokens ,
234248 ** tool_kwargs ,
249+ thinking = thinking if thinking else NOT_GIVEN ,
235250 )
236251
237252 return self ._map_anthropic_message (message )
@@ -242,6 +257,7 @@ async def agenerate(
242257 temperature : float = 0.2 ,
243258 max_tokens : int = 1024 ,
244259 seed : int = None ,
260+ thinking : dict = None ,
245261 tools : List [Union [Callable , Dict ]] = None ,
246262 tool_choice : str = None ,
247263 ** kwargs ,
@@ -271,6 +287,7 @@ async def agenerate(
271287 temperature = temperature ,
272288 max_tokens = max_tokens ,
273289 ** tool_kwargs ,
290+ thinking = thinking if thinking else NOT_GIVEN ,
274291 )
275292
276293 return self ._map_anthropic_message (message )
0 commit comments