Skip to content

Commit 7022005

Browse files
committed
Minor fixes
1 parent cacc440 commit 7022005

5 files changed

Lines changed: 96 additions & 58 deletions

File tree

github/server/cache.go

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -211,47 +211,54 @@ func (s *Server) handleGetCacheEntryDownloadURL(w http.ResponseWriter, r *http.R
211211
scope = req.Metadata.Scope
212212
}
213213

214-
s.mu.RLock()
215-
defer s.mu.RUnlock()
214+
type match struct {
215+
id int64
216+
key string
217+
}
218+
219+
var found *match
216220

221+
s.mu.RLock()
217222
// 1. Exact match: scope + key + version
218223
exactKey := scope + "/" + req.Key + "/" + req.Version
219224
if entry, ok := s.caches[exactKey]; ok && entry.Finalized {
220-
downloadURL := s.makeSignedURL("GET", entry.ID)
221-
writeJSON(w, http.StatusOK, GetCacheEntryDownloadURLResponse{
222-
Ok: true,
223-
SignedDownloadURL: downloadURL,
224-
MatchedKey: entry.Key,
225-
})
226-
return
225+
found = &match{id: entry.ID, key: entry.Key}
227226
}
228227

229228
// 2. Prefix match with restore_keys
230-
for _, rk := range req.RestoreKeys {
231-
var best *CacheEntry
232-
for _, entry := range s.caches {
233-
if entry.Scope != scope || entry.Version != req.Version {
234-
continue
235-
}
236-
if !entry.Finalized {
237-
continue
238-
}
239-
if !strings.HasPrefix(entry.Key, rk) {
240-
continue
229+
if found == nil {
230+
for _, rk := range req.RestoreKeys {
231+
var best *CacheEntry
232+
for _, entry := range s.caches {
233+
if entry.Scope != scope || entry.Version != req.Version {
234+
continue
235+
}
236+
if !entry.Finalized {
237+
continue
238+
}
239+
if !strings.HasPrefix(entry.Key, rk) {
240+
continue
241+
}
242+
if best == nil || entry.CreatedAt.After(best.CreatedAt) {
243+
best = entry
244+
}
241245
}
242-
if best == nil || entry.CreatedAt.After(best.CreatedAt) {
243-
best = entry
246+
if best != nil {
247+
found = &match{id: best.ID, key: best.Key}
248+
break
244249
}
245250
}
246-
if best != nil {
247-
downloadURL := s.makeSignedURL("GET", best.ID)
248-
writeJSON(w, http.StatusOK, GetCacheEntryDownloadURLResponse{
249-
Ok: true,
250-
SignedDownloadURL: downloadURL,
251-
MatchedKey: best.Key,
252-
})
253-
return
254-
}
251+
}
252+
s.mu.RUnlock()
253+
254+
if found != nil {
255+
downloadURL := s.makeSignedURL("GET", found.id)
256+
writeJSON(w, http.StatusOK, GetCacheEntryDownloadURLResponse{
257+
Ok: true,
258+
SignedDownloadURL: downloadURL,
259+
MatchedKey: found.key,
260+
})
261+
return
255262
}
256263

257264
writeTwirpError(w, http.StatusNotFound, "not_found", "cache entry not found")

github/server/legacy.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ func (s *Server) handleLegacyCreate(w http.ResponseWriter, r *http.Request) {
7070

7171
// PUT /v3-upload/{containerId}?itemPath={path}
7272
func (s *Server) handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
73+
if !hasBearer(r) {
74+
http.Error(w, "unauthorized", http.StatusUnauthorized)
75+
return
76+
}
77+
7378
cidStr := r.PathValue("containerId")
7479
cid, err := strconv.ParseInt(cidStr, 10, 64)
7580
if err != nil {
@@ -114,7 +119,11 @@ func (s *Server) handleLegacyUpload(w http.ResponseWriter, r *http.Request) {
114119
return
115120
}
116121
if start > 0 {
117-
f.Seek(start, io.SeekStart)
122+
if _, err := f.Seek(start, io.SeekStart); err != nil {
123+
f.Close()
124+
http.Error(w, "storage error", http.StatusInternalServerError)
125+
return
126+
}
118127
}
119128
n, copyErr := io.Copy(f, r.Body)
120129
if err := f.Close(); err != nil && copyErr == nil {
@@ -199,6 +208,11 @@ func (s *Server) handleLegacyList(w http.ResponseWriter, r *http.Request) {
199208

200209
// GET /download-v3/{containerId}?itemPath={prefix}
201210
func (s *Server) handleLegacyListFiles(w http.ResponseWriter, r *http.Request) {
211+
if !hasBearer(r) {
212+
http.Error(w, "unauthorized", http.StatusUnauthorized)
213+
return
214+
}
215+
202216
cidStr := r.PathValue("containerId")
203217
cid, err := strconv.ParseInt(cidStr, 10, 64)
204218
if err != nil {
@@ -239,6 +253,11 @@ func (s *Server) handleLegacyListFiles(w http.ResponseWriter, r *http.Request) {
239253

240254
// GET /artifact/{path...}
241255
func (s *Server) handleLegacyDownload(w http.ResponseWriter, r *http.Request) {
256+
if !hasBearer(r) {
257+
http.Error(w, "unauthorized", http.StatusUnauthorized)
258+
return
259+
}
260+
242261
fullPath := r.PathValue("path")
243262
before, after, ok := strings.Cut(fullPath, "/")
244263
if !ok {

github/server/legacy_test.go

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ func TestLegacyFullCycle(t *testing.T) {
6060
// 2. Upload file
6161
req, _ := http.NewRequest("PUT", uploadURL+"?itemPath=data/file.txt", bytes.NewReader(fileContent))
6262
req.Header.Set("Content-Type", "application/octet-stream")
63+
req.Header.Set("Authorization", "Bearer test-token")
6364
uploadResp, err := http.DefaultClient.Do(req)
6465
if err != nil {
6566
t.Fatalf("upload: %v", err)
@@ -115,7 +116,9 @@ func TestLegacyFullCycle(t *testing.T) {
115116
contentLocation := filesResult.Value[0]["contentLocation"].(string)
116117

117118
// 6. Download file
118-
dlResp, err := http.Get(contentLocation)
119+
dlReq, _ := http.NewRequest("GET", contentLocation, nil)
120+
dlReq.Header.Set("Authorization", "Bearer test-token")
121+
dlResp, err := http.DefaultClient.Do(dlReq)
119122
if err != nil {
120123
t.Fatalf("download: %v", err)
121124
}
@@ -148,6 +151,7 @@ func TestLegacyChunkedUpload(t *testing.T) {
148151
// Upload chunk 1
149152
req, _ := http.NewRequest("PUT", uploadURL+"?itemPath=file.bin", bytes.NewReader(chunk1))
150153
req.Header.Set("Content-Range", fmt.Sprintf("bytes 0-%d/%d", len(chunk1)-1, total))
154+
req.Header.Set("Authorization", "Bearer test-token")
151155
r, _ := http.DefaultClient.Do(req)
152156
r.Body.Close()
153157
if r.StatusCode != http.StatusCreated {
@@ -157,6 +161,7 @@ func TestLegacyChunkedUpload(t *testing.T) {
157161
// Upload chunk 2
158162
req, _ = http.NewRequest("PUT", uploadURL+"?itemPath=file.bin", bytes.NewReader(chunk2))
159163
req.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", len(chunk1), total-1, total))
164+
req.Header.Set("Authorization", "Bearer test-token")
160165
r, _ = http.DefaultClient.Do(req)
161166
r.Body.Close()
162167
if r.StatusCode != http.StatusCreated {
@@ -184,7 +189,9 @@ func TestLegacyChunkedUpload(t *testing.T) {
184189
resp.Body.Close()
185190
contentLocation := filesResult.Value[0]["contentLocation"].(string)
186191

187-
dlResp, _ := http.Get(contentLocation)
192+
dlReq, _ := http.NewRequest("GET", contentLocation, nil)
193+
dlReq.Header.Set("Authorization", "Bearer test-token")
194+
dlResp, _ := http.DefaultClient.Do(dlReq)
188195
defer dlResp.Body.Close()
189196
data, _ := io.ReadAll(dlResp.Body)
190197
expected := append(chunk1, chunk2...)
@@ -215,6 +222,7 @@ func TestLegacyGzipRoundtrip(t *testing.T) {
215222
// Upload with Content-Encoding: gzip
216223
req, _ := http.NewRequest("PUT", uploadURL+"?itemPath=data.gz", bytes.NewReader(gzipData))
217224
req.Header.Set("Content-Encoding", "gzip")
225+
req.Header.Set("Authorization", "Bearer test-token")
218226
r, _ := http.DefaultClient.Do(req)
219227
r.Body.Close()
220228
if r.StatusCode != http.StatusCreated {
@@ -239,7 +247,9 @@ func TestLegacyGzipRoundtrip(t *testing.T) {
239247
// Use raw HTTP transport to avoid automatic decompression
240248
transport := &http.Transport{DisableCompression: true}
241249
client := &http.Client{Transport: transport}
242-
dlResp, _ := client.Get(filesResult.Value[0]["contentLocation"].(string))
250+
dlReq, _ := http.NewRequest("GET", filesResult.Value[0]["contentLocation"].(string), nil)
251+
dlReq.Header.Set("Authorization", "Bearer test-token")
252+
dlResp, _ := client.Do(dlReq)
243253
defer dlResp.Body.Close()
244254

245255
if dlResp.Header.Get("Content-Encoding") != "gzip" {
@@ -272,6 +282,7 @@ func TestLegacyMultipleFiles(t *testing.T) {
272282

273283
for path, content := range files {
274284
req, _ := http.NewRequest("PUT", uploadURL+"?itemPath="+path, bytes.NewReader([]byte(content)))
285+
req.Header.Set("Authorization", "Bearer test-token")
275286
r, _ := http.DefaultClient.Do(req)
276287
r.Body.Close()
277288
}
@@ -319,14 +330,14 @@ func TestLegacyNotFound(t *testing.T) {
319330
}
320331

321332
// Download non-existent container
322-
dlResp, _ := http.Get(ts.URL + "/download-v3/9999")
333+
dlResp := legacyRequest(t, ts, "GET", "/download-v3/9999", nil)
323334
dlResp.Body.Close()
324335
if dlResp.StatusCode != http.StatusNotFound {
325336
t.Fatalf("expected 404, got %d", dlResp.StatusCode)
326337
}
327338

328339
// Download non-existent file
329-
dlResp, _ = http.Get(ts.URL + "/artifact/9999/nofile.txt")
340+
dlResp = legacyRequest(t, ts, "GET", "/artifact/9999/nofile.txt", nil)
330341
dlResp.Body.Close()
331342
if dlResp.StatusCode != http.StatusNotFound {
332343
t.Fatalf("expected 404, got %d", dlResp.StatusCode)

github/server/server.go

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,16 @@ func (s *Server) handleCreateArtifact(w http.ResponseWriter, r *http.Request, ru
271271

272272
blobPath := s.artifactBlobPath(req.WorkflowRunBackendID, req.Name)
273273

274+
safeDir, err := utils.SafeJoinPath(s.storageDir, filepath.Base(req.WorkflowRunBackendID))
275+
if err != nil {
276+
writeTwirpError(w, http.StatusBadRequest, "invalid_argument", "invalid artifact path")
277+
return
278+
}
279+
if err := os.MkdirAll(safeDir, 0o755); err != nil {
280+
writeTwirpError(w, http.StatusInternalServerError, "internal", "failed to create storage directory")
281+
return
282+
}
283+
274284
s.mu.Lock()
275285
if existing, ok := s.artifacts[key]; ok && existing.Finalized {
276286
s.mu.Unlock()
@@ -298,16 +308,6 @@ func (s *Server) handleCreateArtifact(w http.ResponseWriter, r *http.Request, ru
298308
s.uploadMu[id] = &sync.Mutex{}
299309
s.mu.Unlock()
300310

301-
safeDir, err := utils.SafeJoinPath(s.storageDir, filepath.Base(req.WorkflowRunBackendID))
302-
if err != nil {
303-
writeTwirpError(w, http.StatusBadRequest, "invalid_argument", "invalid artifact path")
304-
return
305-
}
306-
if err := os.MkdirAll(safeDir, 0o755); err != nil {
307-
writeTwirpError(w, http.StatusInternalServerError, "internal", "failed to create storage directory")
308-
return
309-
}
310-
311311
uploadURL := s.makeSignedURL("PUT", id)
312312

313313
writeJSON(w, http.StatusOK, CreateArtifactResponse{
@@ -480,6 +480,16 @@ func (s *Server) handleMigrateArtifact(w http.ResponseWriter, r *http.Request, r
480480
key := req.WorkflowRunBackendID + "/" + req.Name
481481
blobPath := s.artifactBlobPath(req.WorkflowRunBackendID, req.Name)
482482

483+
safeDir, err := utils.SafeJoinPath(s.storageDir, filepath.Base(req.WorkflowRunBackendID))
484+
if err != nil {
485+
writeTwirpError(w, http.StatusBadRequest, "invalid_argument", "invalid artifact path")
486+
return
487+
}
488+
if err := os.MkdirAll(safeDir, 0o755); err != nil {
489+
writeTwirpError(w, http.StatusInternalServerError, "internal", "failed to create storage directory")
490+
return
491+
}
492+
483493
s.mu.Lock()
484494
if existing, ok := s.artifacts[key]; ok && existing.Finalized {
485495
s.mu.Unlock()
@@ -501,16 +511,6 @@ func (s *Server) handleMigrateArtifact(w http.ResponseWriter, r *http.Request, r
501511
s.uploadMu[id] = &sync.Mutex{}
502512
s.mu.Unlock()
503513

504-
safeDir, err := utils.SafeJoinPath(s.storageDir, filepath.Base(req.WorkflowRunBackendID))
505-
if err != nil {
506-
writeTwirpError(w, http.StatusBadRequest, "invalid_argument", "invalid artifact path")
507-
return
508-
}
509-
if err := os.MkdirAll(safeDir, 0o755); err != nil {
510-
writeTwirpError(w, http.StatusInternalServerError, "internal", "failed to create storage directory")
511-
return
512-
}
513-
514514
uploadURL := s.makeSignedURL("PUT", id)
515515

516516
writeJSON(w, http.StatusOK, MigrateArtifactResponse{

github/server/start.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package server
22

33
import (
44
"crypto/rand"
5+
"errors"
56
"fmt"
67
"net"
78
"net/http"
@@ -85,7 +86,7 @@ func StartServer(cfg Config) (*RunningServer, error) {
8586
}
8687

8788
go func() {
88-
if err := httpServer.Serve(listener); err != nil {
89+
if err := httpServer.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
8990
fmt.Fprintf(os.Stderr, "server error: %v\n", err)
9091
}
9192
}()

0 commit comments

Comments
 (0)