11import io
22import logging
3+ import os
34from pathlib import Path
45from typing import TYPE_CHECKING , Dict , Iterator , List , Optional , Type , TypeVar
56
2122
2223SelfS3DocStore = TypeVar ('SelfS3DocStore' , bound = 'S3DocStore' )
2324
25+ def get_transport_params ():
26+ transport_params = {}
27+ if os .environ .get ("AWS_ACCESS_KEY_ID" ) or os .environ .get ("AWS_SECRET_ACCESS_KEY" ):
28+ session = boto3 .Session (
29+ aws_access_key_id = os .environ .get ("AWS_ACCESS_KEY_ID" ),
30+ aws_secret_access_key = os .environ .get ("AWS_SECRET_ACCESS_KEY" )
31+ )
32+ if os .environ .get ("S3_ENDPOINT_URL" ):
33+ transport_params ["client" ] = session .client ("s3" , endpoint_url = os .environ .get ("S3_ENDPOINT_URL" ))
34+ else :
35+ transport_params ["client" ] = session .client ("s3" )
36+
37+ if transport_params :
38+ return transport_params
39+ else :
40+ return None
2441
2542class _BufferedCachingReader :
2643 """A buffered reader that writes to a cache file while reading."""
@@ -32,7 +49,14 @@ def __init__(
3249 self ._cache = None
3350 if cache_path :
3451 self ._cache_path = cache_path .with_suffix ('.tmp' )
35- self ._cache = open (self ._cache_path , 'wb' )
52+
53+ transport_params = None
54+ if os .environ .get ("S3_ENDPOINT_URL" ):
55+ transport_params = {
56+ "endpoint_url" : os .environ .get ("S3_ENDPOINT_URL" )
57+ }
58+
59+ self ._cache = open (self ._cache_path , 'wb' , transport_params = get_transport_params ())
3660 self .closed = False
3761
3862 def read (self , size : Optional [int ] = - 1 ) -> bytes :
@@ -147,13 +171,15 @@ def push_stream(
147171 binary_stream = _to_binary_stream (
148172 docs , protocol = 'pickle' , compress = None , show_progress = show_progress
149173 )
174+ transport_params = get_transport_params ()
175+ transport_params ["multipart_upload" ] = False
150176
151177 # Upload to S3
152178 with open (
153179 f"s3://{ bucket } /{ name } .docs" ,
154180 'wb' ,
155181 compression = '.gz' ,
156- transport_params = { 'multipart_upload' : False } ,
182+ transport_params = transport_params ,
157183 ) as fout :
158184 while True :
159185 try :
@@ -206,9 +232,9 @@ def pull_stream(
206232
207233 save_name = name .replace ('/' , '_' )
208234 cache_path = _get_cache_path () / f'{ save_name } .docs'
209-
235+ transport_params = get_transport_params ()
210236 source = _BufferedCachingReader (
211- open (f"s3://{ bucket } /{ name } .docs" , 'rb' , compression = '.gz' ),
237+ open (f"s3://{ bucket } /{ name } .docs" , 'rb' , compression = '.gz' , transport_params = transport_params ),
212238 cache_path = cache_path if local_cache else None ,
213239 )
214240
@@ -221,7 +247,7 @@ def pull_stream(
221247 logging .info (
222248 f'Using cached file for { name } (size: { cache_path .stat ().st_size } )'
223249 )
224- source = open (cache_path , 'rb' )
250+ source = open (cache_path , 'rb' , transport_params = transport_params )
225251
226252 return _from_binary_stream (
227253 docs_cls .doc_type ,
0 commit comments