diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 0491dd9..c45d5c1 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -1,6 +1,7 @@ package cache import ( + "fmt" "io" "log/slog" "net" @@ -33,9 +34,11 @@ type CacheConfig struct { } type inFlight struct { - tmpPath string - done chan struct{} - err error + tmpPath string + headerReady chan struct{} + contentLength int64 + done chan struct{} + err error } type CacheFile struct { @@ -97,9 +100,21 @@ func (c *Cache) Fetch(relPath string) (*CacheFile, error) { return nil, err } + var size int64 + select { + case <-flight.headerReady: + size = flight.contentLength + err = flight.err + case <-time.After(5 * time.Second): + return nil, fmt.Errorf("upstream header timeout") + } + if err != nil { + return nil, err + } + return &CacheFile{ Reader: &tailer{f: file, flight: flight}, - Size: -1, + Size: size, Filename: filepath.Base(relPath), }, nil diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index 92a985c..a56b5e0 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -2,6 +2,7 @@ package cache import ( "bytes" + "errors" "fmt" "io" "log/slog" @@ -37,6 +38,14 @@ func newTestCache(t *testing.T, mirrorURLs []string) *Cache { return c } +func newTestFlight(tmpPath string) *inFlight { + return &inFlight{ + tmpPath: tmpPath, + headerReady: make(chan struct{}), + done: make(chan struct{}), + } +} + func TestFetch(t *testing.T) { // test happy paths on fetch, the error paths all return through // the handler so need to be tested from the handler @@ -172,6 +181,7 @@ func TestGetStreamMultiplClient(t *testing.T) { } func TestDownloadWrangle(t *testing.T) { + const expected = "This is fake file contents" t.Run("Download error propagates to flight.err", func(t *testing.T) { const expected = "This is fake file contents" svr := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -182,10 +192,7 @@ func TestDownloadWrangle(t *testing.T) { c := newTestCache(t, []string{svr.URL + "/"}) relPath := "fakefile" tmpPath := "fakefile.tmp" - flight := &inFlight{ - tmpPath: tmpPath, - done: make(chan struct{}), - } + flight := newTestFlight(tmpPath) tmpFile, err := c.cr.Create(tmpPath) require.NoError(t, err, "failed open test file") @@ -200,14 +207,10 @@ func TestDownloadWrangle(t *testing.T) { }) t.Run("Network error propagates to flight.err", func(t *testing.T) { - const expected = "This is fake file contents" c := newTestCache(t, []string{"http://127.0.0.1/"}) relPath := "fakefile" tmpPath := "fakefile.tmp" - flight := &inFlight{ - tmpPath: tmpPath, - done: make(chan struct{}), - } + flight := newTestFlight(tmpPath) tmpFile, err := c.cr.Create(tmpPath) require.NoError(t, err, "failed open test file") @@ -222,7 +225,6 @@ func TestDownloadWrangle(t *testing.T) { }) t.Run("Retry works across mirror", func(t *testing.T) { - const expected = "This is fake file contents" svrMiss := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { //nolint:errcheck //ephemeral no need to check w.WriteHeader(http.StatusNotFound) @@ -239,10 +241,7 @@ func TestDownloadWrangle(t *testing.T) { }) relPath := "fakefile" tmpPath := "fakefile.tmp" - flight := &inFlight{ - tmpPath: tmpPath, - done: make(chan struct{}), - } + flight := newTestFlight(tmpPath) tmpFile, err := c.cr.Create(tmpPath) require.NoError(t, err, "failed open test file") @@ -257,7 +256,6 @@ func TestDownloadWrangle(t *testing.T) { }) t.Run("Cleanup runs on failure", func(t *testing.T) { - const expected = "This is fake file contents" svr := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { //nolint:errcheck //ephemeral no need to check w.WriteHeader(http.StatusNotFound) @@ -266,21 +264,54 @@ func TestDownloadWrangle(t *testing.T) { c := newTestCache(t, []string{svr.URL + "/"}) relPath := "fakefile" tmpPath := "fakefile.tmp" - flight := &inFlight{ - tmpPath: tmpPath, - done: make(chan struct{}), - } + flight := newTestFlight(tmpPath) tmpFile, err := c.cr.Create(tmpPath) require.NoError(t, err, "failed open test file") c.downloadWrangle(relPath, flight, tmpFile) _, err = os.Stat(tmpPath) assert.ErrorIs(t, err, os.ErrNotExist) + select { + case <-flight.headerReady: + //closed + default: + t.Error("headerReady not closes") + } + select { + case <-flight.done: + //closed + default: + t.Error("done not closed") + } c.inFlightMu.Lock() _, ok := c.inFlight[relPath] c.inFlightMu.Unlock() assert.False(t, ok, "expected inFlight entry to be removed") }) + + t.Run("Size propagates to flight", func(t *testing.T) { + svr := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + //nolint:errcheck //ephemeral no need to check + fmt.Fprintf(w, "%s", expected) + })) + c := newTestCache(t, []string{svr.URL + "/"}) + relPath := "fakefile" + tmpPath := "fakefile.tmp" + flight := newTestFlight(tmpPath) + tmpFile, err := c.cr.Create(tmpPath) + require.NoError(t, err, "failed open test file") + + c.downloadWrangle(relPath, flight, tmpFile) + var size int64 + select { + case <-flight.headerReady: + size = flight.contentLength + case <-time.After(time.Second): + t.Fatal("content-length never got set") + } + assert.Equal(t, int64(len(expected)), size) + + }) } func TestTailer(t *testing.T) { @@ -314,7 +345,7 @@ func TestTailer(t *testing.T) { require.NoError(t, err) go func() { - for _ = range 3 { + for range 3 { fmt.Fprintf(wf, "%s", expected) time.Sleep(100 * time.Millisecond) } @@ -331,7 +362,25 @@ func TestTailer(t *testing.T) { require.NoError(t, err) assert.Equal(t, []byte(strings.Repeat(expected, 3)), data) }) - // Test: blocks until done - // Test: propagates flight.err - // Test: return true EOF + + t.Run("propagate flight.err", func(t *testing.T) { + expectedErr := errors.New("upstream failed") + tmpPath := filepath.Join(t.TempDir(), filename) + + err := os.WriteFile(tmpPath, []byte{}, 0660) + require.NoError(t, err) + + f, err := os.Open(tmpPath) + require.NoError(t, err) + + flight := &inFlight{ + done: make(chan struct{}), + err: expectedErr, + } + close(flight.done) + + tr := &tailer{f: f, flight: flight} + _, err = io.ReadAll(tr) + assert.ErrorIs(t, err, expectedErr) + }) } diff --git a/internal/cache/download.go b/internal/cache/download.go index 0546cf6..471ade2 100644 --- a/internal/cache/download.go +++ b/internal/cache/download.go @@ -19,7 +19,7 @@ func (c *Cache) downloadWrangle(relPath string, flight *inFlight, tmpFile *os.Fi // fetch pkgs from mirror with retry logic for range len(c.cfg.mirrorURLs) { url := c.nextMirror() + relPath - err = c.downloadToDisk(url, tmpFile) + err = c.downloadToDisk(url, flight, tmpFile) if err == nil { break } @@ -53,7 +53,7 @@ func (c *Cache) downloadWrangle(relPath string, flight *inFlight, tmpFile *os.Fi slog.Debug("file moved to final location", "err", err) } -func (c *Cache) downloadToDisk(url string, tmpFile *os.File) error { +func (c *Cache) downloadToDisk(url string, flight *inFlight, tmpFile *os.File) error { slog.Info("fetching", "url", url) // set the user agent @@ -78,6 +78,10 @@ func (c *Cache) downloadToDisk(url string, tmpFile *os.File) error { } }() + size := resp.ContentLength + flight.contentLength = size + close(flight.headerReady) + _, err = io.Copy(tmpFile, resp.Body) if err != nil { return err @@ -92,6 +96,16 @@ func (c *Cache) cleanupFlight(key string, f *inFlight) { c.inFlightMu.Lock() delete(c.inFlight, key) c.inFlightMu.Unlock() - slog.Debug("closing done channel") - close(f.done) + slog.Debug("closing channels") + safeClose(f.headerReady) + safeClose(f.done) +} + +func safeClose(ch chan struct{}) { + select { + case <-ch: + // already closed + default: + close(ch) + } } diff --git a/internal/cache/get_stream.go b/internal/cache/get_stream.go index c13d5c4..e1812da 100644 --- a/internal/cache/get_stream.go +++ b/internal/cache/get_stream.go @@ -34,8 +34,10 @@ func (c *Cache) getStream(relPath string) (*inFlight, *os.File, error) { } flight := &inFlight{ - tmpPath: tmpPath, - done: make(chan struct{}), + contentLength: 0, + headerReady: make(chan struct{}), + tmpPath: tmpPath, + done: make(chan struct{}), } c.inFlight[relPath] = flight