new test for streaming and struct changes and additions
This commit is contained in:
Vendored
+9
-4
@@ -10,8 +10,6 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
const userAgent = "pacman/7.1.0 (Linux x86_64) libalpm/16.0.1"
|
||||
@@ -20,9 +18,10 @@ type Cache struct {
|
||||
cfg CacheConfig
|
||||
cr *os.Root
|
||||
mirrorIdx atomic.Uint64
|
||||
sf singleflight.Group //prevents duplicate downloads
|
||||
mu sync.Mutex
|
||||
refreshMu sync.Mutex
|
||||
client http.Client
|
||||
inFlight map[string]*inFlight
|
||||
inFlightMu sync.Mutex
|
||||
}
|
||||
|
||||
type CacheConfig struct {
|
||||
@@ -33,6 +32,11 @@ type CacheConfig struct {
|
||||
ClientTimeout time.Duration
|
||||
}
|
||||
|
||||
type inFlight struct {
|
||||
done chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
type CacheFile struct {
|
||||
Reader io.ReadCloser
|
||||
Size int64
|
||||
@@ -71,6 +75,7 @@ func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) (*C
|
||||
Timeout: cfg.ClientTimeout,
|
||||
Transport: transport,
|
||||
},
|
||||
inFlight: make(map[string]*inFlight),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
Vendored
+65
@@ -233,3 +233,68 @@ func TestCreateSymlinks(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStreamMultiplClient(t *testing.T) {
|
||||
firstBytesSend := make(chan struct{})
|
||||
const expectedOne = "This is fake file contents"
|
||||
const expectedTwo = "More fake file contents"
|
||||
expected := expectedOne + expectedTwo
|
||||
|
||||
svr := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
//nolint:errcheck //ephemeral no need to check
|
||||
fmt.Fprint(w, expectedOne)
|
||||
w.(http.Flusher).Flush()
|
||||
close(firstBytesSend)
|
||||
time.Sleep(2 * time.Second)
|
||||
fmt.Fprint(w, expectedTwo)
|
||||
}))
|
||||
|
||||
c := newTestCache(t, []string{svr.URL + "/"})
|
||||
c.client.Timeout = 10 * time.Second
|
||||
|
||||
type fetchResult struct {
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
results := make(chan fetchResult, 2)
|
||||
|
||||
for range 2 {
|
||||
go func() {
|
||||
cf, err := c.Fetch("fakefile")
|
||||
if err != nil {
|
||||
results <- fetchResult{err: err}
|
||||
return
|
||||
}
|
||||
defer cf.Reader.Close()
|
||||
data, err := io.ReadAll(cf.Reader)
|
||||
results <- fetchResult{data: data, err: err}
|
||||
}()
|
||||
}
|
||||
|
||||
<-firstBytesSend
|
||||
c.inFlightMu.Lock()
|
||||
_, ok := c.inFlight["fakefile"]
|
||||
c.inFlightMu.Unlock()
|
||||
if !ok {
|
||||
t.Errorf("no matching key in map: %v", c.inFlight)
|
||||
}
|
||||
|
||||
for range 2 {
|
||||
result := <-results
|
||||
if result.err != nil {
|
||||
t.Errorf("a fetch failed: %v", result.err)
|
||||
}
|
||||
if !bytes.Equal(result.data, []byte(expected)) {
|
||||
t.Errorf("expected result to contain %s got %s", expected, result.data)
|
||||
}
|
||||
}
|
||||
|
||||
data, err := c.cr.ReadFile("fakefile")
|
||||
if err != nil {
|
||||
t.Fatalf("Error reading file back: %v", err)
|
||||
}
|
||||
if !bytes.Equal(data, []byte(expected)) {
|
||||
t.Errorf("expected file to contain %s got %s", expected, data)
|
||||
}
|
||||
}
|
||||
|
||||
Vendored
+5
-15
@@ -7,6 +7,9 @@ import (
|
||||
)
|
||||
|
||||
func (c *Cache) Fetch(relPath string) (*CacheFile, error) {
|
||||
// relPath is relative to the localRoot
|
||||
// ie relPath includes /{repo}/os/{arch}/ and the actual name linux-x.x.x.pkg.tar.zst
|
||||
|
||||
// return file directly if exists in cache
|
||||
cf, err := c.getCachedFile(relPath)
|
||||
if err == nil {
|
||||
@@ -14,24 +17,11 @@ func (c *Cache) Fetch(relPath string) (*CacheFile, error) {
|
||||
}
|
||||
|
||||
// fetch file from upstream
|
||||
_, err, _ = c.sf.Do(relPath, func() (any, error) {
|
||||
slog.Debug("calling fetch", "file", relPath)
|
||||
return nil, c.fetch(relPath)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cf, err = c.getCachedFile(relPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cf, nil
|
||||
return nil, c.getStream(relPath)
|
||||
}
|
||||
|
||||
func (c *Cache) fetch(relPath string) error {
|
||||
// relPath is relative to the localRoot
|
||||
// ie relPath includes /{repo}/os/{arch}/ and the actual name linux-x.x.x.pkg.tar.zst
|
||||
func (c *Cache) getStream(relPath string) error {
|
||||
|
||||
// declare vars outside loop
|
||||
var err error
|
||||
|
||||
Vendored
+3
-3
@@ -3,10 +3,10 @@ package cache
|
||||
import "path/filepath"
|
||||
|
||||
func (c *Cache) Refresh() error {
|
||||
if !c.mu.TryLock() {
|
||||
if !c.refreshMu.TryLock() {
|
||||
return nil
|
||||
}
|
||||
defer c.mu.Unlock()
|
||||
defer c.refreshMu.Unlock()
|
||||
|
||||
for _, repo := range c.cfg.mirroredRepos {
|
||||
if err := c.refreshDB(repo); err != nil {
|
||||
@@ -19,7 +19,7 @@ func (c *Cache) Refresh() error {
|
||||
func (c *Cache) refreshDB(repo string) error {
|
||||
dbFile := repo + ".db.tar.gz"
|
||||
dbPath := filepath.Join(repo, "os/x86_64", dbFile)
|
||||
err := c.fetch(dbPath)
|
||||
err := c.getStream(dbPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user