Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/cmd/lazystack/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"flag"
"fmt"
"os"
Expand Down Expand Up @@ -43,7 +44,9 @@ func main() {
}

if *doUpdate {
latest, downloadURL, checksumsURL, err := selfupdate.CheckLatest(version)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
latest, downloadURL, checksumsURL, err := selfupdate.CheckLatest(ctx, version)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
Expand All @@ -53,7 +56,7 @@ func main() {
return
}
fmt.Printf("Updating lazystack %s → %s...\n", version, latest)
if err := selfupdate.Apply(downloadURL, checksumsURL); err != nil {
if err := selfupdate.Apply(ctx, downloadURL, checksumsURL); err != nil {
fmt.Fprintf(os.Stderr, "Update failed: %v\n", err)
os.Exit(1)
}
Expand Down
2 changes: 1 addition & 1 deletion src/internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ func (m Model) Init() tea.Cmd {
ver := m.version
ttl := m.updateCheckInterval
cmds = append(cmds, func() tea.Msg {
latest, dlURL, csURL, err := selfupdate.CheckLatestCached(ver, ttl)
latest, dlURL, csURL, err := selfupdate.CheckLatestCached(context.Background(), ver, ttl)
if err != nil || latest == "" {
return nil
}
Expand Down
2 changes: 2 additions & 0 deletions src/internal/app/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ func (m Model) viewContent() string {
return m.help.Render()
}

// Overlay priority chain — overlays are mutually exclusive by design
// (each action activates at most one), so first-match ordering is safe.
if m.activeModal == modalConfirm {
return m.confirm.View()
}
Expand Down
7 changes: 4 additions & 3 deletions src/internal/selfupdate/cache.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package selfupdate

import (
"context"
"encoding/json"
"errors"
"os"
Expand Down Expand Up @@ -31,7 +32,7 @@ var CachePath = func() string {

// checkFn is the function used to query for the latest version.
// It is a variable so tests can override it.
var checkFn = CheckLatest
var checkFn func(context.Context, string) (string, string, string, error) = CheckLatest // overridable for tests

// LoadCache reads the cached update-check result from disk.
// Returns nil, nil if the file does not exist.
Expand Down Expand Up @@ -90,7 +91,7 @@ func SaveCache(entry CacheEntry) error {
// if the cache is missing, expired (older than ttl), or was written
// for a different binary version (i.e. the user upgraded via their
// package manager).
func CheckLatestCached(currentVersion string, ttl time.Duration) (latest, downloadURL, checksumsURL string, err error) {
func CheckLatestCached(ctx context.Context, currentVersion string, ttl time.Duration) (latest, downloadURL, checksumsURL string, err error) {
shared.Debugf("[selfupdate] CheckLatestCached: start currentVersion=%s ttl=%s", currentVersion, ttl)
cache, _ := LoadCache()
if cache != nil && cache.CurrentVersion == currentVersion && time.Since(cache.CheckedAt) < ttl {
Expand All @@ -104,7 +105,7 @@ func CheckLatestCached(currentVersion string, ttl time.Duration) (latest, downlo
}

shared.Debugf("[selfupdate] CheckLatestCached: cache miss or expired, querying API")
latest, downloadURL, checksumsURL, err = checkFn(currentVersion)
latest, downloadURL, checksumsURL, err = checkFn(ctx, currentVersion)
if err != nil {
shared.Debugf("[selfupdate] CheckLatestCached: error from API: %v", err)
return "", "", "", err
Expand Down
13 changes: 7 additions & 6 deletions src/internal/selfupdate/cache_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package selfupdate

import (
"context"
"encoding/json"
"os"
"path/filepath"
Expand Down Expand Up @@ -80,13 +81,13 @@ func TestCheckLatestCached_UsesCacheWithinTTL(t *testing.T) {
// Override checkFn to track if CheckLatest is called
called := false
origCheckFn := checkFn
checkFn = func(ver string) (string, string, string, error) {
checkFn = func(_ context.Context, ver string) (string, string, string, error) {
called = true
return "", "", "", nil
}
defer func() { checkFn = origCheckFn }()

latest, _, _, err := CheckLatestCached("v1.0.0", 24*time.Hour)
latest, _, _, err := CheckLatestCached(context.Background(), "v1.0.0", 24*time.Hour)
if err != nil {
t.Fatalf("CheckLatestCached: %v", err)
}
Expand Down Expand Up @@ -116,13 +117,13 @@ func TestCheckLatestCached_RefreshesExpiredCache(t *testing.T) {

called := false
origCheckFn := checkFn
checkFn = func(ver string) (string, string, string, error) {
checkFn = func(_ context.Context, ver string) (string, string, string, error) {
called = true
return "v2.0.0", "https://example.com/bin2", "https://example.com/SHA256SUMS2", nil
}
defer func() { checkFn = origCheckFn }()

latest, _, _, err := CheckLatestCached("v1.0.0", 24*time.Hour)
latest, _, _, err := CheckLatestCached(context.Background(), "v1.0.0", 24*time.Hour)
if err != nil {
t.Fatalf("CheckLatestCached: %v", err)
}
Expand Down Expand Up @@ -152,13 +153,13 @@ func TestCheckLatestCached_InvalidatesOnVersionChange(t *testing.T) {

called := false
origCheckFn := checkFn
checkFn = func(ver string) (string, string, string, error) {
checkFn = func(_ context.Context, ver string) (string, string, string, error) {
called = true
return "v2.0.0", "https://example.com/bin", "", nil
}
defer func() { checkFn = origCheckFn }()

_, _, _, err := CheckLatestCached("v1.5.0", 24*time.Hour)
_, _, _, err := CheckLatestCached(context.Background(), "v1.5.0", 24*time.Hour)
if err != nil {
t.Fatalf("CheckLatestCached: %v", err)
}
Expand Down
152 changes: 60 additions & 92 deletions src/internal/selfupdate/selfupdate.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package selfupdate

import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
Expand All @@ -12,50 +14,71 @@ import (
"runtime"
"strconv"
"strings"
"time"

"github.com/larkly/lazystack/internal/shared"
)

const releaseAPI = "https://api.github.com/repos/larkly/lazystack/releases/latest"

// httpClient is used for API/metadata requests (30s timeout).
var httpClient = &http.Client{Timeout: 30 * time.Second}

// downloadClient is used for binary downloads (5 minute timeout).
var downloadClient = &http.Client{Timeout: 5 * time.Minute}

// githubRelease is the subset of the GitHub release API response we need.
type githubRelease struct {
TagName string `json:"tag_name"`
Assets []githubAsset `json:"assets"`
}

// githubAsset is a single asset in a GitHub release.
type githubAsset struct {
Name string `json:"name"`
BrowserDownloadURL string `json:"browser_download_url"`
}

// CheckLatest checks GitHub for a newer release. Returns empty strings if
// already up to date. Returns an error if currentVersion is "dev".
func CheckLatest(currentVersion string) (latest, downloadURL, checksumsURL string, err error) {
func CheckLatest(ctx context.Context, currentVersion string) (latest, downloadURL, checksumsURL string, err error) {
shared.Debugf("[selfupdate] CheckLatest: start currentVersion=%s", currentVersion)
if currentVersion == "dev" {
shared.Debugf("[selfupdate] CheckLatest: error dev build")
return "", "", "", errors.New("cannot check for updates on a dev build; build with -ldflags \"-X main.version=vX.Y.Z\"")
}

body, err := httpGet(releaseAPI)
body, err := httpGet(ctx, releaseAPI)
if err != nil {
shared.Debugf("[selfupdate] CheckLatest: error fetching release: %v", err)
return "", "", "", fmt.Errorf("fetching latest release: %w", err)
}

tagName := jsonString(body, "tag_name")
if tagName == "" {
shared.Debugf("[selfupdate] CheckLatest: error parsing tag_name")
var release githubRelease
if err := json.Unmarshal(body, &release); err != nil {
shared.Debugf("[selfupdate] CheckLatest: error parsing release JSON: %v", err)
return "", "", "", fmt.Errorf("parsing release response: %w", err)
}

if release.TagName == "" {
shared.Debugf("[selfupdate] CheckLatest: error empty tag_name")
return "", "", "", errors.New("could not parse tag_name from release response")
}
shared.Debugf("[selfupdate] CheckLatest: found tagName=%s", tagName)
shared.Debugf("[selfupdate] CheckLatest: found tagName=%s", release.TagName)

if !isNewer(tagName, currentVersion) {
if !isNewer(release.TagName, currentVersion) {
shared.Debugf("[selfupdate] CheckLatest: already up to date")
return "", "", "", nil
}

assetName := fmt.Sprintf("lazystack-%s-%s", runtime.GOOS, runtime.GOARCH)
shared.Debugf("[selfupdate] CheckLatest: looking for asset %s", assetName)
assets := jsonArray(body, "assets")
for _, asset := range assets {
name := jsonString(asset, "name")
url := jsonString(asset, "browser_download_url")
if name == assetName {
downloadURL = url
for _, asset := range release.Assets {
if asset.Name == assetName {
downloadURL = asset.BrowserDownloadURL
}
if name == "SHA256SUMS" {
checksumsURL = url
if asset.Name == "SHA256SUMS" {
checksumsURL = asset.BrowserDownloadURL
}
}

Expand All @@ -64,13 +87,13 @@ func CheckLatest(currentVersion string) (latest, downloadURL, checksumsURL strin
return "", "", "", fmt.Errorf("no asset found for %s", assetName)
}

shared.Debugf("[selfupdate] CheckLatest: success latest=%s downloadURL=%s", tagName, downloadURL)
return tagName, downloadURL, checksumsURL, nil
shared.Debugf("[selfupdate] CheckLatest: success latest=%s downloadURL=%s", release.TagName, downloadURL)
return release.TagName, downloadURL, checksumsURL, nil
}

// Apply downloads the binary from downloadURL, optionally verifies its checksum
// using checksumsURL, and replaces the current executable.
func Apply(downloadURL, checksumsURL string) error {
func Apply(ctx context.Context, downloadURL, checksumsURL string) error {
shared.Debugf("[selfupdate] Apply: start downloadURL=%s", downloadURL)
exePath, err := os.Executable()
if err != nil {
Expand All @@ -96,7 +119,12 @@ func Apply(downloadURL, checksumsURL string) error {
}()

shared.Debugf("[selfupdate] Apply: downloading binary")
resp, err := http.Get(downloadURL)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil)
if err != nil {
shared.Debugf("[selfupdate] Apply: error creating request: %v", err)
return fmt.Errorf("creating download request: %w", err)
}
resp, err := downloadClient.Do(req)
if err != nil {
shared.Debugf("[selfupdate] Apply: error downloading: %v", err)
return fmt.Errorf("downloading binary: %w", err)
Expand All @@ -119,7 +147,7 @@ func Apply(downloadURL, checksumsURL string) error {

if checksumsURL != "" {
shared.Debugf("[selfupdate] Apply: verifying checksum")
if err := verifyChecksum(checksumsURL, got); err != nil {
if err := verifyChecksum(ctx, checksumsURL, got); err != nil {
shared.Debugf("[selfupdate] Apply: error checksum verification: %v", err)
return err
}
Expand All @@ -140,14 +168,14 @@ func Apply(downloadURL, checksumsURL string) error {
return nil
}

func verifyChecksum(checksumsURL, gotHash string) error {
body, err := httpGet(checksumsURL)
func verifyChecksum(ctx context.Context, checksumsURL, gotHash string) error {
body, err := httpGet(ctx, checksumsURL)
if err != nil {
return fmt.Errorf("downloading checksums: %w", err)
}

assetName := fmt.Sprintf("lazystack-%s-%s", runtime.GOOS, runtime.GOARCH)
for _, line := range strings.Split(body, "\n") {
for _, line := range strings.Split(string(body), "\n") {
parts := strings.Fields(line)
if len(parts) == 2 && parts[1] == assetName {
if parts[0] != gotHash {
Expand All @@ -160,20 +188,20 @@ func verifyChecksum(checksumsURL, gotHash string) error {
return fmt.Errorf("no checksum found for %s in SHA256SUMS", assetName)
}

func httpGet(url string) (string, error) {
resp, err := http.Get(url)
func httpGet(ctx context.Context, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := httpClient.Do(req)
if err != nil {
return "", err
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d from %s", resp.StatusCode, url)
return nil, fmt.Errorf("HTTP %d from %s", resp.StatusCode, url)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(b), nil
return io.ReadAll(resp.Body)
}

// isNewer returns true if latest is a higher semver than current.
Expand Down Expand Up @@ -215,63 +243,3 @@ func parseVersion(v string) []int {
}
return nums
}

// Minimal JSON helpers — avoids encoding/json for simple field extraction.

func jsonString(json, key string) string {
needle := fmt.Sprintf("%q", key)
idx := strings.Index(json, needle)
if idx < 0 {
return ""
}
rest := json[idx+len(needle):]
// skip `: `
rest = strings.TrimLeft(rest, " \t\n\r:")
if len(rest) == 0 || rest[0] != '"' {
return ""
}
rest = rest[1:]
end := strings.Index(rest, "\"")
if end < 0 {
return ""
}
return rest[:end]
}

func jsonArray(json, key string) []string {
needle := fmt.Sprintf("%q", key)
idx := strings.Index(json, needle)
if idx < 0 {
return nil
}
rest := json[idx+len(needle):]
rest = strings.TrimLeft(rest, " \t\n\r:")
if len(rest) == 0 || rest[0] != '[' {
return nil
}
rest = rest[1:]

var items []string
depth := 0
start := -1
for i := 0; i < len(rest); i++ {
switch rest[i] {
case '{':
if depth == 0 {
start = i
}
depth++
case '}':
depth--
if depth == 0 && start >= 0 {
items = append(items, rest[start:i+1])
start = -1
}
case ']':
if depth == 0 {
return items
}
}
}
return items
}
Loading