use os.*Root instead of os to prevent filesystem traversal bugs and security leaks
This commit is contained in:
+4
-1
@@ -34,7 +34,10 @@ func newTestServer(t *testing.T, mirrorHandler http.HandlerFunc) (*httptest.Serv
|
|||||||
mirror := httptest.NewServer(mirrorHandler)
|
mirror := httptest.NewServer(mirrorHandler)
|
||||||
t.Cleanup(func() { mirror.Close() })
|
t.Cleanup(func() { mirror.Close() })
|
||||||
|
|
||||||
c := cache.NewCache(t.TempDir(), []string{mirror.URL + "/"}, []string{"core"})
|
c, err := cache.NewCache(t.TempDir(), []string{mirror.URL + "/"}, []string{"core"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create cache: %v", err)
|
||||||
|
}
|
||||||
cfg := &Config{
|
cfg := &Config{
|
||||||
Port: "0",
|
Port: "0",
|
||||||
Auth: AuthConfig{Token: "testtoken"},
|
Auth: AuthConfig{Token: "testtoken"},
|
||||||
|
|||||||
Vendored
+10
-4
@@ -5,6 +5,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -16,6 +17,7 @@ const userAgent = "pacman/7.1.0 (Linux x86_64) libalpm/16.0.1"
|
|||||||
|
|
||||||
type Cache struct {
|
type Cache struct {
|
||||||
cfg CacheConfig
|
cfg CacheConfig
|
||||||
|
cr *os.Root
|
||||||
mirrorIdx atomic.Uint64
|
mirrorIdx atomic.Uint64
|
||||||
sf singleflight.Group //prevents duplicate downloads
|
sf singleflight.Group //prevents duplicate downloads
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -23,7 +25,6 @@ type Cache struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CacheConfig struct {
|
type CacheConfig struct {
|
||||||
cacheRoot string
|
|
||||||
mirrorURLs []string
|
mirrorURLs []string
|
||||||
mirroredRepos []string
|
mirroredRepos []string
|
||||||
DialTimeout time.Duration
|
DialTimeout time.Duration
|
||||||
@@ -37,9 +38,8 @@ type CacheFile struct {
|
|||||||
Filename string
|
Filename string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) *Cache {
|
func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) (*Cache, error) {
|
||||||
cfg := CacheConfig{
|
cfg := CacheConfig{
|
||||||
cacheRoot: cacheRoot,
|
|
||||||
mirrorURLs: mirrorURLs,
|
mirrorURLs: mirrorURLs,
|
||||||
mirroredRepos: mirroredRepos,
|
mirroredRepos: mirroredRepos,
|
||||||
DialTimeout: 5 * time.Second,
|
DialTimeout: 5 * time.Second,
|
||||||
@@ -54,13 +54,19 @@ func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) *Ca
|
|||||||
ResponseHeaderTimeout: cfg.ResponseHeaderTimeout,
|
ResponseHeaderTimeout: cfg.ResponseHeaderTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cr, err := os.OpenRoot(cacheRoot)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &Cache{
|
return &Cache{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
cr: cr,
|
||||||
client: http.Client{
|
client: http.Client{
|
||||||
Timeout: cfg.ClientTimeout,
|
Timeout: cfg.ClientTimeout,
|
||||||
Transport: transport,
|
Transport: transport,
|
||||||
},
|
},
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpstreamError struct {
|
type UpstreamError struct {
|
||||||
|
|||||||
Vendored
+8
-12
@@ -7,8 +7,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -23,7 +21,10 @@ func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
|
|||||||
func newTestCache(t *testing.T, mirrorURLs []string) *Cache {
|
func newTestCache(t *testing.T, mirrorURLs []string) *Cache {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
mirroredRepos := []string{"core", "extra"}
|
mirroredRepos := []string{"core", "extra"}
|
||||||
c := NewCache(t.TempDir(), mirrorURLs, mirroredRepos)
|
c, err := NewCache(t.TempDir(), mirrorURLs, mirroredRepos)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create cache: %v", err)
|
||||||
|
}
|
||||||
c.client.Timeout = 500 * time.Millisecond
|
c.client.Timeout = 500 * time.Millisecond
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
@@ -33,8 +34,7 @@ func TestCacheHit(t *testing.T) {
|
|||||||
|
|
||||||
c := newTestCache(t, []string{"http://example.com/"})
|
c := newTestCache(t, []string{"http://example.com/"})
|
||||||
tmpFileName := "fakeFile"
|
tmpFileName := "fakeFile"
|
||||||
tmpPath := filepath.Join(c.cfg.cacheRoot, tmpFileName)
|
err := c.cr.WriteFile(tmpFileName, []byte(expected), 0644)
|
||||||
err := os.WriteFile(tmpPath, []byte(expected), 0644)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create tempfile: %v", err)
|
t.Fatalf("failed to create tempfile: %v", err)
|
||||||
}
|
}
|
||||||
@@ -78,9 +78,7 @@ func TestCacheMissExists(t *testing.T) {
|
|||||||
t.Fatalf("Fetch failed %v", err)
|
t.Fatalf("Fetch failed %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fakefilepath := filepath.Join(c.cfg.cacheRoot, "fakefile")
|
data, err := c.cr.ReadFile("fakefile")
|
||||||
|
|
||||||
data, err := os.ReadFile(fakefilepath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error reading file back: %v", err)
|
t.Fatalf("Error reading file back: %v", err)
|
||||||
}
|
}
|
||||||
@@ -141,8 +139,7 @@ func TestFetchSrvDead(t *testing.T) {
|
|||||||
t.Fatal("expected err got nil")
|
t.Fatal("expected err got nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
var upstreamErr *UpstreamError
|
if _, ok := errors.AsType[*UpstreamError](err); ok {
|
||||||
if errors.As(err, &upstreamErr) {
|
|
||||||
t.Error("expected network error not UpstreamError")
|
t.Error("expected network error not UpstreamError")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -169,8 +166,7 @@ func TestFetchRetryExists(t *testing.T) {
|
|||||||
t.Fatalf("fetch failed: %v", err)
|
t.Fatalf("fetch failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fakefilepath := filepath.Join(c.cfg.cacheRoot, "fakefile")
|
data, err := c.cr.ReadFile("fakefile")
|
||||||
data, err := os.ReadFile(fakefilepath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error reading file back: %v", err)
|
t.Fatalf("error reading file back: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Vendored
+7
-12
@@ -3,13 +3,12 @@ package cache
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Cache) Fetch(relPath string) (*CacheFile, error) {
|
func (c *Cache) Fetch(relPath string) (*CacheFile, error) {
|
||||||
// return file directly if exists in cache
|
// return file directly if exists in cache
|
||||||
cf, err := getCachedFile(c.cfg.cacheRoot, relPath)
|
cf, err := c.getCachedFile(relPath)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return cf, nil
|
return cf, nil
|
||||||
}
|
}
|
||||||
@@ -23,7 +22,7 @@ func (c *Cache) Fetch(relPath string) (*CacheFile, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cf, err = getCachedFile(c.cfg.cacheRoot, relPath)
|
cf, err = c.getCachedFile(relPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -34,15 +33,12 @@ func (c *Cache) fetch(relPath string) error {
|
|||||||
// relPath is relative to the localRoot
|
// relPath is relative to the localRoot
|
||||||
// ie relPath includes /{repo}/os/{arch}/ and the actual name linux-x.x.x.pkg.tar.zst
|
// ie relPath includes /{repo}/os/{arch}/ and the actual name linux-x.x.x.pkg.tar.zst
|
||||||
|
|
||||||
// final file name and path
|
|
||||||
destPath := filepath.Join(c.cfg.cacheRoot, relPath)
|
|
||||||
|
|
||||||
// declare vars outside loop
|
// declare vars outside loop
|
||||||
var err error
|
var err error
|
||||||
// fetch pkgs from mirror with retry logic
|
// fetch pkgs from mirror with retry logic
|
||||||
for range len(c.cfg.mirrorURLs) {
|
for range len(c.cfg.mirrorURLs) {
|
||||||
url := c.nextMirror() + relPath
|
url := c.nextMirror() + relPath
|
||||||
err = downloadToDisk(url, destPath, c.client)
|
err = c.downloadToDisk(url, relPath)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -58,14 +54,13 @@ func (c *Cache) fetch(relPath string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCachedFile(cacheRoot, relPath string) (*CacheFile, error) {
|
func (c *Cache) getCachedFile(relPath string) (*CacheFile, error) {
|
||||||
filePath := filepath.Join(cacheRoot, relPath)
|
info, err := c.cr.Stat(relPath)
|
||||||
info, err := os.Stat(filePath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
f, err := os.Open(filePath)
|
f, err := c.cr.Open(relPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -73,6 +68,6 @@ func getCachedFile(cacheRoot, relPath string) (*CacheFile, error) {
|
|||||||
return &CacheFile{
|
return &CacheFile{
|
||||||
Reader: f,
|
Reader: f,
|
||||||
Size: info.Size(),
|
Size: info.Size(),
|
||||||
Filename: filepath.Base(filePath),
|
Filename: filepath.Base(relPath),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Vendored
+11
-11
@@ -4,7 +4,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,7 +13,7 @@ func (c *Cache) nextMirror() string {
|
|||||||
return c.cfg.mirrorURLs[idx%mirrorCount]
|
return c.cfg.mirrorURLs[idx%mirrorCount]
|
||||||
}
|
}
|
||||||
|
|
||||||
func downloadToDisk(url, destPath string, c http.Client) error {
|
func (c *Cache) downloadToDisk(url, relPath string) error {
|
||||||
slog.Info("fetching", "url", url)
|
slog.Info("fetching", "url", url)
|
||||||
|
|
||||||
// set the user agent
|
// set the user agent
|
||||||
@@ -24,7 +23,7 @@ func downloadToDisk(url, destPath string, c http.Client) error {
|
|||||||
}
|
}
|
||||||
req.Header.Set("User-Agent", userAgent)
|
req.Header.Set("User-Agent", userAgent)
|
||||||
|
|
||||||
resp, err := c.Do(req)
|
resp, err := c.client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("fetch failed", "url", url, "err", err)
|
slog.Warn("fetch failed", "url", url, "err", err)
|
||||||
return err
|
return err
|
||||||
@@ -36,14 +35,14 @@ func downloadToDisk(url, destPath string, c http.Client) error {
|
|||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// make sure the dir structure exists
|
// make sure the dir structure exists
|
||||||
err = os.MkdirAll(filepath.Dir(destPath), 0750)
|
err = c.cr.MkdirAll(filepath.Dir(relPath), 0750)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// use a tmp file for the initial fetch in case it fails
|
// use a tmp file for the initial fetch in case it fails
|
||||||
tempPath := destPath + ".tmp"
|
tmpPath := relPath + ".tmp"
|
||||||
tmpFile, err := os.Create(tempPath)
|
tmpFile, err := c.cr.Create(tmpPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -51,18 +50,19 @@ func downloadToDisk(url, destPath string, c http.Client) error {
|
|||||||
|
|
||||||
_, err = io.Copy(tmpFile, resp.Body)
|
_, err = io.Copy(tmpFile, resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
removeErr := os.Remove(tempPath)
|
removeErr := c.cr.Remove(tmpPath)
|
||||||
if removeErr != nil {
|
if removeErr != nil {
|
||||||
slog.Warn("failed to remove temp file", "path", tempPath, "err", removeErr)
|
slog.Warn("failed to remove temp file", "path", tmpPath, "err", removeErr)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// mv file to final location
|
// mv file to final location
|
||||||
if err := os.Rename(tempPath, destPath); err != nil {
|
err = c.cr.Rename(tmpPath, relPath)
|
||||||
removeErr := os.Remove(tempPath)
|
if err != nil {
|
||||||
|
removeErr := c.cr.Remove(tmpPath)
|
||||||
if removeErr != nil {
|
if removeErr != nil {
|
||||||
slog.Warn("failed to remove temp file", "path", tempPath, "err", removeErr)
|
slog.Warn("failed to remove temp file", "path", tmpPath, "err", removeErr)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,7 +48,11 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
c := cache.NewCache(cfg.CacheRoot, cfg.MirrorURLs, cfg.MirroredRepos)
|
c, err := cache.NewCache(cfg.CacheRoot, cfg.MirrorURLs, cfg.MirroredRepos)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to create cache", "err", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
c: c,
|
c: c,
|
||||||
|
|||||||
Reference in New Issue
Block a user