use os.*Root instead of os to prevent filesystem traversal bugs and security leaks

This commit is contained in:
2026-05-06 22:22:41 -06:00
parent 3505f0e059
commit 2ae5ee8285
6 changed files with 45 additions and 41 deletions
+4 -1
View File
@@ -34,7 +34,10 @@ func newTestServer(t *testing.T, mirrorHandler http.HandlerFunc) (*httptest.Serv
mirror := httptest.NewServer(mirrorHandler)
t.Cleanup(func() { mirror.Close() })
c := cache.NewCache(t.TempDir(), []string{mirror.URL + "/"}, []string{"core"})
c, err := cache.NewCache(t.TempDir(), []string{mirror.URL + "/"}, []string{"core"})
if err != nil {
t.Fatalf("failed to create cache: %v", err)
}
cfg := &Config{
Port: "0",
Auth: AuthConfig{Token: "testtoken"},
+10 -4
View File
@@ -5,6 +5,7 @@ import (
"io"
"net"
"net/http"
"os"
"sync"
"sync/atomic"
"time"
@@ -16,6 +17,7 @@ const userAgent = "pacman/7.1.0 (Linux x86_64) libalpm/16.0.1"
type Cache struct {
cfg CacheConfig
cr *os.Root
mirrorIdx atomic.Uint64
sf singleflight.Group //prevents duplicate downloads
mu sync.Mutex
@@ -23,7 +25,6 @@ type Cache struct {
}
type CacheConfig struct {
cacheRoot string
mirrorURLs []string
mirroredRepos []string
DialTimeout time.Duration
@@ -37,9 +38,8 @@ type CacheFile struct {
Filename string
}
func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) *Cache {
func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) (*Cache, error) {
cfg := CacheConfig{
cacheRoot: cacheRoot,
mirrorURLs: mirrorURLs,
mirroredRepos: mirroredRepos,
DialTimeout: 5 * time.Second,
@@ -54,13 +54,19 @@ func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) *Ca
ResponseHeaderTimeout: cfg.ResponseHeaderTimeout,
}
cr, err := os.OpenRoot(cacheRoot)
if err != nil {
return nil, err
}
return &Cache{
cfg: cfg,
cr: cr,
client: http.Client{
Timeout: cfg.ClientTimeout,
Transport: transport,
},
}
}, nil
}
type UpstreamError struct {
+8 -12
View File
@@ -7,8 +7,6 @@ import (
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
)
@@ -23,7 +21,10 @@ func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
func newTestCache(t *testing.T, mirrorURLs []string) *Cache {
t.Helper()
mirroredRepos := []string{"core", "extra"}
c := NewCache(t.TempDir(), mirrorURLs, mirroredRepos)
c, err := NewCache(t.TempDir(), mirrorURLs, mirroredRepos)
if err != nil {
t.Fatalf("failed to create cache: %v", err)
}
c.client.Timeout = 500 * time.Millisecond
return c
}
@@ -33,8 +34,7 @@ func TestCacheHit(t *testing.T) {
c := newTestCache(t, []string{"http://example.com/"})
tmpFileName := "fakeFile"
tmpPath := filepath.Join(c.cfg.cacheRoot, tmpFileName)
err := os.WriteFile(tmpPath, []byte(expected), 0644)
err := c.cr.WriteFile(tmpFileName, []byte(expected), 0644)
if err != nil {
t.Fatalf("failed to create tempfile: %v", err)
}
@@ -78,9 +78,7 @@ func TestCacheMissExists(t *testing.T) {
t.Fatalf("Fetch failed %v", err)
}
fakefilepath := filepath.Join(c.cfg.cacheRoot, "fakefile")
data, err := os.ReadFile(fakefilepath)
data, err := c.cr.ReadFile("fakefile")
if err != nil {
t.Fatalf("Error reading file back: %v", err)
}
@@ -141,8 +139,7 @@ func TestFetchSrvDead(t *testing.T) {
t.Fatal("expected err got nil")
}
var upstreamErr *UpstreamError
if errors.As(err, &upstreamErr) {
if _, ok := errors.AsType[*UpstreamError](err); ok {
t.Error("expected network error not UpstreamError")
}
}
@@ -169,8 +166,7 @@ func TestFetchRetryExists(t *testing.T) {
t.Fatalf("fetch failed: %v", err)
}
fakefilepath := filepath.Join(c.cfg.cacheRoot, "fakefile")
data, err := os.ReadFile(fakefilepath)
data, err := c.cr.ReadFile("fakefile")
if err != nil {
t.Fatalf("error reading file back: %v", err)
}
+7 -12
View File
@@ -3,13 +3,12 @@ package cache
import (
"errors"
"log/slog"
"os"
"path/filepath"
)
func (c *Cache) Fetch(relPath string) (*CacheFile, error) {
// return file directly if exists in cache
cf, err := getCachedFile(c.cfg.cacheRoot, relPath)
cf, err := c.getCachedFile(relPath)
if err == nil {
return cf, nil
}
@@ -23,7 +22,7 @@ func (c *Cache) Fetch(relPath string) (*CacheFile, error) {
return nil, err
}
cf, err = getCachedFile(c.cfg.cacheRoot, relPath)
cf, err = c.getCachedFile(relPath)
if err != nil {
return nil, err
}
@@ -34,15 +33,12 @@ func (c *Cache) fetch(relPath string) error {
// relPath is relative to the localRoot
// ie relPath includes /{repo}/os/{arch}/ and the actual name linux-x.x.x.pkg.tar.zst
// final file name and path
destPath := filepath.Join(c.cfg.cacheRoot, relPath)
// declare vars outside loop
var err error
// fetch pkgs from mirror with retry logic
for range len(c.cfg.mirrorURLs) {
url := c.nextMirror() + relPath
err = downloadToDisk(url, destPath, c.client)
err = c.downloadToDisk(url, relPath)
if err == nil {
break
}
@@ -58,14 +54,13 @@ func (c *Cache) fetch(relPath string) error {
return nil
}
func getCachedFile(cacheRoot, relPath string) (*CacheFile, error) {
filePath := filepath.Join(cacheRoot, relPath)
info, err := os.Stat(filePath)
func (c *Cache) getCachedFile(relPath string) (*CacheFile, error) {
info, err := c.cr.Stat(relPath)
if err != nil {
return nil, err
}
f, err := os.Open(filePath)
f, err := c.cr.Open(relPath)
if err != nil {
return nil, err
}
@@ -73,6 +68,6 @@ func getCachedFile(cacheRoot, relPath string) (*CacheFile, error) {
return &CacheFile{
Reader: f,
Size: info.Size(),
Filename: filepath.Base(filePath),
Filename: filepath.Base(relPath),
}, nil
}
+11 -11
View File
@@ -4,7 +4,6 @@ import (
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
)
@@ -14,7 +13,7 @@ func (c *Cache) nextMirror() string {
return c.cfg.mirrorURLs[idx%mirrorCount]
}
func downloadToDisk(url, destPath string, c http.Client) error {
func (c *Cache) downloadToDisk(url, relPath string) error {
slog.Info("fetching", "url", url)
// set the user agent
@@ -24,7 +23,7 @@ func downloadToDisk(url, destPath string, c http.Client) error {
}
req.Header.Set("User-Agent", userAgent)
resp, err := c.Do(req)
resp, err := c.client.Do(req)
if err != nil {
slog.Warn("fetch failed", "url", url, "err", err)
return err
@@ -36,14 +35,14 @@ func downloadToDisk(url, destPath string, c http.Client) error {
defer resp.Body.Close()
// make sure the dir structure exists
err = os.MkdirAll(filepath.Dir(destPath), 0750)
err = c.cr.MkdirAll(filepath.Dir(relPath), 0750)
if err != nil {
return err
}
// use a tmp file for the initial fetch in case it fails
tempPath := destPath + ".tmp"
tmpFile, err := os.Create(tempPath)
tmpPath := relPath + ".tmp"
tmpFile, err := c.cr.Create(tmpPath)
if err != nil {
return err
}
@@ -51,18 +50,19 @@ func downloadToDisk(url, destPath string, c http.Client) error {
_, err = io.Copy(tmpFile, resp.Body)
if err != nil {
removeErr := os.Remove(tempPath)
removeErr := c.cr.Remove(tmpPath)
if removeErr != nil {
slog.Warn("failed to remove temp file", "path", tempPath, "err", removeErr)
slog.Warn("failed to remove temp file", "path", tmpPath, "err", removeErr)
}
return err
}
// mv file to final location
if err := os.Rename(tempPath, destPath); err != nil {
removeErr := os.Remove(tempPath)
err = c.cr.Rename(tmpPath, relPath)
if err != nil {
removeErr := c.cr.Remove(tmpPath)
if removeErr != nil {
slog.Warn("failed to remove temp file", "path", tempPath, "err", removeErr)
slog.Warn("failed to remove temp file", "path", tmpPath, "err", removeErr)
}
return err
}
+5 -1
View File
@@ -48,7 +48,11 @@ func main() {
os.Exit(1)
}
c := cache.NewCache(cfg.CacheRoot, cfg.MirrorURLs, cfg.MirroredRepos)
c, err := cache.NewCache(cfg.CacheRoot, cfg.MirrorURLs, cfg.MirroredRepos)
if err != nil {
slog.Error("failed to create cache", "err", err)
os.Exit(1)
}
srv := &Server{
cfg: cfg,
c: c,