Skip to content

Commit f5c673e

Browse files
committed
Merge branch 'feature/lef-70-canva-as-a-user-i-want-to-extend-the-chatfaq-widget-with-a' into new_file_upload_logic
2 parents 679cb02 + 8dac6ac commit f5c673e

24 files changed

Lines changed: 936 additions & 253 deletions

File tree

back/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ apply_fixtures:
3232
celery_worker:
3333
$(POETRY_RUN) celery -A back.config worker -l INFO -P solo
3434

35+
#
3536
# create_fsm_fixtures:
3637
# ./manage.py dumpdata fsm.fsmdefinition --indent 4 > back/apps/fsm/fixtures/initial.json
3738
#

back/back/apps/fsm/lib/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def next_state(self):
125125
is reached it makes sure everything is saved and cached into the DB to keep the system stateful
126126
"""
127127
transitions = self.get_current_state_transitions()
128-
128+
129129
if state_override := self.get_state_by_name(self.ctx.conversation.fsm_state_override):
130130
logger.debug("Overriding FSM state")
131131
logger.debug(f"FSM from ---> {self.current_state}")
@@ -148,6 +148,8 @@ async def next_state(self):
148148
logger.debug(f"FSM to -----> {self.current_state}")
149149
await self.run_current_state_events(transition_data)
150150

151+
await self.save_cache()
152+
151153
async def run_current_state_events(self, transition_data=None):
152154
"""
153155
It will call the RPC server, the procedure name is the event name declared in the fsm definition for the

back/back/apps/language_model/consumers/__init__.py

Lines changed: 119 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import time
23
import uuid
34
from logging import getLogger
45
from typing import Awaitable, Callable, Dict, List, Optional, Union
@@ -29,6 +30,8 @@
2930
from chat_rag.llms import load_llm
3031
from chat_rag.llms.types import Content, Message, ToolResult, ToolUse
3132

33+
from back.apps.health.models import Event
34+
3235
logger = getLogger(__name__)
3336

3437

@@ -198,6 +201,40 @@ async def resolve_references(reference_kis, retriever_config):
198201
}
199202

200203

204+
async def log_llm_event(
205+
event_type: str,
206+
is_success: bool,
207+
data: dict
208+
):
209+
"""
210+
Async function to log LLM-related events to the Event model.
211+
212+
Parameters:
213+
-----------
214+
event_type : str
215+
Type of event (e.g., 'llm_call_start', 'llm_call_complete')
216+
is_success : bool
217+
Whether the event represents a successful operation
218+
llm_call_id : str
219+
Unique identifier for the LLM call
220+
llm_config_name : str
221+
Name of the LLM configuration
222+
conversation_id : int
223+
ID of the conversation
224+
start_time : float
225+
Start time of the LLM call (used to calculate duration for 'complete' events)
226+
additional_data : dict, optional
227+
Any additional data to include in the event
228+
"""
229+
230+
# Create the event asynchronously
231+
await database_sync_to_async(Event.objects.create)(
232+
event_type=event_type,
233+
is_success=is_success,
234+
data=data
235+
)
236+
237+
201238
async def query_llm(
202239
llm_config_name: str,
203240
conversation_id: int,
@@ -221,18 +258,20 @@ async def query_llm(
221258
# if the llm config is mistral then return an error that mistral is not supported yet
222259
if llm_config.llm_type == LLMChoices.MISTRAL.value:
223260
await error_handler({
224-
"payload": {
225-
"errors": "Error: Mistral is temporarily unavailable. We're working to add support for it soon. For now, please select a different model like OpenAI.",
226-
"request_info": {"llm_config_name": llm_config_name},
227-
}
228-
})
261+
"errors": "Error: Mistral is temporarily unavailable. We're working to add support for it soon. For now, please select a different model provider like OpenAI.",
262+
"llm_config_name": llm_config_name,
263+
"conversation_id": conversation_id
264+
},
265+
event_type="llm_config_not_found"
266+
)
229267
except LLMConfig.DoesNotExist:
230268
await error_handler({
231-
"payload": {
232269
"errors": f"LLM config with name: {llm_config_name} does not exist.",
233-
"request_info": {"llm_config_name": llm_config_name},
234-
}
235-
})
270+
"llm_config_name": llm_config_name,
271+
"conversation_id": conversation_id
272+
},
273+
event_type="llm_config_not_found"
274+
)
236275
return
237276

238277
conv = await database_sync_to_async(Conversation.objects.get)(pk=conversation_id)
@@ -252,25 +291,28 @@ async def query_llm(
252291
messages = messages[1:]
253292
elif not prev_messages:
254293
await error_handler({
255-
"payload": {
256-
"errors": "Error: No previous messages and no messages provided.",
257-
"request_info": {"conversation_id": conversation_id},
258-
}
259-
})
294+
"errors": "Error: No previous messages and no messages provided.",
295+
"conversation_id": conversation_id,
296+
},
297+
)
260298
return
261299
if messages:
262300
new_messages.extend(messages)
263301
else:
264302
new_messages = messages
265303
if new_messages is None:
266304
await error_handler({
267-
"payload": {
268-
"errors": "Error: No messages provided.",
269-
"request_info": {"conversation_id": conversation_id},
270-
}
271-
})
305+
"errors": "Error: No messages provided.",
306+
"conversation_id": conversation_id,
307+
},
308+
)
272309
return
273310

311+
312+
# Generate a unique ID for this LLM call
313+
llm_call_id = str(uuid.uuid4())
314+
start_time = time.perf_counter()
315+
274316
try:
275317
# Decrypt the API key from the LLMConfig if available.
276318
api_key = None
@@ -288,6 +330,25 @@ async def query_llm(
288330
api_key=api_key,
289331
)
290332

333+
await log_llm_event(
334+
event_type="llm_call_start",
335+
is_success=True,
336+
data={
337+
"llm_call_id": llm_call_id,
338+
"llm_config_name": llm_config_name,
339+
"conversation_id": conversation_id,
340+
"temperature": temperature,
341+
"max_tokens": max_tokens,
342+
"seed": seed,
343+
"tools": tools,
344+
"tool_choice": tool_choice,
345+
"messages": new_messages,
346+
"schema": response_schema,
347+
"cache_config": cache_config,
348+
"stream": stream,
349+
}
350+
)
351+
291352
if response_schema:
292353
response_message = await llm.aparse(
293354
messages=new_messages,
@@ -356,14 +417,28 @@ async def query_llm(
356417
"last_chunk": True,
357418
}
358419

420+
await log_llm_event(
421+
event_type="llm_call_complete",
422+
is_success=True,
423+
data={
424+
"llm_call_id": llm_call_id,
425+
"duration_seconds": time.perf_counter() - start_time,
426+
}
427+
)
428+
359429
except Exception as e:
360-
logger.error("Error during LLM query", exc_info=e)
361-
await error_handler({
362-
"payload": {
430+
logger.exception(f"Error during llm call: {e}")
431+
await error_handler(
432+
{
363433
"errors": "There was an error generating the response. Please try again or contact the administrator.",
364-
"request_info": {"conversation_id": conversation_id},
365-
}
366-
})
434+
"error_message": str(e),
435+
"llm_config_name": llm_config_name,
436+
"conversation_id": conversation_id,
437+
"llm_call_id": llm_call_id,
438+
"duration_seconds": time.perf_counter() - start_time,
439+
},
440+
event_type="llm_call_complete",
441+
)
367442
return
368443

369444

@@ -449,6 +524,7 @@ async def process_llm_request(self, data):
449524

450525
lm_msg_id = str(uuid.uuid4())
451526
data = serializer.validated_data
527+
452528
async for chunk in query_llm(
453529
data["llm_config_name"],
454530
data["conversation_id"],
@@ -463,7 +539,7 @@ async def process_llm_request(self, data):
463539
data.get("cache_config"),
464540
data.get("response_schema"),
465541
data.get("stream"),
466-
error_handler=self.error_response,
542+
error_handler=self.llm_error_response,
467543
):
468544
await self.send(
469545
json.dumps(
@@ -555,9 +631,24 @@ async def process_prompt_request(self, data):
555631
}
556632
)
557633

558-
559-
560634
async def error_response(self, data: dict):
561635
data["status"] = WSStatusCodes.bad_request.value
562636
data["type"] = RPCMessageType.error.value
563637
await self.send(json.dumps(data))
638+
639+
async def llm_error_response(self, data: dict, event_type: str = None):
640+
if event_type:
641+
await log_llm_event(
642+
event_type=event_type,
643+
is_success=False,
644+
data=data
645+
)
646+
# This is info sent to the SDK, so don't send a detailed error message for now.
647+
return await self.error_response(
648+
{
649+
"payload": {
650+
"errors": data["errors"],
651+
"request_info": {"conversation_id": data["conversation_id"], "llm_config_name": data["llm_config_name"]},
652+
}
653+
}
654+
)

back/back/apps/language_model/urls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
urlpatterns += [
2323
path("tasks/", back.apps.language_model.views.tasks.ListTasksAPI.as_view()),
24+
path("query-llm/", back.apps.language_model.views.tasks.QueryLLM.as_view()),
2425
path("ray-status/", back.apps.language_model.views.tasks.RayStatusAPI.as_view()),
2526
path("retrieve/", back.apps.language_model.views.rag_pipeline.RetrieveAPI.as_view()),
2627
]

back/back/apps/language_model/views/tasks.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1+
import asyncio
2+
import json
13
from django.http import JsonResponse
2-
from rest_framework import status
4+
from django.views.decorators.csrf import csrf_exempt
5+
from django.utils.decorators import method_decorator
6+
from rest_framework import status, permissions
37
from rest_framework.response import Response
48
from back.apps.language_model.models import RayTaskState
59
from back.apps.language_model.serializers.tasks import RayTaskStateSerializer
610
from rest_framework.views import APIView
711
from drf_spectacular.utils import extend_schema
812
from django.conf import settings
9-
10-
13+
from django.views import View
14+
from back.apps.language_model.consumers import query_llm
15+
from back.apps.broker.serializers.rpc import RPCLLMRequestSerializer
16+
from back.config.middelware import KnoxAsyncAuthMixin
1117

1218
def get_paginated_response(data, limit, offset, count):
1319
return JsonResponse({
@@ -61,7 +67,30 @@ def get(self, request):
6167
page = [task for task in page if task.get('task_id') == task_id]
6268

6369
return get_paginated_response(page, limit, offset, len(data))
64-
70+
71+
@method_decorator(csrf_exempt, name='dispatch')
72+
class QueryLLM(KnoxAsyncAuthMixin, View):
73+
async def post(self, request):
74+
body_data = json.loads(request.body.decode('utf-8'))
75+
serializer = RPCLLMRequestSerializer(data=body_data)
76+
77+
if not serializer.is_valid():
78+
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
79+
80+
data = serializer.validated_data
81+
llm_config_name = data.pop("llm_config_name")
82+
conversation_id = data.pop("conversation_id")
83+
data.pop("bot_channel_name", None)
84+
85+
res = None
86+
async for _res in query_llm(
87+
llm_config_name,
88+
conversation_id,
89+
**data
90+
):
91+
res = _res
92+
93+
return JsonResponse(res, safe=False, status=status.HTTP_200_OK)
6594

6695
class RayStatusAPI(APIView):
6796
@extend_schema(

back/back/config/middelware.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
from back.apps.widget.models import Widget
1313
from urllib.parse import urlparse
1414

15+
from functools import wraps
16+
from asgiref.sync import sync_to_async
17+
from knox.auth import TokenAuthentication
18+
from django.http import JsonResponse
19+
1520
logger = getLogger(__name__)
1621

1722

@@ -90,3 +95,24 @@ def has_permission(self, request, view):
9095
if fnmatch.fnmatch(urlparse(origin).netloc, widget.domain):
9196
return True
9297
return False
98+
99+
class KnoxAsyncAuthMixin:
100+
"""Mixin to add Knox authentication to async class-based views."""
101+
102+
async def dispatch(self, request, *args, **kwargs):
103+
knox_auth = TokenAuthentication()
104+
105+
async def authenticate_async():
106+
try:
107+
return await sync_to_async(knox_auth.authenticate)(request)
108+
except Exception as e:
109+
return None
110+
111+
auth_result = await authenticate_async()
112+
113+
if auth_result:
114+
request.user = auth_result[0]
115+
request.auth = auth_result[1]
116+
return await super().dispatch(request, *args, **kwargs)
117+
else:
118+
return JsonResponse({"error": "Authentication failed"}, status=401)

back/back/config/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def pre_logging(self, env: EnvManager):
125125
return super().pre_logging(env)
126126

127127

128-
model_w_django = CustomPreset(enable_storages=not LOCAL_STORAGE, enable_celery=False)
128+
model_w_django = CustomPreset(enable_storages=not LOCAL_STORAGE, enable_celery=False, enable_wagtail=False)
129129

130130
with EnvManager(model_w_django) as env:
131131
# ---

0 commit comments

Comments
 (0)