44from ray import serve
55
66from ragatouille import RAGPretrainedModel
7-
7+ from back . apps . language_model . ray_tasks import get_filesystem
88from chat_rag .inf_retrieval .reference_checker import clean_relevant_references
99
1010
@@ -22,21 +22,22 @@ class ColBERTDeployment:
2222 ColBERTDeployment class for serving the a ColBERT retriever in a Ray Serve deployment in a Ray cluster.
2323 """
2424
25- def __init__ (self , index_path ):
25+ def __init__ (self , index_path , remote_ray_cluster , storages_mode ):
2626 print (f"Initializing ColBERTDeployment" )
2727
28- index_path = self ._read_index (index_path )
28+ index_path = self ._read_index (index_path , remote_ray_cluster , storages_mode )
2929
3030 self .retriever = RAGPretrainedModel .from_index (index_path )
3131
3232 # Test query for loading the searcher for the first time
3333 self .retriever .search ("test query" , k = 1 )
3434 print (f"ColBERTDeployment initialized with index_path={ index_path } " )
3535
36- def _read_index (self , index_path ):
36+ def _read_index (self , index_path , remote_ray_cluster , storages_mode ):
3737 """
3838 If the index_path is an S3 path, read the index from object storage and write it to the local storage.
3939 """
40+ fs_ref = get_filesystem .remote (storages_mode )
4041
4142 from tqdm import tqdm
4243
@@ -61,11 +62,37 @@ def write_file(row, base_path):
6162 print (f"Reading index from { index_path } " )
6263
6364 if "s3://" in index_path :
64- index = ray .data .read_binary_files (index_path , include_paths = True )
65- print (f"Read { index .count ()} files from S3" )
65+
66+ from pyarrow .fs import FileSelector
67+
68+ fs = ray .get (fs_ref )
69+
70+ if fs is not None :
71+ # unwrap the filesystem object
72+ fs = fs .unwrap ()
73+
74+ files = fs .get_file_info (FileSelector (index_path .split ('s3://' )[1 ]))
75+
76+ file_paths = [file .path for file in files ]
77+
78+ # print total index size in MB
79+ print (f"Downloading index with size: { sum (file .size for file in files ) / 1024 / 1024 :.2f} MB" )
80+
81+ index = ray .data .read_parquet_bulk (file_paths , filesystem = fs )
82+ print (f"Downloaded { index .count ()} files from S3" )
6683
6784 index_name = os .path .basename (index_path )
68- index_path = os .path .join ("/" , "indexes" , index_name )
85+ index_path = os .path .join ("indexes" , index_name )
86+
87+ # If remote Ray cluster running on containers
88+ if remote_ray_cluster :
89+ index_path = os .path .join ("/" , index_path )
90+
91+ # if the directory exists, delete it
92+ if os .path .exists (index_path ):
93+ print (f"Deleting existing index at { index_path } " )
94+ import shutil
95+ shutil .rmtree (index_path )
6996
7097 print (f"Writing index to { index_path } " )
7198 for row in tqdm (index .iter_rows ()):
@@ -117,30 +144,32 @@ async def __call__(self, query: str, top_k: int):
117144 return await self .batch_handler (query , top_k )
118145
119146
120- def construct_index_path (index_path : str ):
147+ def construct_index_path (index_path : str , storages_mode , remote_ray_cluster : bool ):
121148 """
122149 Construct the index path based on the STORAGES_MODE environment variable.
123150 """
124151
125- STORAGES_MODE = os .environ .get ("STORAGES_MODE" , "local" )
126- if STORAGES_MODE == "local" :
127- exists_ray_cluster = os .getenv ("RAY_CLUSTER" , "False" ) == "True"
128- if exists_ray_cluster :
152+ if storages_mode == "local" :
153+ if remote_ray_cluster :
129154 return os .path .join (
130155 "/" , index_path
131156 ) # In the ray containers we mount local_storage/indexes/ as /indexes/
132157 else :
133158 return os .path .join ("back" , "back" , "local_storage" , index_path )
134- elif STORAGES_MODE in ["s3" , "do" ]:
159+ elif storages_mode in ["s3" , "do" ]:
135160 bucket_name = os .environ .get ("AWS_STORAGE_BUCKET_NAME" )
136161 return f"s3://{ bucket_name } /{ index_path } "
137162
138163
139164def launch_colbert (retriever_deploy_name , index_path ):
140165 print (f"Launching ColBERT deployment with name: { retriever_deploy_name } " )
141- index_path = construct_index_path (index_path )
166+
167+ storages_mode = os .environ .get ("STORAGES_MODE" , "local" )
168+ remote_ray_cluster = os .getenv ("RAY_CLUSTER" , "False" ) == "True"
169+
170+ index_path = construct_index_path (index_path , storages_mode , remote_ray_cluster )
142171 retriever_handle = ColBERTDeployment .options (
143172 name = retriever_deploy_name ,
144- ).bind (index_path )
173+ ).bind (index_path , remote_ray_cluster , storages_mode )
145174 print (f"Launched ColBERT deployment with name: { retriever_deploy_name } " )
146175 return retriever_handle
0 commit comments