Skip to content

Commit ce80ae2

Browse files
committed
Merge branch 'develop' into feature/chatf-1409-support-tool-use-in-streaming
2 parents 9db2a46 + 7232e0b commit ce80ae2

File tree

26 files changed

+805
-709
lines changed

26 files changed

+805
-709
lines changed

back/back/apps/broker/consumers/bots/custom_ws.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ async def gather_initial_conversation_metadata(self):
3333
params = parse_qs(self.scope["query_string"])
3434
_meta = params.get(b"metadata", [b"{}"])
3535
return json.loads(_meta[0])
36+
37+
async def gather_fsm_state_overwrite(self):
38+
params = parse_qs(self.scope["query_string"])
39+
state = params.get(b"state_overwrite", "")
40+
return state[0].decode('utf-8') if state else None
3641

3742
@classmethod
3843
def platform_url_paths(cls) -> str:
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Generated by Django 4.2.17 on 2025-02-18 15:08
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
8+
dependencies = [
9+
('broker', '0046_remove_userfeedback_feedback_comment_and_more'),
10+
]
11+
12+
operations = [
13+
migrations.AddField(
14+
model_name='conversation',
15+
name='fsm_state_override',
16+
field=models.TextField(blank=True, null=True),
17+
),
18+
]

back/back/apps/broker/models/message.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class Conversation(ChangesMixin):
4242
name = models.CharField(max_length=255, null=True, blank=True)
4343
initial_conversation_metadata = models.JSONField(default=dict)
4444
authentication_required = models.BooleanField(default=False)
45+
fsm_state_override = models.TextField(null=True, blank=True)
4546

4647
def get_first_msg(self):
4748
return Message.objects.filter(

back/back/apps/fsm/lib/__init__.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -125,22 +125,28 @@ async def next_state(self):
125125
is reached it makes sure everything is saved and cached into the DB to keep the system stateful
126126
"""
127127
transitions = self.get_current_state_transitions()
128-
best_score = 0
129-
best_transition = None
130-
transition_data = {}
131-
for t in transitions:
132-
score, _data = await self.check_transition_condition(t)
133-
if score > best_score:
134-
best_transition = t
135-
best_score = score
136-
transition_data = _data
137-
if best_transition:
128+
state_overwrite = self.get_state_by_name(self.ctx.conversation.fsm_state_overwrite)
129+
if state_overwrite:
130+
logger.debug("Overriding FSM state")
138131
logger.debug(f"FSM from ---> {self.current_state}")
139-
self.current_state = self.get_state_by_name(best_transition.dest)
132+
self.current_state = state_overwrite
140133
logger.debug(f"FSM to -----> {self.current_state}")
141-
await self.run_current_state_events(transition_data)
142-
143-
await self.save_cache()
134+
await self.run_current_state_events()
135+
else:
136+
best_score = 0
137+
best_transition = None
138+
transition_data = {}
139+
for t in transitions:
140+
score, _data = await self.check_transition_condition(t)
141+
if score > best_score:
142+
best_transition = t
143+
best_score = score
144+
transition_data = _data
145+
if best_transition:
146+
logger.debug(f"FSM from ---> {self.current_state}")
147+
self.current_state = self.get_state_by_name(best_transition.dest)
148+
logger.debug(f"FSM to -----> {self.current_state}")
149+
await self.run_current_state_events(transition_data)
144150

145151
async def run_current_state_events(self, transition_data=None):
146152
"""

back/back/apps/language_model/consumers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ async def query_llm(
265265
# Decrypt the API key from the LLMConfig if available.
266266
api_key = None
267267
if llm_config.api_key:
268-
from back.utils import get_light_bringer
268+
from back.utils.encrypt import get_light_bringer
269269
lb = get_light_bringer()
270270
api_key = llm_config.api_key.decrypt(lb)
271271

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Generated by Django 4.2.17 on 2025-02-18 15:08
2+
3+
import back.utils.encrypt._django
4+
from django.db import migrations
5+
6+
7+
class Migration(migrations.Migration):
8+
9+
dependencies = [
10+
('language_model', '0068_historicalllmconfig_api_key_llmconfig_api_key'),
11+
]
12+
13+
operations = [
14+
migrations.AlterField(
15+
model_name='historicalllmconfig',
16+
name='api_key',
17+
field=back.utils.encrypt._django.NissaStringField(blank=True, editable=True, help_text='Optional API key for the LLM. This value will be stored encrypted. Note: API keys can only be saved when encryption is properly configured.', null=True),
18+
),
19+
migrations.AlterField(
20+
model_name='llmconfig',
21+
name='api_key',
22+
field=back.utils.encrypt._django.NissaStringField(blank=True, editable=True, help_text='Optional API key for the LLM. This value will be stored encrypted. Note: API keys can only be saved when encryption is properly configured.', null=True),
23+
),
24+
]

back/back/apps/language_model/models/rag_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222
from back.apps.language_model.tasks import index_task
2323
from back.common.models import ChangesMixin
24-
from back.utils import NissaStringField
24+
from back.utils.encrypt import NissaStringField
2525

2626
logger = getLogger(__name__)
2727

back/back/common/abs/bot_consumers/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ def gather_initial_conversation_metadata(self, mml: "Message"):
127127
raise NotImplemented(
128128
"Implement a method that creates/gathers the conversation covnersation metadata"
129129
)
130+
131+
def gather_fsm_state_overwrite(self, mml: "Message"):
132+
raise NotImplemented(
133+
"Implement a method that gathers the fsm state overwrite"
134+
)
130135

131136
def gather_conversation_id(self, mml: "Message"):
132137
raise NotImplemented(
@@ -147,13 +152,14 @@ async def authenticate(self):
147152
and not isinstance(self.scope["user"], AnonymousUser)
148153
)
149154

150-
async def set_conversation(self, platform_conversation_id, initial_conversation_metadata, authentication_required):
155+
async def set_conversation(self, platform_conversation_id, initial_conversation_metadata, authentication_required, fsm_state_overwrite=None):
151156
self.conversation, created = await Conversation.objects.aget_or_create(
152157
platform_conversation_id=platform_conversation_id
153158
)
154159
if created:
155160
self.conversation.initial_conversation_metadata = initial_conversation_metadata
156161
self.conversation.authentication_required = authentication_required
162+
self.conversation.fsm_state_overwrite = fsm_state_overwrite
157163
await database_sync_to_async(self.conversation.save)()
158164

159165
def set_fsm_def(self, fsm_def):

back/back/common/abs/bot_consumers/http.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ async def handle(self, body):
6969
await self.set_conversation(
7070
self.gather_conversation_id(serializer.validated_data),
7171
await self.gather_initial_conversation_metadata(serializer.validated_data),
72-
self.fsm_def.authentication_required
72+
self.fsm_def.authentication_required,
73+
await self.gather_fsm_state_overwrite(serializer.validated_data)
7374
)
7475

7576
mml = await database_sync_to_async(serializer.to_mml)(self)

back/back/common/abs/bot_consumers/ws.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ async def connect(self):
2828
self.set_fsm_def(fsm_def)
2929
self.set_user_id(await self.gather_user_id())
3030

31-
await self.set_conversation(self.gather_conversation_id(), await self.gather_initial_conversation_metadata(), self.fsm_def.authentication_required)
31+
await self.set_conversation(self.gather_conversation_id(), await self.gather_initial_conversation_metadata(), self.fsm_def.authentication_required, await self.gather_fsm_state_overwrite())
3232
if not await self.authenticate():
3333
await self.close(3000, reason="`Authentication failed`")
3434

0 commit comments

Comments
 (0)