Skip to content

Commit 8236013

Browse files
committed
Fix writing and reading indexes from DO
1 parent 8c48125 commit 8236013

6 files changed

Lines changed: 66 additions & 32 deletions

File tree

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ sw.*
245245
.out
246246

247247
back/models/
248-
back/indexes/
248+
*indexes/
249249
*.ipynb
250250
chat_rag/data/
251251
chat_rag/examples/

back/back/apps/language_model/ray_deployments/colbert_deployment.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from ray import serve
55

66
from ragatouille import RAGPretrainedModel
7-
7+
from back.apps.language_model.ray_tasks import get_filesystem
88
from chat_rag.inf_retrieval.reference_checker import clean_relevant_references
99

1010

@@ -22,21 +22,22 @@ class ColBERTDeployment:
2222
ColBERTDeployment class for serving the a ColBERT retriever in a Ray Serve deployment in a Ray cluster.
2323
"""
2424

25-
def __init__(self, index_path):
25+
def __init__(self, index_path, remote_ray_cluster, storages_mode):
2626
print(f"Initializing ColBERTDeployment")
2727

28-
index_path = self._read_index(index_path)
28+
index_path = self._read_index(index_path, remote_ray_cluster, storages_mode)
2929

3030
self.retriever = RAGPretrainedModel.from_index(index_path)
3131

3232
# Test query for loading the searcher for the first time
3333
self.retriever.search("test query", k=1)
3434
print(f"ColBERTDeployment initialized with index_path={index_path}")
3535

36-
def _read_index(self, index_path):
36+
def _read_index(self, index_path, remote_ray_cluster, storages_mode):
3737
"""
3838
If the index_path is an S3 path, read the index from object storage and write it to the local storage.
3939
"""
40+
fs_ref = get_filesystem.remote(storages_mode)
4041

4142
from tqdm import tqdm
4243

@@ -61,11 +62,37 @@ def write_file(row, base_path):
6162
print(f"Reading index from {index_path}")
6263

6364
if "s3://" in index_path:
64-
index = ray.data.read_binary_files(index_path, include_paths=True)
65-
print(f"Read {index.count()} files from S3")
65+
66+
from pyarrow.fs import FileSelector
67+
68+
fs = ray.get(fs_ref)
69+
70+
if fs is not None:
71+
# unwrap the filesystem object
72+
fs = fs.unwrap()
73+
74+
files = fs.get_file_info(FileSelector(index_path.split('s3://')[1]))
75+
76+
file_paths = [file.path for file in files]
77+
78+
# print total index size in MB
79+
print(f"Downloading index with size: {sum(file.size for file in files) / 1024 / 1024:.2f} MB")
80+
81+
index = ray.data.read_parquet_bulk(file_paths, filesystem=fs)
82+
print(f"Downloaded {index.count()} files from S3")
6683

6784
index_name = os.path.basename(index_path)
68-
index_path = os.path.join("/", "indexes", index_name)
85+
index_path = os.path.join("indexes", index_name)
86+
87+
# If remote Ray cluster running on containers
88+
if remote_ray_cluster:
89+
index_path = os.path.join("/", index_path)
90+
91+
# if the directory exists, delete it
92+
if os.path.exists(index_path):
93+
print(f"Deleting existing index at {index_path}")
94+
import shutil
95+
shutil.rmtree(index_path)
6996

7097
print(f"Writing index to {index_path}")
7198
for row in tqdm(index.iter_rows()):
@@ -117,30 +144,32 @@ async def __call__(self, query: str, top_k: int):
117144
return await self.batch_handler(query, top_k)
118145

119146

120-
def construct_index_path(index_path: str):
147+
def construct_index_path(index_path: str, storages_mode, remote_ray_cluster: bool):
121148
"""
122149
Construct the index path based on the STORAGES_MODE environment variable.
123150
"""
124151

125-
STORAGES_MODE = os.environ.get("STORAGES_MODE", "local")
126-
if STORAGES_MODE == "local":
127-
exists_ray_cluster = os.getenv("RAY_CLUSTER", "False") == "True"
128-
if exists_ray_cluster:
152+
if storages_mode == "local":
153+
if remote_ray_cluster:
129154
return os.path.join(
130155
"/", index_path
131156
) # In the ray containers we mount local_storage/indexes/ as /indexes/
132157
else:
133158
return os.path.join("back", "back", "local_storage", index_path)
134-
elif STORAGES_MODE in ["s3", "do"]:
159+
elif storages_mode in ["s3", "do"]:
135160
bucket_name = os.environ.get("AWS_STORAGE_BUCKET_NAME")
136161
return f"s3://{bucket_name}/{index_path}"
137162

138163

139164
def launch_colbert(retriever_deploy_name, index_path):
140165
print(f"Launching ColBERT deployment with name: {retriever_deploy_name}")
141-
index_path = construct_index_path(index_path)
166+
167+
storages_mode = os.environ.get("STORAGES_MODE", "local")
168+
remote_ray_cluster = os.getenv("RAY_CLUSTER", "False") == "True"
169+
170+
index_path = construct_index_path(index_path, storages_mode, remote_ray_cluster)
142171
retriever_handle = ColBERTDeployment.options(
143172
name=retriever_deploy_name,
144-
).bind(index_path)
173+
).bind(index_path, remote_ray_cluster, storages_mode)
145174
print(f"Launched ColBERT deployment with name: {retriever_deploy_name}")
146175
return retriever_handle
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .ray_tasks import generate_embeddings, parse_pdf, generate_titles, create_colbert_index
1+
from .ray_tasks import generate_embeddings, parse_pdf, generate_titles, create_colbert_index, get_filesystem

back/back/apps/language_model/ray_tasks/ray_tasks.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def get_filesystem(storages_mode):
8585

8686
print("Using Digital Ocean S3 filesystem")
8787

88-
fs_ref = ray.put(s3fs)
89-
return fs_ref
88+
return s3fs
9089

90+
# If None ray data autoinfers the filesystem
9191
return None
9292

9393

@@ -142,17 +142,19 @@ def get_num_gpus():
142142
# ------ ----
143143
# bytes binary
144144
# path string,
145-
fs_ref = ray.get(get_filesystem.remote(storages_mode))
146-
fs = ray.get(fs_ref)
147-
148-
# unwrap the filesystem object
149-
fs = fs.unwrap()
145+
fs_ref = get_filesystem.remote(storages_mode)
150146

151147
print('Reading index from local storage')
152148
local_index_path = 'local://' + os.path.join(os.getcwd(), local_index_path)
153149
index = ray.data.read_binary_files(local_index_path, include_paths=True)
154150

155151
print(f"Writing index to object storage {s3_index_path}")
152+
153+
fs = ray.get(fs_ref)
154+
if fs is not None:
155+
# unwrap the filesystem object
156+
fs = fs.unwrap()
157+
156158
# Then we can write the index to the cloud storage
157159
index.write_parquet(s3_index_path, filesystem=fs)
158160
print('Index written to object storage')

back/back/apps/language_model/tasks.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -344,9 +344,11 @@ def creates_index(rag_config):
344344
task_name = f"create_colbert_index_{rag_config.name}"
345345

346346
with connect_to_ray_cluster():
347-
347+
storages_mode = os.environ.get("STORAGES_MODE", "local")
348+
remote_ray_cluster = os.getenv("RAY_CLUSTER", "False") == "True"
349+
index_path = construct_index_path(s3_index_path, storages_mode, remote_ray_cluster)
348350
index_ref = ray_create_colbert_index.options(resources={"tasks": 1}, num_gpus=num_gpus, name=task_name).remote(
349-
colbert_name, bsize, device, construct_index_path(s3_index_path), storages_mode, contents_pk, contents
351+
colbert_name, bsize, device, index_path, storages_mode, contents_pk, contents
350352
)
351353

352354
# Delete all the contents from memory because they are not needed anymore and can be very large
@@ -447,17 +449,18 @@ def delete_index_files(s3_index_path):
447449
s3_index_path : str
448450
The unique index path.
449451
"""
450-
from django.core.files.storage import default_storage
452+
from back.config.storage_backends import select_private_storage
451453

452454
if s3_index_path:
455+
private_storage = select_private_storage()
453456
logger.info(f"Deleting index files from S3: {s3_index_path}")
454457
# List all files in the unique index path
455-
_, files = default_storage.listdir(s3_index_path)
458+
_, files = private_storage.listdir(s3_index_path)
456459
for file in files:
457460
# Construct the full path for each file
458461
file_path = os.path.join(s3_index_path, file)
459462
# Delete the file from S3
460-
default_storage.delete(file_path)
463+
private_storage.delete(file_path)
461464

462465
logger.info(f"Index files deleted from S3: {s3_index_path}")
463466

back/back/utils/ray_connection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ def initialize_or_check_ray():
9292
Initialize a Ray cluster locally or check if a remote Ray cluster is available.
9393
"""
9494
if not ray.is_initialized():
95-
exists_ray_cluster = (os.getenv('RAY_CLUSTER', 'False') == 'True')
96-
print(f"exists_ray_cluster: {exists_ray_cluster} {os.getenv('RAY_CLUSTER')}")
97-
if exists_ray_cluster:
95+
remote_ray_cluster = (os.getenv('RAY_CLUSTER', 'False') == 'True')
96+
print(f"remote_ray_cluster: {remote_ray_cluster} {os.getenv('RAY_CLUSTER')}")
97+
if remote_ray_cluster:
9898
if not check_remote_ray_cluster():
9999
logger.error(f"You provided a remote Ray Cluster address but the connection failed, these could be because of three reasons: ")
100100
# logger.error(f"1. The provided address is incorrect: {RAY_ADDRESS}")

0 commit comments

Comments
 (0)