From 2ae5ee82858c7af9d1201f5e4ee6f8cc24c89266 Mon Sep 17 00:00:00 2001 From: Eric Phillips Date: Wed, 6 May 2026 22:22:41 -0600 Subject: [PATCH] use os.*Root instead of os to prevent filesystem traversal bugs and security leaks --- handler_test.go | 5 ++++- internal/cache/cache.go | 14 ++++++++++---- internal/cache/cache_test.go | 20 ++++++++------------ internal/cache/fetch.go | 19 +++++++------------ internal/cache/helpers.go | 22 +++++++++++----------- main.go | 6 +++++- 6 files changed, 45 insertions(+), 41 deletions(-) diff --git a/handler_test.go b/handler_test.go index d42dbd9..b37406b 100644 --- a/handler_test.go +++ b/handler_test.go @@ -34,7 +34,10 @@ func newTestServer(t *testing.T, mirrorHandler http.HandlerFunc) (*httptest.Serv mirror := httptest.NewServer(mirrorHandler) t.Cleanup(func() { mirror.Close() }) - c := cache.NewCache(t.TempDir(), []string{mirror.URL + "/"}, []string{"core"}) + c, err := cache.NewCache(t.TempDir(), []string{mirror.URL + "/"}, []string{"core"}) + if err != nil { + t.Fatalf("failed to create cache: %v", err) + } cfg := &Config{ Port: "0", Auth: AuthConfig{Token: "testtoken"}, diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 60b409c..9690e5f 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -5,6 +5,7 @@ import ( "io" "net" "net/http" + "os" "sync" "sync/atomic" "time" @@ -16,6 +17,7 @@ const userAgent = "pacman/7.1.0 (Linux x86_64) libalpm/16.0.1" type Cache struct { cfg CacheConfig + cr *os.Root mirrorIdx atomic.Uint64 sf singleflight.Group //prevents duplicate downloads mu sync.Mutex @@ -23,7 +25,6 @@ type Cache struct { } type CacheConfig struct { - cacheRoot string mirrorURLs []string mirroredRepos []string DialTimeout time.Duration @@ -37,9 +38,8 @@ type CacheFile struct { Filename string } -func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) *Cache { +func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) (*Cache, error) { cfg := CacheConfig{ - cacheRoot: cacheRoot, mirrorURLs: mirrorURLs, mirroredRepos: mirroredRepos, DialTimeout: 5 * time.Second, @@ -54,13 +54,19 @@ func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) *Ca ResponseHeaderTimeout: cfg.ResponseHeaderTimeout, } + cr, err := os.OpenRoot(cacheRoot) + if err != nil { + return nil, err + } + return &Cache{ cfg: cfg, + cr: cr, client: http.Client{ Timeout: cfg.ClientTimeout, Transport: transport, }, - } + }, nil } type UpstreamError struct { diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index 38e2cfd..2f4f78f 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -7,8 +7,6 @@ import ( "io" "net/http" "net/http/httptest" - "os" - "path/filepath" "testing" "time" ) @@ -23,7 +21,10 @@ func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server { func newTestCache(t *testing.T, mirrorURLs []string) *Cache { t.Helper() mirroredRepos := []string{"core", "extra"} - c := NewCache(t.TempDir(), mirrorURLs, mirroredRepos) + c, err := NewCache(t.TempDir(), mirrorURLs, mirroredRepos) + if err != nil { + t.Fatalf("failed to create cache: %v", err) + } c.client.Timeout = 500 * time.Millisecond return c } @@ -33,8 +34,7 @@ func TestCacheHit(t *testing.T) { c := newTestCache(t, []string{"http://example.com/"}) tmpFileName := "fakeFile" - tmpPath := filepath.Join(c.cfg.cacheRoot, tmpFileName) - err := os.WriteFile(tmpPath, []byte(expected), 0644) + err := c.cr.WriteFile(tmpFileName, []byte(expected), 0644) if err != nil { t.Fatalf("failed to create tempfile: %v", err) } @@ -78,9 +78,7 @@ func TestCacheMissExists(t *testing.T) { t.Fatalf("Fetch failed %v", err) } - fakefilepath := filepath.Join(c.cfg.cacheRoot, "fakefile") - - data, err := os.ReadFile(fakefilepath) + data, err := c.cr.ReadFile("fakefile") if err != nil { t.Fatalf("Error reading file back: %v", err) } @@ -141,8 +139,7 @@ func TestFetchSrvDead(t *testing.T) { t.Fatal("expected err got nil") } - var upstreamErr *UpstreamError - if errors.As(err, &upstreamErr) { + if _, ok := errors.AsType[*UpstreamError](err); ok { t.Error("expected network error not UpstreamError") } } @@ -169,8 +166,7 @@ func TestFetchRetryExists(t *testing.T) { t.Fatalf("fetch failed: %v", err) } - fakefilepath := filepath.Join(c.cfg.cacheRoot, "fakefile") - data, err := os.ReadFile(fakefilepath) + data, err := c.cr.ReadFile("fakefile") if err != nil { t.Fatalf("error reading file back: %v", err) } diff --git a/internal/cache/fetch.go b/internal/cache/fetch.go index 674b70c..55b6d30 100644 --- a/internal/cache/fetch.go +++ b/internal/cache/fetch.go @@ -3,13 +3,12 @@ package cache import ( "errors" "log/slog" - "os" "path/filepath" ) func (c *Cache) Fetch(relPath string) (*CacheFile, error) { // return file directly if exists in cache - cf, err := getCachedFile(c.cfg.cacheRoot, relPath) + cf, err := c.getCachedFile(relPath) if err == nil { return cf, nil } @@ -23,7 +22,7 @@ func (c *Cache) Fetch(relPath string) (*CacheFile, error) { return nil, err } - cf, err = getCachedFile(c.cfg.cacheRoot, relPath) + cf, err = c.getCachedFile(relPath) if err != nil { return nil, err } @@ -34,15 +33,12 @@ func (c *Cache) fetch(relPath string) error { // relPath is relative to the localRoot // ie relPath includes /{repo}/os/{arch}/ and the actual name linux-x.x.x.pkg.tar.zst - // final file name and path - destPath := filepath.Join(c.cfg.cacheRoot, relPath) - // declare vars outside loop var err error // fetch pkgs from mirror with retry logic for range len(c.cfg.mirrorURLs) { url := c.nextMirror() + relPath - err = downloadToDisk(url, destPath, c.client) + err = c.downloadToDisk(url, relPath) if err == nil { break } @@ -58,14 +54,13 @@ func (c *Cache) fetch(relPath string) error { return nil } -func getCachedFile(cacheRoot, relPath string) (*CacheFile, error) { - filePath := filepath.Join(cacheRoot, relPath) - info, err := os.Stat(filePath) +func (c *Cache) getCachedFile(relPath string) (*CacheFile, error) { + info, err := c.cr.Stat(relPath) if err != nil { return nil, err } - f, err := os.Open(filePath) + f, err := c.cr.Open(relPath) if err != nil { return nil, err } @@ -73,6 +68,6 @@ func getCachedFile(cacheRoot, relPath string) (*CacheFile, error) { return &CacheFile{ Reader: f, Size: info.Size(), - Filename: filepath.Base(filePath), + Filename: filepath.Base(relPath), }, nil } diff --git a/internal/cache/helpers.go b/internal/cache/helpers.go index ccd4bde..5d9c2ca 100644 --- a/internal/cache/helpers.go +++ b/internal/cache/helpers.go @@ -4,7 +4,6 @@ import ( "io" "log/slog" "net/http" - "os" "path/filepath" ) @@ -14,7 +13,7 @@ func (c *Cache) nextMirror() string { return c.cfg.mirrorURLs[idx%mirrorCount] } -func downloadToDisk(url, destPath string, c http.Client) error { +func (c *Cache) downloadToDisk(url, relPath string) error { slog.Info("fetching", "url", url) // set the user agent @@ -24,7 +23,7 @@ func downloadToDisk(url, destPath string, c http.Client) error { } req.Header.Set("User-Agent", userAgent) - resp, err := c.Do(req) + resp, err := c.client.Do(req) if err != nil { slog.Warn("fetch failed", "url", url, "err", err) return err @@ -36,14 +35,14 @@ func downloadToDisk(url, destPath string, c http.Client) error { defer resp.Body.Close() // make sure the dir structure exists - err = os.MkdirAll(filepath.Dir(destPath), 0750) + err = c.cr.MkdirAll(filepath.Dir(relPath), 0750) if err != nil { return err } // use a tmp file for the initial fetch in case it fails - tempPath := destPath + ".tmp" - tmpFile, err := os.Create(tempPath) + tmpPath := relPath + ".tmp" + tmpFile, err := c.cr.Create(tmpPath) if err != nil { return err } @@ -51,18 +50,19 @@ func downloadToDisk(url, destPath string, c http.Client) error { _, err = io.Copy(tmpFile, resp.Body) if err != nil { - removeErr := os.Remove(tempPath) + removeErr := c.cr.Remove(tmpPath) if removeErr != nil { - slog.Warn("failed to remove temp file", "path", tempPath, "err", removeErr) + slog.Warn("failed to remove temp file", "path", tmpPath, "err", removeErr) } return err } // mv file to final location - if err := os.Rename(tempPath, destPath); err != nil { - removeErr := os.Remove(tempPath) + err = c.cr.Rename(tmpPath, relPath) + if err != nil { + removeErr := c.cr.Remove(tmpPath) if removeErr != nil { - slog.Warn("failed to remove temp file", "path", tempPath, "err", removeErr) + slog.Warn("failed to remove temp file", "path", tmpPath, "err", removeErr) } return err } diff --git a/main.go b/main.go index 5ae5727..15a0e9b 100644 --- a/main.go +++ b/main.go @@ -48,7 +48,11 @@ func main() { os.Exit(1) } - c := cache.NewCache(cfg.CacheRoot, cfg.MirrorURLs, cfg.MirroredRepos) + c, err := cache.NewCache(cfg.CacheRoot, cfg.MirrorURLs, cfg.MirroredRepos) + if err != nil { + slog.Error("failed to create cache", "err", err) + os.Exit(1) + } srv := &Server{ cfg: cfg, c: c,