refactor all file serve logic into internal/cache

This commit is contained in:
2026-05-01 23:44:37 -06:00
parent 6a6006483f
commit 58b5ab55ba
7 changed files with 208 additions and 120 deletions
+5 -3
View File
@@ -1,7 +1,9 @@
cache_root = "/home/ewpt3ch/dev/pacman-cache-server/tmprepo"
mirror_urls = ["https://us.mirrors.cicku.me/archlinux/",
cache_root = "/home/ewpt3ch/dev/pkgstash/tmprepo"
mirror_urls = [
"https://losangeles.mirror.pkgbuild.com/",
"https://mirror.givebytes.net/archlinux/"]
"https://mirror.givebytes.net/archlinux/",
"https://arch.mirror.constant.com/",
]
# array of upstream repos this server caches see pacman.conf
# or pacman docs for more info <core, extra, multilib>
mirrored_repos = ["core", "extra"]
+10 -5
View File
@@ -2,10 +2,11 @@ package main
import (
"errors"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/ewpt3ch/pkgstash/internal/cache"
@@ -28,10 +29,8 @@ func (s *Server) handlePackage(w http.ResponseWriter, req *http.Request) {
arch := req.PathValue("arch")
file := req.PathValue("file")
repoPath := filepath.Join(repo, "os", arch, file) //path from mirror root to pkg or db file
cachePath := filepath.Join(s.cfg.CacheRoot, repoPath) //absolute path for local read of the file
if _, err := os.Stat(cachePath); err != nil {
err = s.c.Fetch(repoPath)
cachedFile, err := s.c.Fetch(repoPath)
if err != nil {
var upstreamErr *cache.UpstreamError
if errors.As(err, &upstreamErr) {
@@ -43,8 +42,14 @@ func (s *Server) handlePackage(w http.ResponseWriter, req *http.Request) {
http.Error(w, "Failed to fetch from upstream", http.StatusBadGateway)
return
}
defer cachedFile.Reader.Close()
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Disposition", "attachment; filename="+cachedFile.Filename)
w.Header().Set("Content-Length", strconv.FormatInt(cachedFile.Size, 10))
_, err = io.Copy(w, cachedFile.Reader)
if err != nil {
log.Printf("error streaming file to client: %v", err)
}
http.ServeFile(w, req, cachePath)
}
+9 -94
View File
@@ -3,11 +3,8 @@ package cache
import (
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"path/filepath"
"sync"
"sync/atomic"
"time"
@@ -15,6 +12,8 @@ import (
"golang.org/x/sync/singleflight"
)
const userAgent = "pacman/7.1.0 (Linux x86_64) libalpm/16.0.1"
type Cache struct {
cfg CacheConfig
mirrorIdx atomic.Uint32
@@ -27,21 +26,25 @@ type CacheConfig struct {
cacheRoot string
mirrorURLs []string
mirroredRepos []string
userAgent string
DialTimeout time.Duration
ResponseHeaderTimeout time.Duration
ClientTimeout time.Duration
}
type CacheFile struct {
Reader io.ReadCloser
Size int64
Filename string
}
func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) *Cache {
cfg := CacheConfig{
cacheRoot: cacheRoot,
mirrorURLs: mirrorURLs,
mirroredRepos: mirroredRepos,
userAgent: "pacman/7.1.0 (Linux x86_64) libalpm/16.0.1",
DialTimeout: 5 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
ClientTimeout: 15 * time.Second,
ClientTimeout: 0 * time.Second,
}
transport := &http.Transport{
@@ -60,15 +63,6 @@ func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) *Ca
}
}
func (c *Cache) Fetch(pkgPath string) error {
log.Printf("pkgPath from Fetch %v", pkgPath)
_, err, _ := c.sf.Do(pkgPath, func() (any, error) {
log.Print("calling fetch")
return nil, c.fetch(pkgPath)
})
return err
}
type UpstreamError struct {
StatusCode int
}
@@ -76,82 +70,3 @@ type UpstreamError struct {
func (e *UpstreamError) Error() string {
return fmt.Sprintf("upstream returned %d", e.StatusCode)
}
func (c *Cache) fetch(pkgName string) error {
// pkgName is relative to the localRoot
// ie pkgName includes /{repo}/os/{arch}/ and the actual name linux-x.x.x.pkg.tar.zst
tempPkgName := pkgName + ".tmp"
tempPkgPath := filepath.Join(c.cfg.cacheRoot, tempPkgName) //full tmp write path
// final file name and path
outPkg := filepath.Join(c.cfg.cacheRoot, pkgName)
// declare vars outside loop
var resp *http.Response
var req *http.Request
var err error
// fetch pkgs from mirror with retry logic
for range len(c.cfg.mirrorURLs) {
pkgURL := c.nextMirror() + pkgName
log.Printf("fetching %v", pkgURL)
// set the user agent
req, err = http.NewRequest("GET", pkgURL, nil)
if err != nil {
log.Printf("failed to create request: %v", err)
return &UpstreamError{StatusCode: http.StatusInternalServerError}
}
req.Header.Set("User-Agent", c.cfg.userAgent)
resp, err = c.client.Do(req)
if err != nil {
log.Printf("error fetching %s: %v", pkgURL, err)
continue
}
if resp.StatusCode == http.StatusOK {
break
}
log.Printf("retrying on code %v", resp.StatusCode)
resp.Body.Close()
}
if resp == nil {
return fmt.Errorf("all mirrors exhausted")
}
defer resp.Body.Close()
if err != nil {
log.Printf("exhauted all mirrors error: %v", err)
return err
}
if resp.StatusCode != http.StatusOK {
log.Printf("exhauted all mirrors %v", resp.StatusCode)
return &UpstreamError{StatusCode: resp.StatusCode}
}
// use a tmp file for the initial fetch in case it fails
outFile, err := os.Create(tempPkgPath)
if err != nil {
return err
}
defer outFile.Close()
_, err = io.Copy(outFile, resp.Body)
if err != nil {
os.Remove(tempPkgPath)
return err
}
// mv file to final location
if err := os.Rename(tempPkgPath, outPkg); err != nil {
os.Remove(tempPkgPath)
return err
}
return nil
}
func (c *Cache) nextMirror() string {
idx := c.mirrorIdx.Add(1) - 1
return c.cfg.mirrorURLs[idx%uint32(len(c.cfg.mirrorURLs))]
}
+44 -7
View File
@@ -4,6 +4,7 @@ import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
@@ -27,7 +28,43 @@ func newTestCache(t *testing.T, mirrorURLs []string) *Cache {
return c
}
func TestFetchFileExists(t *testing.T) {
func TestCacheHit(t *testing.T) {
const expected = "This is fake file contents"
c := newTestCache(t, []string{"http://example.com/"})
tmpFileName := "fakeFile"
tmpPath := filepath.Join(c.cfg.cacheRoot, tmpFileName)
err := os.WriteFile(tmpPath, []byte(expected), 0644)
if err != nil {
t.Fatalf("failed to create tempfile: %v", err)
}
cachedFile, err := c.Fetch(tmpFileName)
if err != nil {
t.Fatalf("failed to fetch file: %v", err)
}
defer cachedFile.Reader.Close()
if cachedFile.Filename != tmpFileName {
t.Errorf("expected filename %s got %s", tmpFileName, cachedFile.Filename)
}
if int64(len(expected)) != cachedFile.Size {
t.Errorf("expected %d got %d", len(expected), cachedFile.Size)
}
data, err := io.ReadAll(cachedFile.Reader)
if err != nil {
t.Fatalf("error reading file back: %v", err)
}
defer cachedFile.Reader.Close()
if !bytes.Equal(data, []byte(expected)) {
t.Errorf("expected file to contain %s got %s", expected, data)
}
}
func TestCacheMissExists(t *testing.T) {
const expected = "This is fake file contents"
svr := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -36,7 +73,7 @@ func TestFetchFileExists(t *testing.T) {
c := newTestCache(t, []string{svr.URL + "/"})
err := c.Fetch("fakefile")
_, err := c.Fetch("fakefile")
if err != nil {
t.Fatalf("Fetch failed %v", err)
}
@@ -59,7 +96,7 @@ func TestFetchNotFound(t *testing.T) {
c := newTestCache(t, []string{svr.URL + "/"})
err := c.Fetch("fakefile")
_, err := c.Fetch("fakefile")
var upstreamErr *UpstreamError
if !errors.As(err, &upstreamErr) {
t.Fatalf("expected UpstreamError got %v", err)
@@ -76,7 +113,7 @@ func TestFetchSrvError(t *testing.T) {
c := newTestCache(t, []string{svr.URL + "/"})
err := c.Fetch("fakefile")
_, err := c.Fetch("fakefile")
var upstreamErr *UpstreamError
if !errors.As(err, &upstreamErr) {
t.Fatalf("expected UpstreamError fot %v", err)
@@ -99,7 +136,7 @@ func TestFetchSrvDead(t *testing.T) {
c := newTestCache(t, []string{svr.URL + "/"})
err := c.Fetch("fakefile")
_, err := c.Fetch("fakefile")
if err == nil {
t.Fatal("expected err got nil")
}
@@ -127,7 +164,7 @@ func TestFetchRetryExists(t *testing.T) {
c := newTestCache(t, fakeURLs)
err := c.Fetch("fakefile")
_, err := c.Fetch("fakefile")
if err != nil {
t.Fatalf("fetch failed: %v", err)
}
@@ -159,7 +196,7 @@ func TestFetchRetryNonExist(t *testing.T) {
c := newTestCache(t, fakeURLs)
err := c.Fetch("fakefile")
_, err := c.Fetch("fakefile")
var upstreamErr *UpstreamError
if !errors.As(err, &upstreamErr) {
t.Errorf("expected UpstreamError got %v", err)
+72
View File
@@ -0,0 +1,72 @@
package cache
import (
"log"
"os"
"path/filepath"
)
func (c *Cache) Fetch(relPath string) (*CacheFile, error) {
// return file directly if exists in cache
cf, err := getCachedFile(c.cfg.cacheRoot, relPath)
if err == nil {
return cf, nil
}
// fetch file from upstream
_, err, _ = c.sf.Do(relPath, func() (any, error) {
log.Print("calling fetch")
return nil, c.fetch(relPath)
})
if err != nil {
return nil, err
}
cf, err = getCachedFile(c.cfg.cacheRoot, relPath)
if err != nil {
return nil, err
}
return cf, nil
}
func (c *Cache) fetch(relPath string) error {
// relPath is relative to the localRoot
// ie relPath includes /{repo}/os/{arch}/ and the actual name linux-x.x.x.pkg.tar.zst
// final file name and path
destPath := filepath.Join(c.cfg.cacheRoot, relPath)
// declare vars outside loop
var err error
// fetch pkgs from mirror with retry logic
for range len(c.cfg.mirrorURLs) {
url := c.nextMirror() + relPath
err = downloadToDisk(url, destPath, c.client)
if err == nil {
break
}
}
if err != nil {
return err
}
return nil
}
func getCachedFile(cacheRoot, relPath string) (*CacheFile, error) {
filePath := filepath.Join(cacheRoot, relPath)
info, err := os.Stat(filePath)
if err != nil {
return nil, err
}
f, err := os.Open(filePath)
if err != nil {
return nil, err
}
return &CacheFile{
Reader: f,
Size: info.Size(),
Filename: filepath.Base(filePath),
}, nil
}
+57
View File
@@ -0,0 +1,57 @@
package cache
import (
"io"
"log"
"net/http"
"os"
)
func (c *Cache) nextMirror() string {
idx := c.mirrorIdx.Add(1) - 1
return c.cfg.mirrorURLs[idx%uint32(len(c.cfg.mirrorURLs))]
}
func downloadToDisk(url, destPath string, c http.Client) error {
log.Printf("fetching %v", url)
// set the user agent
req, err := http.NewRequest("GET", url, nil)
if err != nil {
log.Printf("failed to create request: %v", err)
return &UpstreamError{StatusCode: http.StatusInternalServerError}
}
req.Header.Set("User-Agent", userAgent)
resp, err := c.Do(req)
if err != nil {
log.Printf("error fetching %s: %v", url, err)
return err
}
if resp.StatusCode != 200 {
log.Printf("GET %s returned %d", url, resp.StatusCode)
return &UpstreamError{StatusCode: resp.StatusCode}
}
defer resp.Body.Close()
// use a tmp file for the initial fetch in case it fails
tempPath := destPath + ".tmp"
tmpFile, err := os.Create(tempPath)
if err != nil {
return err
}
defer tmpFile.Close()
_, err = io.Copy(tmpFile, resp.Body)
if err != nil {
os.Remove(tempPath)
return err
}
// mv file to final location
if err := os.Rename(tempPath, destPath); err != nil {
os.Remove(tempPath)
return err
}
return nil
}
+1 -1
View File
@@ -19,7 +19,7 @@ func (c *Cache) Refresh() error {
func (c *Cache) refreshDB(repo string) error {
dbFile := repo + ".db.tar.gz"
dbPath := filepath.Join(repo, "os/x86_64", dbFile)
err := c.Fetch(dbPath)
err := c.fetch(dbPath)
if err != nil {
return err
}