diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 123cecc..d79407a 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -16,29 +16,43 @@ import ( ) type Cache struct { - cacheRoot string - mirrorURLs []string - mirroredRepos []string - mirrorIdx atomic.Uint32 - sf singleflight.Group //prevents duplicate downloads - mu sync.Mutex - client http.Client + cfg CacheConfig + mirrorIdx atomic.Uint32 + sf singleflight.Group //prevents duplicate downloads + mu sync.Mutex + client http.Client +} + +type CacheConfig struct { + cacheRoot string + mirrorURLs []string + mirroredRepos []string + DialTimeout time.Duration + ResponseHeaderTimeout time.Duration + ClientTimeout time.Duration } func NewCache(cacheRoot string, mirrorURLs []string, mirroredRepos []string) *Cache { + cfg := CacheConfig{ + cacheRoot: cacheRoot, + mirrorURLs: mirrorURLs, + mirroredRepos: mirroredRepos, + DialTimeout: 5 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + ClientTimeout: 15 * time.Second, + } + transport := &http.Transport{ DialContext: (&net.Dialer{ - Timeout: 5 * time.Second, + Timeout: cfg.DialTimeout, }).DialContext, - ResponseHeaderTimeout: 10 * time.Second, + ResponseHeaderTimeout: cfg.ResponseHeaderTimeout, } return &Cache{ - cacheRoot: cacheRoot, - mirrorURLs: mirrorURLs, - mirroredRepos: mirroredRepos, + cfg: cfg, client: http.Client{ - Timeout: 15 * time.Second, + Timeout: cfg.ClientTimeout, Transport: transport, }, } @@ -65,8 +79,8 @@ func (c *Cache) fetch(pkgName string) error { // pkgName is relative to the localRoot // ie pkgName includes /{repo}/os/{arch}/ and the actual name linux-x.x.x.pkg.tar.zst tempPkgName := pkgName + ".tmp" - tempPkgPath := filepath.Join(c.cacheRoot, tempPkgName) //full tmp write path - outPkg := filepath.Join(c.cacheRoot, pkgName) + tempPkgPath := filepath.Join(c.cfg.cacheRoot, tempPkgName) //full tmp write path + outPkg := filepath.Join(c.cfg.cacheRoot, pkgName) pkgURL := c.nextMirror() + pkgName log.Printf("fetching %v", pkgURL) @@ -102,5 +116,5 @@ func (c *Cache) fetch(pkgName string) error { func (c *Cache) nextMirror() string { idx := c.mirrorIdx.Add(1) - 1 - return c.mirrorURLs[idx%uint32(len(c.mirrorURLs))] + return c.cfg.mirrorURLs[idx%uint32(len(c.cfg.mirrorURLs))] } diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index 9b05c76..f61e8f9 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -22,7 +22,9 @@ func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server { func newTestCache(t *testing.T, mirrorURL []string) *Cache { t.Helper() mirroredRepos := []string{"core", "extra"} - return NewCache(t.TempDir(), mirrorURL, mirroredRepos) + c := NewCache(t.TempDir(), mirrorURL, mirroredRepos) + c.client.Timeout = 500 * time.Millisecond + return c } func TestFetchFileExists(t *testing.T) { @@ -39,7 +41,7 @@ func TestFetchFileExists(t *testing.T) { t.Fatalf("Fetch failed %v", err) } - fakefilepath := filepath.Join(c.cacheRoot, "fakefile") + fakefilepath := filepath.Join(c.cfg.cacheRoot, "fakefile") data, err := os.ReadFile(fakefilepath) if err != nil { diff --git a/internal/cache/refresh.go b/internal/cache/refresh.go index b5bc4bc..2309222 100644 --- a/internal/cache/refresh.go +++ b/internal/cache/refresh.go @@ -8,7 +8,7 @@ func (c *Cache) Refresh() error { } defer c.mu.Unlock() - for _, repo := range c.mirroredRepos { + for _, repo := range c.cfg.mirroredRepos { if err := c.refreshDB(repo); err != nil { return err }