moved server main to cmd/server

This commit is contained in:
2026-05-07 10:51:21 -06:00
parent 0deb1961fe
commit 5eafd202af
8 changed files with 1 additions and 0 deletions
+64
View File
@@ -0,0 +1,64 @@
package main
import (
"fmt"
"github.com/BurntSushi/toml"
)
type Config struct {
CacheRoot string `toml:"cache_root"`
MirrorURLs []string `toml:"mirror_urls"`
MirroredRepos []string `toml:"mirrored_repos"`
Port string `toml:"port"`
Auth AuthConfig `toml:"auth"`
}
type AuthConfig struct {
Token string `toml:"token"`
}
/* Function kept for reference for future logic
func NewConfig() *Config {
return &Config{
CacheRoot: "/home/ewpt3ch/dev/pacman-cache-server/tmprepo",
MirrorURLs: "https://us.mirrors.cicku.me/archlinux/",
Port: "8090",
Auth: AuthConfig{Token: "FakeToken"},
}
}
*/
func ReadConfig(path string) (*Config, error) {
var cfg Config
_, err := toml.DecodeFile(path, &cfg)
if err != nil {
return nil, fmt.Errorf("error loading config from %s: %w", path, err)
}
if err = cfg.validate(); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
}
return &cfg, nil
}
func (c *Config) validate() error {
if c.CacheRoot == "" {
return fmt.Errorf("cache root is required")
}
if len(c.MirrorURLs) == 0 {
return fmt.Errorf("at least one mirror is required")
}
if len(c.MirroredRepos) == 0 {
return fmt.Errorf("at least one repo is required")
}
if c.Port == "" {
return fmt.Errorf("port required")
}
if c.Auth.Token == "" {
return fmt.Errorf("auth token is required")
}
return nil
}
+138
View File
@@ -0,0 +1,138 @@
package main
import (
"errors"
"os"
"path/filepath"
"testing"
)
func writeConfigFile(t *testing.T, content string) string {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "config.toml")
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
t.Fatalf("failed write test config: %v", err)
}
return path
}
func TestReadConfig(t *testing.T) {
path := writeConfigFile(t, `
cache_root = "srv/cache"
mirror_urls = ["https://mirror.example.com"]
mirrored_repos = ["core", "extra"]
port = "8090"
[auth]
token = "testtoken"
`)
cfg, err := ReadConfig(path)
if err != nil {
t.Fatalf("expected no err on read got: %v", err)
}
if cfg.Port != "8090" {
t.Errorf("expected port 8090 got %s", cfg.Port)
}
}
func TestMissingCacheRoot(t *testing.T) {
path := writeConfigFile(t, `
mirror_urls = ["https://mirror.example.com"]
mirrored_repos = ["core", "extra"]
port = "8090"
[auth]
token = "testtoken"
`)
_, err := ReadConfig(path)
if err == nil {
t.Fatal("expected err got nil")
}
}
func TestMissingMirrorUrls(t *testing.T) {
path := writeConfigFile(t, `
cache_root = "srv/cache"
mirrored_repos = ["core", "extra"]
port = "8090"
[auth]
token = "testtoken"
`)
_, err := ReadConfig(path)
if err == nil {
t.Fatal("expected err got nil")
}
}
func TestMissingMirroredRepos(t *testing.T) {
path := writeConfigFile(t, `
cache_root = "srv/cache"
mirror_urls = ["https://mirror.example.com"]
port = "8090"
[auth]
token = "testtoken"
`)
_, err := ReadConfig(path)
if err == nil {
t.Fatal("expected err got nil")
}
}
func TestMissingPort(t *testing.T) {
path := writeConfigFile(t, `
cache_root = "srv/cache"
mirror_urls = ["https://mirror.example.com"]
mirrored_repos = ["core", "extra"]
[auth]
token = "testtoken"
`)
_, err := ReadConfig(path)
if err == nil {
t.Fatal("expected err got nil")
}
}
func TestMissingAuthToken(t *testing.T) {
path := writeConfigFile(t, `
cache_root = "srv/cache"
mirror_urls = ["https://mirror.example.com"]
mirrored_repos = ["core", "extra"]
port = "8090"
[auth]
`)
_, err := ReadConfig(path)
if err == nil {
t.Fatal("expected err got nil")
}
}
func TestMissingFile(t *testing.T) {
path := filepath.Join(t.TempDir(), "nonexistant.toml")
_, err := ReadConfig(path)
if !errors.Is(err, os.ErrNotExist) {
t.Fatal("expected err got nil")
}
}
func TestInvalidToml(t *testing.T) {
path := writeConfigFile(t, `
cache_root = [srv/cache]
`)
_, err := ReadConfig(path)
if err == nil {
t.Fatal("expected err got nil")
}
}
+69
View File
@@ -0,0 +1,69 @@
package main
import (
"encoding/json"
"log/slog"
"net/http"
"strings"
)
func (s *Server) handlerRefresh(w http.ResponseWriter, req *http.Request) {
if req.Header.Get("Authorization") != "Bearer "+s.cfg.Auth.Token {
ip := req.Header.Get("X-Real-IP")
if ip == "" {
ip = req.RemoteAddr
}
slog.Warn("unauthorized request", "ip", ip, "path", req.URL.Path, "method", req.Method)
respondWithError(w, http.StatusUnauthorized, "unauthorized")
return
}
if err := s.c.Refresh(); err != nil {
slog.Error("refresh failed", "err", err)
http.Error(w, "refresh failed", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
}
func (s *Server) handlerLogLevel(w http.ResponseWriter, req *http.Request) {
if req.Header.Get("Authorization") != "Bearer "+s.cfg.Auth.Token {
ip := req.Header.Get("X-Real-IP")
if ip == "" {
ip = req.RemoteAddr
}
slog.Warn("unauthorized request", "ip", ip, "path", req.URL.Path, "method", req.Method)
respondWithError(w, http.StatusUnauthorized, "unauthorized")
return
}
type reqParameters struct {
NewLevel string `json:"loglevel"`
}
decoder := json.NewDecoder(req.Body)
reqParams := reqParameters{}
err := decoder.Decode(&reqParams)
if err != nil {
slog.Debug("json decode erro", "err", err)
respondWithError(w, http.StatusBadRequest, "invalid request")
return
}
switch strings.ToLower(reqParams.NewLevel) {
case "debug":
s.logLevel.Set(slog.LevelDebug)
case "info":
s.logLevel.Set(slog.LevelInfo)
case "warn":
s.logLevel.Set(slog.LevelWarn)
case "error":
s.logLevel.Set(slog.LevelError)
default:
respondWithError(w, http.StatusBadRequest, "invalid log level")
return
}
slog.Info("log level changed", "level", reqParams.NewLevel)
w.WriteHeader(http.StatusNoContent)
}
+58
View File
@@ -0,0 +1,58 @@
package main
import (
"errors"
"io"
"log/slog"
"net/http"
"path/filepath"
"strconv"
"strings"
"github.com/ewpt3ch/pkgstash/internal/cache"
)
func (s *Server) handlerPackage(w http.ResponseWriter, req *http.Request) {
// db files are not signed so we ignore as to not spam mirrors
if strings.HasSuffix(req.PathValue("file"), ".db.sig") {
w.WriteHeader(http.StatusNotFound)
return
}
// record the useragent from requestor
slog.Debug("Requestors User Agent", "UA", req.Header.Get("User-Agent"))
// build file paths from the request, they follow archlinux repo
// <mirrorroot>/[core, extra, etc]/os/[x86_64, arm, etc]/package.pkg.tar.zst[.sig]
repo := req.PathValue("repo")
arch := req.PathValue("arch")
file := req.PathValue("file")
repoPath := filepath.Join(repo, "os", arch, file) //path from mirror root to requested file
cachedFile, err := s.c.Fetch(repoPath)
if err != nil {
if upstreamErr, ok := errors.AsType[*cache.UpstreamError](err); ok {
slog.Warn("upstream error", "err", upstreamErr.Error())
http.Error(w, "Not found upstream", upstreamErr.StatusCode)
return
}
slog.Warn("fetch error", "err", err)
http.Error(w, "Failed to fetch from upstream", http.StatusBadGateway)
return
}
defer func() {
if closeErr := cachedFile.Reader.Close(); closeErr != nil {
err = closeErr
}
}()
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Disposition", "attachment; filename="+cachedFile.Filename)
w.Header().Set("Content-Length", strconv.FormatInt(cachedFile.Size, 10))
_, err = io.Copy(w, cachedFile.Reader)
if err != nil {
slog.Warn("streaming error", "err", err)
}
}
+206
View File
@@ -0,0 +1,206 @@
package main
import (
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"github.com/ewpt3ch/pkgstash/internal/cache"
)
var (
//nolint:errcheck //ephemeral no need to check
mirrorOK = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "fake pkg data") })
mirror404 = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) })
)
func mirrorOKWithCounter() (http.HandlerFunc, *atomic.Int32) {
var calls atomic.Int32
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
calls.Add(1)
//nolint:errcheck //ephemeral no need to check
fmt.Fprint(w, "fake pkg data")
}), &calls
}
const testUrlBase = "/core/os/x86_64"
func newTestServer(t *testing.T, mirrorHandler http.HandlerFunc) (*httptest.Server, *Server) {
t.Helper()
mirror := httptest.NewServer(mirrorHandler)
t.Cleanup(func() { mirror.Close() })
c, err := cache.NewCache(t.TempDir(), []string{mirror.URL + "/"}, []string{"core"})
if err != nil {
t.Fatalf("failed to create cache: %v", err)
}
cfg := &Config{
Port: "0",
Auth: AuthConfig{Token: "testtoken"},
}
logLevel := new(slog.LevelVar)
srv := &Server{
cfg: cfg,
c: c,
logLevel: logLevel,
}
mux := http.NewServeMux()
mux.HandleFunc("GET /{repo}/os/{arch}/{file}", srv.handlerPackage)
mux.HandleFunc("POST /api/refresh", srv.handlerRefresh)
mux.HandleFunc("POST /api/loglevel", srv.handlerLogLevel)
ts := httptest.NewServer(mux)
t.Cleanup(ts.Close)
return ts, srv
}
func TestHandlerPkgsExist(t *testing.T) {
const expected = "fake pkg data"
const expectedFile = "attachment; filename=somepkg.tar.zst"
ts, _ := newTestServer(t, mirrorOK)
resp, err := http.Get(ts.URL + testUrlBase + "/somepkg.tar.zst")
if err != nil {
t.Fatalf("GET failed: %v", err)
}
respFile := resp.Header.Get("Content-Disposition")
if resp.ContentLength != int64(len(expected)) {
t.Errorf("expected %d got %d", len(expected), resp.ContentLength)
}
if respFile != expectedFile {
t.Errorf("expected %s got %s", expectedFile, respFile)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("err reading body: %v", err)
}
if string(data) != expected {
t.Errorf("expected %s got %s", expected, string(data))
}
}
func TestHandlerPkgsMiss(t *testing.T) {
ts, _ := newTestServer(t, mirror404)
resp, err := http.Get(ts.URL + testUrlBase + "/somepkg.tar.zst")
if err != nil {
t.Fatalf("GET failed %v", err)
}
if resp.StatusCode != http.StatusNotFound {
t.Errorf("expected 404 got %d", resp.StatusCode)
}
}
func TestHandlerPkgsDBSig(t *testing.T) {
handler, calls := mirrorOKWithCounter()
ts, _ := newTestServer(t, handler)
resp, err := http.Get(ts.URL + testUrlBase + "/core.db.sig")
if err != nil {
t.Fatalf("GET failed %v", err)
}
if resp.StatusCode != http.StatusNotFound {
t.Errorf("expected 404 got %d", resp.StatusCode)
}
if calls.Load() != 0 {
t.Error("expected no upstream calls for .db.sig")
}
}
func TestHandlerRefreshUnauthorized(t *testing.T) {
ts, _ := newTestServer(t, mirrorOK)
req, err := http.NewRequest("POST", ts.URL+"/api/refresh", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("Authorization", "Bearer badtoken")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("POST failed: %v", err)
}
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected %d got %d", http.StatusUnauthorized, resp.StatusCode)
}
}
func TestHandlerRefreshOK(t *testing.T) {
ts, _ := newTestServer(t, mirrorOK)
req, err := http.NewRequest("POST", ts.URL+"/api/refresh", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("Authorization", "Bearer testtoken")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("POST failed: %v", err)
}
if resp.StatusCode != http.StatusNoContent {
t.Errorf("expected %d got %d", http.StatusNoContent, resp.StatusCode)
}
}
func TestHandlerLogLevelValid(t *testing.T) {
ts, srv := newTestServer(t, mirrorOK)
body := strings.NewReader(`{"loglevel": "debug"}`)
req, err := http.NewRequest("POST", ts.URL+"/api/loglevel", body)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer testtoken")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("POST failed: %v", err)
}
if resp.StatusCode != http.StatusNoContent {
t.Errorf("expected %d got %d", http.StatusNoContent, resp.StatusCode)
}
got := srv.logLevel.Level()
if got != slog.LevelDebug {
t.Errorf("expected DEBUG got %s", got)
}
}
func TestHandlerLogLevelInvalid(t *testing.T) {
ts, srv := newTestServer(t, mirrorOK)
body := strings.NewReader(`{"loglevel": "what"}`)
req, err := http.NewRequest("POST", ts.URL+"/api/loglevel", body)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer testtoken")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("POST failed: %v", err)
}
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected %d got %d", http.StatusBadRequest, resp.StatusCode)
}
got := srv.logLevel.Level()
if got != slog.LevelInfo {
t.Errorf("expected INFO got %s", got)
}
}
+107
View File
@@ -0,0 +1,107 @@
package main
import (
"context"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/ewpt3ch/pkgstash/internal/cache"
)
type Server struct {
cfg *Config
c *cache.Cache
logLevel *slog.LevelVar
}
func main() {
// get options from cli flags
var configPath string
flag.StringVar(&configPath, "config", "", "path to config file")
logFlag := flag.String("loglevel", "INFO", "loglevel: DEBUG, INFO, WARN, ERROR")
flag.Parse()
// set config from flag if available
if len(configPath) == 0 {
configPath = "/etc/pkgstash/pkgstash.toml"
}
//set log level from flag if available
logLevel := new(slog.LevelVar)
if err := logLevel.UnmarshalText([]byte(*logFlag)); err != nil {
fmt.Fprintf(os.Stderr, "invalid log level %q, defaulting to INFO\n", *logFlag)
logLevel.Set(slog.LevelInfo)
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: logLevel,
}))
slog.SetDefault(logger)
cfg, err := ReadConfig(configPath)
if err != nil {
slog.Error("failed to read config", "err", err)
os.Exit(1)
}
c, err := cache.NewCache(cfg.CacheRoot, cfg.MirrorURLs, cfg.MirroredRepos)
if err != nil {
slog.Error("failed to create cache", "err", err)
os.Exit(1)
}
defer c.Close() //nolint:errcheck // best effort cleanup on exit
srv := &Server{
cfg: cfg,
c: c,
logLevel: logLevel,
}
mux := http.NewServeMux()
mux.HandleFunc("GET /{repo}/os/{arch}/{file}", srv.handlerPackage)
mux.HandleFunc("POST /api/refresh", srv.handlerRefresh)
mux.HandleFunc("POST /api/loglevel", srv.handlerLogLevel)
if err := srv.c.Refresh(); err != nil {
slog.Error("failed to refesh db files", "err", err)
//nolint:errcheck //already exiting
_ = c.Close() // best effort cleanup on exit
os.Exit(1)
}
httpServe := &http.Server{
Addr: ":" + srv.cfg.Port,
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
}
// gracefully quit the server
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
go func() {
slog.Info("serving pkgstash", "root", cfg.CacheRoot, "port", cfg.Port)
if err = httpServe.ListenAndServe(); err != http.ErrServerClosed {
slog.Error("server failed", "err", err)
_ = c.Close() // best effort cleanup on exit
os.Exit(1)
}
}()
<-quit
slog.Info("shutting down")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := httpServe.Shutdown(ctx); err != nil {
slog.Error("shutdown failed", "err", err)
}
}
+25
View File
@@ -0,0 +1,25 @@
package main
import (
"encoding/json"
"net/http"
)
func respondWithError(w http.ResponseWriter, code int, msg string) {
type returnVals struct {
Error string `json:"error"`
}
respondWithJSON(w, code, returnVals{Error: msg})
}
func respondWithJSON(w http.ResponseWriter, code int, payload any) {
dat, err := json.Marshal(payload)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
_, _ = w.Write(dat)
}