Skip to content

Commit 163dfce

Browse files
committed
Merge remote-tracking branch 'origin/develop' into develop
2 parents c8b25ed + 6279b10 commit 163dfce

15 files changed

Lines changed: 3802 additions & 2595 deletions

File tree

BOOK_OF_FRUSTRATIONS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ This file is intended for writing down all those TODOs/tech. deb. you know it ne
1010
- [ ] RAGs should not be loaded in shared celery memory, but in another service which celery calls.
1111
- [ ] RAG need to be deleted and created if the retriever type wants to be changed. This deletion and creation shouldn't be needed and the RAG should be able to handle this change.
1212
- [ ] Search correct faiss-gpu version and pin it.
13+
- [ ] Default PDF parser doesn't handle images yet.

back/poetry.lock

Lines changed: 579 additions & 548 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

back/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ pgvector = "^0.2.3"
3333
django-filter = "^23.2"
3434
django-storages = "*"
3535
sqlalchemy = "^2.0.16"
36-
chat-rag = {version = "0.1.54"}
36+
chat-rag = {version = "0.1.56"}
3737
gevent = "23.9.0"
3838
torch = [
3939
{ version = "^2.0.1", source = "torch" },

chat_rag/chat_rag/data/splitters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ class SmartSplitter:
241241
Splits the text into information meaningful chunks using the GPT-4 model.
242242
This can reach API rate limits very quickly.
243243
"""
244-
def __init__(self, model_name='gpt-4'):
244+
def __init__(self, model_name='gpt-4-0125-preview'):
245245
"""
246246
Parameters
247247
----------

chat_rag/chat_rag/llms/claude_client.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
from typing import List, Dict
22
import os
33

4-
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
4+
from anthropic import Anthropic
55

66
from chat_rag.llms import RAGLLM, CONTEXT_PREFIX
77

88

99
class ClaudeChatModel(RAGLLM):
1010
def __init__(self, llm_name, **kwargs) -> None:
1111
self.llm_name = llm_name
12-
self.anthropic = Anthropic(
13-
api_key=os.environ["ANTHROPIC_API_KEY"],
12+
self.client = Anthropic(
13+
api_key=os.environ.get("ANTHROPIC_API_KEY"),
1414
)
1515

1616
def format_prompt(
1717
self,
18-
messages: List[Dict[str, str]],
1918
contexts: List[str],
2019
system_prefix: str,
2120
n_contexts_to_use: int = 3,
@@ -26,8 +25,6 @@ def format_prompt(
2625
Formats the prompt to be used by the model.
2726
Parameters
2827
----------
29-
messages : List[Tuple[str, str]]
30-
The messages to use for the prompt. Pair of (role, message).
3128
contexts : list
3229
The context to use.
3330
system_prefix : str
@@ -49,20 +46,14 @@ def format_prompt(
4946
list
5047
The formatted prompt.
5148
"""
52-
prompt = self.format_system_prompt(
49+
system_prompt = self.format_system_prompt(
5350
contexts=contexts,
5451
system_prefix=system_prefix,
5552
n_contexts_to_use=n_contexts_to_use,
5653
lang=lang,
5754
)
5855

59-
for message in messages:
60-
if message['role'] == 'user':
61-
prompt += f"{HUMAN_PROMPT} {message['content']}{AI_PROMPT}"
62-
elif message['role'] == 'assistant':
63-
prompt += " " + message['content']
64-
65-
return prompt
56+
return system_prompt
6657

6758
def generate(
6859
self,
@@ -93,18 +84,19 @@ def generate(
9384
The generated text.
9485
"""
9586

96-
prompt = self.format_prompt(messages, contexts, **prompt_structure_dict, lang=lang)
87+
system_prompt = self.format_prompt(contexts, **prompt_structure_dict, lang=lang)
9788

98-
completion = self.anthropic.completions.create(
89+
message = self.client.messages.create(
9990
model=self.llm_name,
100-
max_tokens_to_sample=generation_config_dict['max_new_tokens'],
91+
system=system_prompt,
92+
messages=messages,
93+
max_tokens=generation_config_dict['max_new_tokens'],
10194
temperature=generation_config_dict['temperature'],
10295
top_p=generation_config_dict['top_p'],
10396
top_k=generation_config_dict['top_k'],
104-
prompt=prompt,
10597
)
10698

107-
return completion.completion
99+
return message.content[0].text
108100

109101
def stream(
110102
self,
@@ -114,7 +106,7 @@ def stream(
114106
generation_config_dict: dict = None,
115107
lang: str = "en",
116108
**kwargs,
117-
) -> str:
109+
):
118110
"""
119111
Generate text from a prompt using the model.
120112
Parameters
@@ -135,20 +127,22 @@ def stream(
135127
The generated text.
136128
"""
137129

138-
prompt = self.format_prompt(messages, contexts, **prompt_structure_dict, lang=lang)
130+
system_prompt = self.format_prompt(contexts, **prompt_structure_dict, lang=lang)
139131

140-
stream = self.anthropic.completions.create(
132+
stream = self.client.messages.create(
141133
model=self.llm_name,
142-
max_tokens_to_sample=generation_config_dict['max_new_tokens'],
134+
system=system_prompt,
135+
messages=messages,
136+
max_tokens=generation_config_dict['max_new_tokens'],
143137
temperature=generation_config_dict['temperature'],
144138
top_p=generation_config_dict['top_p'],
145139
top_k=generation_config_dict['top_k'],
146-
prompt=prompt,
147140
stream=True,
148141
)
149142

150-
for completion in stream:
151-
yield completion.completion
143+
for event in stream:
144+
if event.type == "content_block_delta":
145+
yield event.delta.text
152146

153147

154148

0 commit comments

Comments
 (0)