Skip to content

Commit 4d2b861

Browse files
author
Andrew Zhai
committed
respect s3 settings
1 parent e64a595 commit 4d2b861

File tree

1 file changed

+31
-5
lines changed

1 file changed

+31
-5
lines changed

docarray/store/s3.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import io
22
import logging
3+
import os
34
from pathlib import Path
45
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Type, TypeVar
56

@@ -21,6 +22,22 @@
2122

2223
SelfS3DocStore = TypeVar('SelfS3DocStore', bound='S3DocStore')
2324

25+
def get_transport_params():
26+
transport_params = {}
27+
if os.environ.get("AWS_ACCESS_KEY_ID") or os.environ.get("AWS_SECRET_ACCESS_KEY"):
28+
session = boto3.Session(
29+
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
30+
aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY")
31+
)
32+
if os.environ.get("S3_ENDPOINT_URL"):
33+
transport_params["client"] = session.client("s3", endpoint_url=os.environ.get("S3_ENDPOINT_URL"))
34+
else:
35+
transport_params["client"] = session.client("s3")
36+
37+
if transport_params:
38+
return transport_params
39+
else:
40+
return None
2441

2542
class _BufferedCachingReader:
2643
"""A buffered reader that writes to a cache file while reading."""
@@ -32,7 +49,14 @@ def __init__(
3249
self._cache = None
3350
if cache_path:
3451
self._cache_path = cache_path.with_suffix('.tmp')
35-
self._cache = open(self._cache_path, 'wb')
52+
53+
transport_params = None
54+
if os.environ.get("S3_ENDPOINT_URL"):
55+
transport_params = {
56+
"endpoint_url": os.environ.get("S3_ENDPOINT_URL")
57+
}
58+
59+
self._cache = open(self._cache_path, 'wb', transport_params=get_transport_params())
3660
self.closed = False
3761

3862
def read(self, size: Optional[int] = -1) -> bytes:
@@ -147,13 +171,15 @@ def push_stream(
147171
binary_stream = _to_binary_stream(
148172
docs, protocol='pickle', compress=None, show_progress=show_progress
149173
)
174+
transport_params = get_transport_params()
175+
transport_params["multipart_upload"] = False
150176

151177
# Upload to S3
152178
with open(
153179
f"s3://{bucket}/{name}.docs",
154180
'wb',
155181
compression='.gz',
156-
transport_params={'multipart_upload': False},
182+
transport_params=transport_params,
157183
) as fout:
158184
while True:
159185
try:
@@ -206,9 +232,9 @@ def pull_stream(
206232

207233
save_name = name.replace('/', '_')
208234
cache_path = _get_cache_path() / f'{save_name}.docs'
209-
235+
transport_params = get_transport_params()
210236
source = _BufferedCachingReader(
211-
open(f"s3://{bucket}/{name}.docs", 'rb', compression='.gz'),
237+
open(f"s3://{bucket}/{name}.docs", 'rb', compression='.gz', transport_params=transport_params),
212238
cache_path=cache_path if local_cache else None,
213239
)
214240

@@ -221,7 +247,7 @@ def pull_stream(
221247
logging.info(
222248
f'Using cached file for {name} (size: {cache_path.stat().st_size})'
223249
)
224-
source = open(cache_path, 'rb')
250+
source = open(cache_path, 'rb', transport_params=transport_params)
225251

226252
return _from_binary_stream(
227253
docs_cls.doc_type,

0 commit comments

Comments
 (0)