From 58b5ab55bac2814e8f3ae173c25ee0ecda9da2e3 Mon Sep 17 00:00:00 2001 From: Eric Phillips Date: Fri, 1 May 2026 23:44:37 -0600 Subject: [PATCH] refactor all file serve logic into internal/cache --- deploy/pkgstash.toml.example | 8 ++- handlerPkgs.go | 35 +++++++----- internal/cache/cache.go | 103 +++-------------------------------- internal/cache/cache_test.go | 51 ++++++++++++++--- internal/cache/fetch.go | 72 ++++++++++++++++++++++++ internal/cache/helpers.go | 57 +++++++++++++++++++ internal/cache/refresh.go | 2 +- 7 files changed, 208 insertions(+), 120 deletions(-) create mode 100644 internal/cache/fetch.go create mode 100644 internal/cache/helpers.go diff --git a/deploy/pkgstash.toml.example b/deploy/pkgstash.toml.example index 506f030..c7295d4 100644 --- a/deploy/pkgstash.toml.example +++ b/deploy/pkgstash.toml.example @@ -1,7 +1,9 @@ -cache_root = "/home/ewpt3ch/dev/pacman-cache-server/tmprepo" -mirror_urls = ["https://us.mirrors.cicku.me/archlinux/", +cache_root = "/home/ewpt3ch/dev/pkgstash/tmprepo" +mirror_urls = [ "https://losangeles.mirror.pkgbuild.com/", - "https://mirror.givebytes.net/archlinux/"] + "https://mirror.givebytes.net/archlinux/", + "https://arch.mirror.constant.com/", + ] # array of upstream repos this server caches see pacman.conf # or pacman docs for more info mirrored_repos = ["core", "extra"] diff --git a/handlerPkgs.go b/handlerPkgs.go index 67cc4c2..09f8949 100644 --- a/handlerPkgs.go +++ b/handlerPkgs.go @@ -2,10 +2,11 @@ package main import ( "errors" + "io" "log" "net/http" - "os" "path/filepath" + "strconv" "strings" "github.com/ewpt3ch/pkgstash/internal/cache" @@ -27,24 +28,28 @@ func (s *Server) handlePackage(w http.ResponseWriter, req *http.Request) { repo := req.PathValue("repo") arch := req.PathValue("arch") file := req.PathValue("file") - repoPath := filepath.Join(repo, "os", arch, file) //path from mirror root to pkg or db file - cachePath := filepath.Join(s.cfg.CacheRoot, repoPath) //absolute path for local read of the file + repoPath := filepath.Join(repo, "os", arch, file) //path from mirror root to pkg or db file - if _, err := os.Stat(cachePath); err != nil { - err = s.c.Fetch(repoPath) - if err != nil { - var upstreamErr *cache.UpstreamError - if errors.As(err, &upstreamErr) { - log.Printf("upstream error: %v", err) - http.Error(w, "Not found upstream", upstreamErr.StatusCode) - return - } - log.Printf("fetch error: %v", err) - http.Error(w, "Failed to fetch from upstream", http.StatusBadGateway) + cachedFile, err := s.c.Fetch(repoPath) + if err != nil { + var upstreamErr *cache.UpstreamError + if errors.As(err, &upstreamErr) { + log.Printf("upstream error: %v", err) + http.Error(w, "Not found upstream", upstreamErr.StatusCode) return } + log.Printf("fetch error: %v", err) + http.Error(w, "Failed to fetch from upstream", http.StatusBadGateway) + return + } + defer cachedFile.Reader.Close() + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", "attachment; filename="+cachedFile.Filename) + w.Header().Set("Content-Length", strconv.FormatInt(cachedFile.Size, 10)) + _, err = io.Copy(w, cachedFile.Reader) + if err != nil { + log.Printf("error streaming file to client: %v", err) } - http.ServeFile(w, req, cachePath) } diff --git a/internal/cache/cache.go b/internal/cache/cache.go index c650097..1c297ad 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -3,11 +3,8 @@ package cache import ( "fmt" "io" - "log" "net" "net/http" - "os" - "path/filepath" "sync" "sync/atomic" "time" @@ -15,6 +12,8 @@ import ( "golang.org/x/sync/singleflight" ) +const userAgent = "pacman/7.1.0 (Linux x86_64) libalpm/16.0.1" + type Cache struct { cfg CacheConfig mirrorIdx atomic.Uint32 @@ -27,21 +26,25 @@ type CacheConfig struct { cacheRoot string mirrorURLs []string mirroredRepos []string - userAgent string DialTimeout time.Duration ResponseHeaderTimeout time.Duration ClientTimeout time.Duration } +type CacheFile struct { + Reader io.ReadCloser + Size int64 + Filename string +} + func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) *Cache { cfg := CacheConfig{ cacheRoot: cacheRoot, mirrorURLs: mirrorURLs, mirroredRepos: mirroredRepos, - userAgent: "pacman/7.1.0 (Linux x86_64) libalpm/16.0.1", DialTimeout: 5 * time.Second, ResponseHeaderTimeout: 10 * time.Second, - ClientTimeout: 15 * time.Second, + ClientTimeout: 0 * time.Second, } transport := &http.Transport{ @@ -60,15 +63,6 @@ func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) *Ca } } -func (c *Cache) Fetch(pkgPath string) error { - log.Printf("pkgPath from Fetch %v", pkgPath) - _, err, _ := c.sf.Do(pkgPath, func() (any, error) { - log.Print("calling fetch") - return nil, c.fetch(pkgPath) - }) - return err -} - type UpstreamError struct { StatusCode int } @@ -76,82 +70,3 @@ type UpstreamError struct { func (e *UpstreamError) Error() string { return fmt.Sprintf("upstream returned %d", e.StatusCode) } - -func (c *Cache) fetch(pkgName string) error { - // pkgName is relative to the localRoot - // ie pkgName includes /{repo}/os/{arch}/ and the actual name linux-x.x.x.pkg.tar.zst - - tempPkgName := pkgName + ".tmp" - tempPkgPath := filepath.Join(c.cfg.cacheRoot, tempPkgName) //full tmp write path - - // final file name and path - outPkg := filepath.Join(c.cfg.cacheRoot, pkgName) - - // declare vars outside loop - var resp *http.Response - var req *http.Request - var err error - // fetch pkgs from mirror with retry logic - for range len(c.cfg.mirrorURLs) { - pkgURL := c.nextMirror() + pkgName - log.Printf("fetching %v", pkgURL) - - // set the user agent - req, err = http.NewRequest("GET", pkgURL, nil) - if err != nil { - log.Printf("failed to create request: %v", err) - return &UpstreamError{StatusCode: http.StatusInternalServerError} - } - req.Header.Set("User-Agent", c.cfg.userAgent) - - resp, err = c.client.Do(req) - if err != nil { - log.Printf("error fetching %s: %v", pkgURL, err) - continue - } - if resp.StatusCode == http.StatusOK { - break - } - log.Printf("retrying on code %v", resp.StatusCode) - resp.Body.Close() - } - if resp == nil { - return fmt.Errorf("all mirrors exhausted") - } - defer resp.Body.Close() - - if err != nil { - log.Printf("exhauted all mirrors error: %v", err) - return err - } - - if resp.StatusCode != http.StatusOK { - log.Printf("exhauted all mirrors %v", resp.StatusCode) - return &UpstreamError{StatusCode: resp.StatusCode} - } - - // use a tmp file for the initial fetch in case it fails - outFile, err := os.Create(tempPkgPath) - if err != nil { - return err - } - defer outFile.Close() - - _, err = io.Copy(outFile, resp.Body) - if err != nil { - os.Remove(tempPkgPath) - return err - } - - // mv file to final location - if err := os.Rename(tempPkgPath, outPkg); err != nil { - os.Remove(tempPkgPath) - return err - } - return nil -} - -func (c *Cache) nextMirror() string { - idx := c.mirrorIdx.Add(1) - 1 - return c.cfg.mirrorURLs[idx%uint32(len(c.cfg.mirrorURLs))] -} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index d635ace..38e2cfd 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "io" "net/http" "net/http/httptest" "os" @@ -27,7 +28,43 @@ func newTestCache(t *testing.T, mirrorURLs []string) *Cache { return c } -func TestFetchFileExists(t *testing.T) { +func TestCacheHit(t *testing.T) { + const expected = "This is fake file contents" + + c := newTestCache(t, []string{"http://example.com/"}) + tmpFileName := "fakeFile" + tmpPath := filepath.Join(c.cfg.cacheRoot, tmpFileName) + err := os.WriteFile(tmpPath, []byte(expected), 0644) + if err != nil { + t.Fatalf("failed to create tempfile: %v", err) + } + + cachedFile, err := c.Fetch(tmpFileName) + if err != nil { + t.Fatalf("failed to fetch file: %v", err) + } + defer cachedFile.Reader.Close() + + if cachedFile.Filename != tmpFileName { + t.Errorf("expected filename %s got %s", tmpFileName, cachedFile.Filename) + } + + if int64(len(expected)) != cachedFile.Size { + t.Errorf("expected %d got %d", len(expected), cachedFile.Size) + } + + data, err := io.ReadAll(cachedFile.Reader) + if err != nil { + t.Fatalf("error reading file back: %v", err) + } + defer cachedFile.Reader.Close() + + if !bytes.Equal(data, []byte(expected)) { + t.Errorf("expected file to contain %s got %s", expected, data) + } +} + +func TestCacheMissExists(t *testing.T) { const expected = "This is fake file contents" svr := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -36,7 +73,7 @@ func TestFetchFileExists(t *testing.T) { c := newTestCache(t, []string{svr.URL + "/"}) - err := c.Fetch("fakefile") + _, err := c.Fetch("fakefile") if err != nil { t.Fatalf("Fetch failed %v", err) } @@ -59,7 +96,7 @@ func TestFetchNotFound(t *testing.T) { c := newTestCache(t, []string{svr.URL + "/"}) - err := c.Fetch("fakefile") + _, err := c.Fetch("fakefile") var upstreamErr *UpstreamError if !errors.As(err, &upstreamErr) { t.Fatalf("expected UpstreamError got %v", err) @@ -76,7 +113,7 @@ func TestFetchSrvError(t *testing.T) { c := newTestCache(t, []string{svr.URL + "/"}) - err := c.Fetch("fakefile") + _, err := c.Fetch("fakefile") var upstreamErr *UpstreamError if !errors.As(err, &upstreamErr) { t.Fatalf("expected UpstreamError fot %v", err) @@ -99,7 +136,7 @@ func TestFetchSrvDead(t *testing.T) { c := newTestCache(t, []string{svr.URL + "/"}) - err := c.Fetch("fakefile") + _, err := c.Fetch("fakefile") if err == nil { t.Fatal("expected err got nil") } @@ -127,7 +164,7 @@ func TestFetchRetryExists(t *testing.T) { c := newTestCache(t, fakeURLs) - err := c.Fetch("fakefile") + _, err := c.Fetch("fakefile") if err != nil { t.Fatalf("fetch failed: %v", err) } @@ -159,7 +196,7 @@ func TestFetchRetryNonExist(t *testing.T) { c := newTestCache(t, fakeURLs) - err := c.Fetch("fakefile") + _, err := c.Fetch("fakefile") var upstreamErr *UpstreamError if !errors.As(err, &upstreamErr) { t.Errorf("expected UpstreamError got %v", err) diff --git a/internal/cache/fetch.go b/internal/cache/fetch.go new file mode 100644 index 0000000..9311598 --- /dev/null +++ b/internal/cache/fetch.go @@ -0,0 +1,72 @@ +package cache + +import ( + "log" + "os" + "path/filepath" +) + +func (c *Cache) Fetch(relPath string) (*CacheFile, error) { + // return file directly if exists in cache + cf, err := getCachedFile(c.cfg.cacheRoot, relPath) + if err == nil { + return cf, nil + } + + // fetch file from upstream + _, err, _ = c.sf.Do(relPath, func() (any, error) { + log.Print("calling fetch") + return nil, c.fetch(relPath) + }) + if err != nil { + return nil, err + } + + cf, err = getCachedFile(c.cfg.cacheRoot, relPath) + if err != nil { + return nil, err + } + return cf, nil +} + +func (c *Cache) fetch(relPath string) error { + // relPath is relative to the localRoot + // ie relPath includes /{repo}/os/{arch}/ and the actual name linux-x.x.x.pkg.tar.zst + + // final file name and path + destPath := filepath.Join(c.cfg.cacheRoot, relPath) + + // declare vars outside loop + var err error + // fetch pkgs from mirror with retry logic + for range len(c.cfg.mirrorURLs) { + url := c.nextMirror() + relPath + err = downloadToDisk(url, destPath, c.client) + if err == nil { + break + } + } + if err != nil { + return err + } + return nil +} + +func getCachedFile(cacheRoot, relPath string) (*CacheFile, error) { + filePath := filepath.Join(cacheRoot, relPath) + info, err := os.Stat(filePath) + if err != nil { + return nil, err + } + + f, err := os.Open(filePath) + if err != nil { + return nil, err + } + + return &CacheFile{ + Reader: f, + Size: info.Size(), + Filename: filepath.Base(filePath), + }, nil +} diff --git a/internal/cache/helpers.go b/internal/cache/helpers.go new file mode 100644 index 0000000..8e79f59 --- /dev/null +++ b/internal/cache/helpers.go @@ -0,0 +1,57 @@ +package cache + +import ( + "io" + "log" + "net/http" + "os" +) + +func (c *Cache) nextMirror() string { + idx := c.mirrorIdx.Add(1) - 1 + return c.cfg.mirrorURLs[idx%uint32(len(c.cfg.mirrorURLs))] +} + +func downloadToDisk(url, destPath string, c http.Client) error { + log.Printf("fetching %v", url) + + // set the user agent + req, err := http.NewRequest("GET", url, nil) + if err != nil { + log.Printf("failed to create request: %v", err) + return &UpstreamError{StatusCode: http.StatusInternalServerError} + } + req.Header.Set("User-Agent", userAgent) + + resp, err := c.Do(req) + if err != nil { + log.Printf("error fetching %s: %v", url, err) + return err + } + if resp.StatusCode != 200 { + log.Printf("GET %s returned %d", url, resp.StatusCode) + return &UpstreamError{StatusCode: resp.StatusCode} + } + defer resp.Body.Close() + + // use a tmp file for the initial fetch in case it fails + tempPath := destPath + ".tmp" + tmpFile, err := os.Create(tempPath) + if err != nil { + return err + } + defer tmpFile.Close() + + _, err = io.Copy(tmpFile, resp.Body) + if err != nil { + os.Remove(tempPath) + return err + } + + // mv file to final location + if err := os.Rename(tempPath, destPath); err != nil { + os.Remove(tempPath) + return err + } + return nil +} diff --git a/internal/cache/refresh.go b/internal/cache/refresh.go index 2309222..b1655e2 100644 --- a/internal/cache/refresh.go +++ b/internal/cache/refresh.go @@ -19,7 +19,7 @@ func (c *Cache) Refresh() error { func (c *Cache) refreshDB(repo string) error { dbFile := repo + ".db.tar.gz" dbPath := filepath.Join(repo, "os/x86_64", dbFile) - err := c.Fetch(dbPath) + err := c.fetch(dbPath) if err != nil { return err }