Skip to content

Commit 42c402a

Browse files
committed
Created and endpoint for rag retrieve and updating the SDK with http requests for retrieve and feedback
1 parent b2f394d commit 42c402a

File tree

6 files changed

+77
-7
lines changed

6 files changed

+77
-7
lines changed

back/back/apps/broker/serializers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,9 @@ class StatsSerializer(serializers.Serializer):
118118
granularity = serializers.ChoiceField(
119119
required=False, choices=["year", "quarter", "month", "week", "day", "date", "time", "hour", "minute", "second"], default="day"
120120
)
121+
122+
123+
class RetrieverRequestSerializer(serializers.Serializer):
124+
retriever_config_name = serializers.CharField(required=True, allow_blank=False, allow_null=False)
125+
query = serializers.CharField(required=True, allow_blank=False, allow_null=False)
126+
top_k = serializers.IntegerField(default=3)

back/back/apps/broker/serializers/rpc.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
RPCNodeType,
1010
)
1111
from back.apps.broker.models.message import AgentType
12+
from back.apps.broker.serializers import RetrieverRequestSerializer
1213
from back.config.storage_backends import (
1314
PrivateMediaLocalStorage,
1415
select_private_storage,
@@ -168,12 +169,8 @@ class RPCPromptRequestSerializer(serializers.Serializer):
168169
bot_channel_name = serializers.CharField()
169170

170171

171-
class RPCRetrieverRequestSerializer(serializers.Serializer):
172-
172+
class RPCRetrieverRequestSerializer(RetrieverRequestSerializer):
173173
retriever_config_name = serializers.CharField(required=True, allow_blank=False, allow_null=False)
174-
bot_channel_name = serializers.CharField()
175-
query = serializers.CharField(required=True, allow_blank=False, allow_null=False)
176-
top_k = serializers.IntegerField(default=3)
177174

178175

179176
class RegisterParsersSerializer(serializers.Serializer):

back/back/apps/language_model/urls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@
2222
urlpatterns += [
2323
path("tasks/", back.apps.language_model.views.tasks.ListTasksAPI.as_view()),
2424
path("ray-status/", back.apps.language_model.views.tasks.RayStatusAPI.as_view()),
25+
path("retrieve/", back.apps.language_model.views.rag_pipeline.RetrieveAPI.as_view()),
2526
]

back/back/apps/language_model/views/rag_pipeline.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from rest_framework import viewsets
1+
from rest_framework import views, viewsets
22
from django.http import JsonResponse
33
from rest_framework.decorators import action
44

5+
from back.apps.broker.serializers.rpc import RetrieverRequestSerializer
6+
from back.apps.language_model.consumers import query_retriever
57
from back.apps.language_model.models.rag_pipeline import LLMConfig, GenerationConfig, PromptConfig, RetrieverConfig
68
from back.apps.language_model.serializers.rag_pipeline import LLMConfigSerializer, \
79
GenerationConfigSerializer, PromptConfigSerializer, RetrieverConfigSerializer
@@ -74,3 +76,25 @@ class PromptConfigAPIViewSet(viewsets.ModelViewSet):
7476
filter_backends = [DjangoFilterBackend, SearchFilter, OrderingFilter]
7577
filterset_fields = ["id", "name"]
7678
search_fields = ['name']
79+
80+
81+
class RetrieveAPI(views.APIView):
82+
def get(self, request):
83+
serializer = RetrieverRequestSerializer(data=request.data)
84+
if not serializer.is_valid():
85+
return Response(
86+
{"error": serializer.errors},
87+
status=status.HTTP_400_BAD_REQUEST
88+
)
89+
90+
data = serializer.validated_data
91+
result = query_retriever(
92+
data["retriever_config_name"],
93+
data["query"],
94+
data.get("top_k"),
95+
)
96+
97+
if result.get("error"):
98+
return Response(result, status=status.HTTP_400_BAD_REQUEST)
99+
100+
return Response(result, status=status.HTTP_200_OK)

sdk/chatfaq_sdk/__init__.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,20 @@ async def send_retriever_request(
387387
)
388388
)
389389

390+
async def retriever_request(
391+
self, retriever_config_name, query, top_k=3
392+
):
393+
async with httpx.AsyncClient() as client:
394+
response = await client.get(
395+
urllib.parse.urljoin(
396+
self.chatfaq_http,
397+
f"back/api/language-model/retrieve/?retriever_config__name={retriever_config_name}&query={query}&top_k={top_k}"
398+
),
399+
headers={"Authorization": f"Token {self.token}"},
400+
)
401+
response.raise_for_status()
402+
return response.json()
403+
390404
async def query_kis(self, knowledge_base_name, query: Optional[dict] = None) -> List[KnowledgeItem]:
391405
res = []
392406
offset = 0
@@ -423,6 +437,34 @@ async def query_prompt(self, prompt_name) -> List[KnowledgeItem]:
423437
if results:
424438
return results[0]["prompt"]
425439

440+
async def submit_feedback_last_bot_msg(self, ctx: dict, value: str, comment: str):
441+
conv_mml = ctx.get("conv_mml", [])
442+
last_bot_message_id = None
443+
for i in reversed(conv_mml):
444+
if i["sender"]["type"] == "bot" and i["stack"]:
445+
last_bot_message_id = i["id"]
446+
break
447+
if last_bot_message_id:
448+
async with httpx.AsyncClient() as client:
449+
response = await client.post(
450+
urllib.parse.urljoin(
451+
self.chatfaq_http,
452+
f"back/api/broker/user-feedback/",
453+
),
454+
json={
455+
"message_source": last_bot_message_id,
456+
"message_target": last_bot_message_id,
457+
"feedback_data": {
458+
"thumb_value": value,
459+
"feedback_comment": comment,
460+
},
461+
},
462+
headers={"Authorization": f"Token {self.token}"},
463+
)
464+
response.raise_for_status()
465+
return response.json()
466+
return False
467+
426468
async def send_prompt_request(self, prompt_config_name, bot_channel_name):
427469
logger.info(f"[PROMPT] Requesting Prompt ({prompt_config_name})")
428470
self.prompt_request_futures[bot_channel_name] = (

sdk/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "chatfaq_sdk"
3-
version = "0.1.47"
3+
version = "0.1.48"
44
description = "ChatFAQ SDK"
55
authors = ["Hector Soria <[email protected]>"]
66
readme = "README.md"

0 commit comments

Comments
 (0)