Skip to content

Commit 64cfb5f

Browse files
committed
Update websocket
2 parents 878959b + 8c823b0 commit 64cfb5f

File tree

4 files changed

+379
-339
lines changed

4 files changed

+379
-339
lines changed

back/back/apps/language_model/consumers/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def merge_contents(existing: List[Content], new: List[Content]) -> List[Content]
141141
content=aggregated_contents,
142142
usage=None,
143143
stop_reason="end_turn"
144-
)
144+
).model_dump()
145145
)
146146
# Start a new group for the new role.
147147
current_role = role
@@ -155,7 +155,7 @@ def merge_contents(existing: List[Content], new: List[Content]) -> List[Content]
155155
content=aggregated_contents,
156156
usage=None,
157157
stop_reason="end_turn"
158-
)
158+
).model_dump()
159159
)
160160

161161
return aggregated_messages
@@ -283,9 +283,9 @@ async def query_llm(
283283
if messages: # In case the fsm sends messages
284284
if messages[0]["role"] == AgentType.system.value:
285285
if prev_messages[0].role == AgentType.system.value:
286-
new_messages[0] = Message(**messages[0]) # replace the original system message with the new one from the fsm
286+
new_messages[0] = messages[0].copy() # replace the original system message with the new one from the fsm
287287
else:
288-
new_messages.insert(0, Message(**messages[0])) # or add the fsm system message
288+
new_messages.insert(0, messages[0].copy()) # or add the fsm system message
289289

290290
# pop the system message
291291
messages = messages[1:]

sdk/chatfaq_sdk/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import httpx
1414
import sentry_sdk
1515
import websockets
16+
from websockets.protocol import State
1617

1718
from chatfaq_sdk import settings
1819
from chatfaq_sdk.conditions import Condition
@@ -183,7 +184,7 @@ async def producer(self, actions, consumer_route):
183184
ws_attrs = [attr for attr in dir(self) if attr.startswith("ws_")]
184185
while True:
185186
if any(
186-
getattr(self, ws_attr) is None or not getattr(self, ws_attr).open
187+
getattr(self, ws_attr) is None or getattr(self, ws_attr).state != State.OPEN
187188
for ws_attr in ws_attrs
188189
):
189190
await asyncio.sleep(0.01)
@@ -245,7 +246,7 @@ async def _disconnect(self):
245246
logger.info("Shutting Down...")
246247
wss = [getattr(self, attr) for attr in dir(self) if attr.startswith("ws_")]
247248
for ws in wss:
248-
if ws is not None and ws.open:
249+
if ws is not None and ws.state == State.OPEN:
249250
await ws.close()
250251

251252
async def rpc_request_callback(self, payload):

0 commit comments

Comments
 (0)