use os.*Root instead of os to prevent filesystem traversal bugs and security leaks
This commit is contained in:
Vendored
+10
-4
@@ -5,6 +5,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -16,6 +17,7 @@ const userAgent = "pacman/7.1.0 (Linux x86_64) libalpm/16.0.1"
|
||||
|
||||
type Cache struct {
|
||||
cfg CacheConfig
|
||||
cr *os.Root
|
||||
mirrorIdx atomic.Uint64
|
||||
sf singleflight.Group //prevents duplicate downloads
|
||||
mu sync.Mutex
|
||||
@@ -23,7 +25,6 @@ type Cache struct {
|
||||
}
|
||||
|
||||
type CacheConfig struct {
|
||||
cacheRoot string
|
||||
mirrorURLs []string
|
||||
mirroredRepos []string
|
||||
DialTimeout time.Duration
|
||||
@@ -37,9 +38,8 @@ type CacheFile struct {
|
||||
Filename string
|
||||
}
|
||||
|
||||
func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) *Cache {
|
||||
func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) (*Cache, error) {
|
||||
cfg := CacheConfig{
|
||||
cacheRoot: cacheRoot,
|
||||
mirrorURLs: mirrorURLs,
|
||||
mirroredRepos: mirroredRepos,
|
||||
DialTimeout: 5 * time.Second,
|
||||
@@ -54,13 +54,19 @@ func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) *Ca
|
||||
ResponseHeaderTimeout: cfg.ResponseHeaderTimeout,
|
||||
}
|
||||
|
||||
cr, err := os.OpenRoot(cacheRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Cache{
|
||||
cfg: cfg,
|
||||
cr: cr,
|
||||
client: http.Client{
|
||||
Timeout: cfg.ClientTimeout,
|
||||
Transport: transport,
|
||||
},
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
type UpstreamError struct {
|
||||
|
||||
Vendored
+8
-12
@@ -7,8 +7,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -23,7 +21,10 @@ func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
|
||||
func newTestCache(t *testing.T, mirrorURLs []string) *Cache {
|
||||
t.Helper()
|
||||
mirroredRepos := []string{"core", "extra"}
|
||||
c := NewCache(t.TempDir(), mirrorURLs, mirroredRepos)
|
||||
c, err := NewCache(t.TempDir(), mirrorURLs, mirroredRepos)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create cache: %v", err)
|
||||
}
|
||||
c.client.Timeout = 500 * time.Millisecond
|
||||
return c
|
||||
}
|
||||
@@ -33,8 +34,7 @@ func TestCacheHit(t *testing.T) {
|
||||
|
||||
c := newTestCache(t, []string{"http://example.com/"})
|
||||
tmpFileName := "fakeFile"
|
||||
tmpPath := filepath.Join(c.cfg.cacheRoot, tmpFileName)
|
||||
err := os.WriteFile(tmpPath, []byte(expected), 0644)
|
||||
err := c.cr.WriteFile(tmpFileName, []byte(expected), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create tempfile: %v", err)
|
||||
}
|
||||
@@ -78,9 +78,7 @@ func TestCacheMissExists(t *testing.T) {
|
||||
t.Fatalf("Fetch failed %v", err)
|
||||
}
|
||||
|
||||
fakefilepath := filepath.Join(c.cfg.cacheRoot, "fakefile")
|
||||
|
||||
data, err := os.ReadFile(fakefilepath)
|
||||
data, err := c.cr.ReadFile("fakefile")
|
||||
if err != nil {
|
||||
t.Fatalf("Error reading file back: %v", err)
|
||||
}
|
||||
@@ -141,8 +139,7 @@ func TestFetchSrvDead(t *testing.T) {
|
||||
t.Fatal("expected err got nil")
|
||||
}
|
||||
|
||||
var upstreamErr *UpstreamError
|
||||
if errors.As(err, &upstreamErr) {
|
||||
if _, ok := errors.AsType[*UpstreamError](err); ok {
|
||||
t.Error("expected network error not UpstreamError")
|
||||
}
|
||||
}
|
||||
@@ -169,8 +166,7 @@ func TestFetchRetryExists(t *testing.T) {
|
||||
t.Fatalf("fetch failed: %v", err)
|
||||
}
|
||||
|
||||
fakefilepath := filepath.Join(c.cfg.cacheRoot, "fakefile")
|
||||
data, err := os.ReadFile(fakefilepath)
|
||||
data, err := c.cr.ReadFile("fakefile")
|
||||
if err != nil {
|
||||
t.Fatalf("error reading file back: %v", err)
|
||||
}
|
||||
|
||||
Vendored
+7
-12
@@ -3,13 +3,12 @@ package cache
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func (c *Cache) Fetch(relPath string) (*CacheFile, error) {
|
||||
// return file directly if exists in cache
|
||||
cf, err := getCachedFile(c.cfg.cacheRoot, relPath)
|
||||
cf, err := c.getCachedFile(relPath)
|
||||
if err == nil {
|
||||
return cf, nil
|
||||
}
|
||||
@@ -23,7 +22,7 @@ func (c *Cache) Fetch(relPath string) (*CacheFile, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cf, err = getCachedFile(c.cfg.cacheRoot, relPath)
|
||||
cf, err = c.getCachedFile(relPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -34,15 +33,12 @@ func (c *Cache) fetch(relPath string) error {
|
||||
// relPath is relative to the localRoot
|
||||
// ie relPath includes /{repo}/os/{arch}/ and the actual name linux-x.x.x.pkg.tar.zst
|
||||
|
||||
// final file name and path
|
||||
destPath := filepath.Join(c.cfg.cacheRoot, relPath)
|
||||
|
||||
// declare vars outside loop
|
||||
var err error
|
||||
// fetch pkgs from mirror with retry logic
|
||||
for range len(c.cfg.mirrorURLs) {
|
||||
url := c.nextMirror() + relPath
|
||||
err = downloadToDisk(url, destPath, c.client)
|
||||
err = c.downloadToDisk(url, relPath)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
@@ -58,14 +54,13 @@ func (c *Cache) fetch(relPath string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getCachedFile(cacheRoot, relPath string) (*CacheFile, error) {
|
||||
filePath := filepath.Join(cacheRoot, relPath)
|
||||
info, err := os.Stat(filePath)
|
||||
func (c *Cache) getCachedFile(relPath string) (*CacheFile, error) {
|
||||
info, err := c.cr.Stat(relPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := os.Open(filePath)
|
||||
f, err := c.cr.Open(relPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -73,6 +68,6 @@ func getCachedFile(cacheRoot, relPath string) (*CacheFile, error) {
|
||||
return &CacheFile{
|
||||
Reader: f,
|
||||
Size: info.Size(),
|
||||
Filename: filepath.Base(filePath),
|
||||
Filename: filepath.Base(relPath),
|
||||
}, nil
|
||||
}
|
||||
|
||||
Vendored
+11
-11
@@ -4,7 +4,6 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
@@ -14,7 +13,7 @@ func (c *Cache) nextMirror() string {
|
||||
return c.cfg.mirrorURLs[idx%mirrorCount]
|
||||
}
|
||||
|
||||
func downloadToDisk(url, destPath string, c http.Client) error {
|
||||
func (c *Cache) downloadToDisk(url, relPath string) error {
|
||||
slog.Info("fetching", "url", url)
|
||||
|
||||
// set the user agent
|
||||
@@ -24,7 +23,7 @@ func downloadToDisk(url, destPath string, c http.Client) error {
|
||||
}
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
resp, err := c.Do(req)
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
slog.Warn("fetch failed", "url", url, "err", err)
|
||||
return err
|
||||
@@ -36,14 +35,14 @@ func downloadToDisk(url, destPath string, c http.Client) error {
|
||||
defer resp.Body.Close()
|
||||
|
||||
// make sure the dir structure exists
|
||||
err = os.MkdirAll(filepath.Dir(destPath), 0750)
|
||||
err = c.cr.MkdirAll(filepath.Dir(relPath), 0750)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// use a tmp file for the initial fetch in case it fails
|
||||
tempPath := destPath + ".tmp"
|
||||
tmpFile, err := os.Create(tempPath)
|
||||
tmpPath := relPath + ".tmp"
|
||||
tmpFile, err := c.cr.Create(tmpPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -51,18 +50,19 @@ func downloadToDisk(url, destPath string, c http.Client) error {
|
||||
|
||||
_, err = io.Copy(tmpFile, resp.Body)
|
||||
if err != nil {
|
||||
removeErr := os.Remove(tempPath)
|
||||
removeErr := c.cr.Remove(tmpPath)
|
||||
if removeErr != nil {
|
||||
slog.Warn("failed to remove temp file", "path", tempPath, "err", removeErr)
|
||||
slog.Warn("failed to remove temp file", "path", tmpPath, "err", removeErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// mv file to final location
|
||||
if err := os.Rename(tempPath, destPath); err != nil {
|
||||
removeErr := os.Remove(tempPath)
|
||||
err = c.cr.Rename(tmpPath, relPath)
|
||||
if err != nil {
|
||||
removeErr := c.cr.Remove(tmpPath)
|
||||
if removeErr != nil {
|
||||
slog.Warn("failed to remove temp file", "path", tempPath, "err", removeErr)
|
||||
slog.Warn("failed to remove temp file", "path", tmpPath, "err", removeErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user