moved server main to cmd/server
This commit is contained in:
@@ -0,0 +1,64 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
CacheRoot string `toml:"cache_root"`
|
||||
MirrorURLs []string `toml:"mirror_urls"`
|
||||
MirroredRepos []string `toml:"mirrored_repos"`
|
||||
Port string `toml:"port"`
|
||||
Auth AuthConfig `toml:"auth"`
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
Token string `toml:"token"`
|
||||
}
|
||||
|
||||
/* Function kept for reference for future logic
|
||||
func NewConfig() *Config {
|
||||
return &Config{
|
||||
CacheRoot: "/home/ewpt3ch/dev/pacman-cache-server/tmprepo",
|
||||
MirrorURLs: "https://us.mirrors.cicku.me/archlinux/",
|
||||
Port: "8090",
|
||||
Auth: AuthConfig{Token: "FakeToken"},
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
func ReadConfig(path string) (*Config, error) {
|
||||
|
||||
var cfg Config
|
||||
_, err := toml.DecodeFile(path, &cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error loading config from %s: %w", path, err)
|
||||
}
|
||||
|
||||
if err = cfg.validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (c *Config) validate() error {
|
||||
if c.CacheRoot == "" {
|
||||
return fmt.Errorf("cache root is required")
|
||||
}
|
||||
if len(c.MirrorURLs) == 0 {
|
||||
return fmt.Errorf("at least one mirror is required")
|
||||
}
|
||||
if len(c.MirroredRepos) == 0 {
|
||||
return fmt.Errorf("at least one repo is required")
|
||||
}
|
||||
if c.Port == "" {
|
||||
return fmt.Errorf("port required")
|
||||
}
|
||||
if c.Auth.Token == "" {
|
||||
return fmt.Errorf("auth token is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func writeConfigFile(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "config.toml")
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed write test config: %v", err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func TestReadConfig(t *testing.T) {
|
||||
path := writeConfigFile(t, `
|
||||
cache_root = "srv/cache"
|
||||
mirror_urls = ["https://mirror.example.com"]
|
||||
mirrored_repos = ["core", "extra"]
|
||||
port = "8090"
|
||||
|
||||
[auth]
|
||||
token = "testtoken"
|
||||
`)
|
||||
|
||||
cfg, err := ReadConfig(path)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no err on read got: %v", err)
|
||||
}
|
||||
if cfg.Port != "8090" {
|
||||
t.Errorf("expected port 8090 got %s", cfg.Port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMissingCacheRoot(t *testing.T) {
|
||||
path := writeConfigFile(t, `
|
||||
mirror_urls = ["https://mirror.example.com"]
|
||||
mirrored_repos = ["core", "extra"]
|
||||
port = "8090"
|
||||
|
||||
[auth]
|
||||
token = "testtoken"
|
||||
`)
|
||||
|
||||
_, err := ReadConfig(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected err got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMissingMirrorUrls(t *testing.T) {
|
||||
path := writeConfigFile(t, `
|
||||
cache_root = "srv/cache"
|
||||
mirrored_repos = ["core", "extra"]
|
||||
port = "8090"
|
||||
|
||||
[auth]
|
||||
token = "testtoken"
|
||||
`)
|
||||
|
||||
_, err := ReadConfig(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected err got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMissingMirroredRepos(t *testing.T) {
|
||||
path := writeConfigFile(t, `
|
||||
cache_root = "srv/cache"
|
||||
mirror_urls = ["https://mirror.example.com"]
|
||||
port = "8090"
|
||||
|
||||
[auth]
|
||||
token = "testtoken"
|
||||
`)
|
||||
|
||||
_, err := ReadConfig(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected err got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMissingPort(t *testing.T) {
|
||||
path := writeConfigFile(t, `
|
||||
cache_root = "srv/cache"
|
||||
mirror_urls = ["https://mirror.example.com"]
|
||||
mirrored_repos = ["core", "extra"]
|
||||
|
||||
[auth]
|
||||
token = "testtoken"
|
||||
`)
|
||||
|
||||
_, err := ReadConfig(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected err got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMissingAuthToken(t *testing.T) {
|
||||
path := writeConfigFile(t, `
|
||||
cache_root = "srv/cache"
|
||||
mirror_urls = ["https://mirror.example.com"]
|
||||
mirrored_repos = ["core", "extra"]
|
||||
port = "8090"
|
||||
|
||||
[auth]
|
||||
`)
|
||||
|
||||
_, err := ReadConfig(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected err got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMissingFile(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "nonexistant.toml")
|
||||
|
||||
_, err := ReadConfig(path)
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatal("expected err got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidToml(t *testing.T) {
|
||||
path := writeConfigFile(t, `
|
||||
cache_root = [srv/cache]
|
||||
`)
|
||||
|
||||
_, err := ReadConfig(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected err got nil")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (s *Server) handlerRefresh(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Header.Get("Authorization") != "Bearer "+s.cfg.Auth.Token {
|
||||
ip := req.Header.Get("X-Real-IP")
|
||||
if ip == "" {
|
||||
ip = req.RemoteAddr
|
||||
}
|
||||
slog.Warn("unauthorized request", "ip", ip, "path", req.URL.Path, "method", req.Method)
|
||||
respondWithError(w, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.c.Refresh(); err != nil {
|
||||
slog.Error("refresh failed", "err", err)
|
||||
http.Error(w, "refresh failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (s *Server) handlerLogLevel(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Header.Get("Authorization") != "Bearer "+s.cfg.Auth.Token {
|
||||
ip := req.Header.Get("X-Real-IP")
|
||||
if ip == "" {
|
||||
ip = req.RemoteAddr
|
||||
}
|
||||
slog.Warn("unauthorized request", "ip", ip, "path", req.URL.Path, "method", req.Method)
|
||||
respondWithError(w, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
type reqParameters struct {
|
||||
NewLevel string `json:"loglevel"`
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(req.Body)
|
||||
reqParams := reqParameters{}
|
||||
err := decoder.Decode(&reqParams)
|
||||
if err != nil {
|
||||
slog.Debug("json decode erro", "err", err)
|
||||
respondWithError(w, http.StatusBadRequest, "invalid request")
|
||||
return
|
||||
}
|
||||
|
||||
switch strings.ToLower(reqParams.NewLevel) {
|
||||
case "debug":
|
||||
s.logLevel.Set(slog.LevelDebug)
|
||||
case "info":
|
||||
s.logLevel.Set(slog.LevelInfo)
|
||||
case "warn":
|
||||
s.logLevel.Set(slog.LevelWarn)
|
||||
case "error":
|
||||
s.logLevel.Set(slog.LevelError)
|
||||
default:
|
||||
respondWithError(w, http.StatusBadRequest, "invalid log level")
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("log level changed", "level", reqParams.NewLevel)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ewpt3ch/pkgstash/internal/cache"
|
||||
)
|
||||
|
||||
func (s *Server) handlerPackage(w http.ResponseWriter, req *http.Request) {
|
||||
|
||||
// db files are not signed so we ignore as to not spam mirrors
|
||||
if strings.HasSuffix(req.PathValue("file"), ".db.sig") {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// record the useragent from requestor
|
||||
slog.Debug("Requestors User Agent", "UA", req.Header.Get("User-Agent"))
|
||||
|
||||
// build file paths from the request, they follow archlinux repo
|
||||
// <mirrorroot>/[core, extra, etc]/os/[x86_64, arm, etc]/package.pkg.tar.zst[.sig]
|
||||
repo := req.PathValue("repo")
|
||||
arch := req.PathValue("arch")
|
||||
file := req.PathValue("file")
|
||||
repoPath := filepath.Join(repo, "os", arch, file) //path from mirror root to requested file
|
||||
|
||||
cachedFile, err := s.c.Fetch(repoPath)
|
||||
if err != nil {
|
||||
if upstreamErr, ok := errors.AsType[*cache.UpstreamError](err); ok {
|
||||
slog.Warn("upstream error", "err", upstreamErr.Error())
|
||||
http.Error(w, "Not found upstream", upstreamErr.StatusCode)
|
||||
return
|
||||
}
|
||||
slog.Warn("fetch error", "err", err)
|
||||
http.Error(w, "Failed to fetch from upstream", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := cachedFile.Reader.Close(); closeErr != nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
w.Header().Set("Content-Disposition", "attachment; filename="+cachedFile.Filename)
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(cachedFile.Size, 10))
|
||||
_, err = io.Copy(w, cachedFile.Reader)
|
||||
if err != nil {
|
||||
slog.Warn("streaming error", "err", err)
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/ewpt3ch/pkgstash/internal/cache"
|
||||
)
|
||||
|
||||
var (
|
||||
//nolint:errcheck //ephemeral no need to check
|
||||
mirrorOK = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "fake pkg data") })
|
||||
mirror404 = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) })
|
||||
)
|
||||
|
||||
func mirrorOKWithCounter() (http.HandlerFunc, *atomic.Int32) {
|
||||
var calls atomic.Int32
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
calls.Add(1)
|
||||
//nolint:errcheck //ephemeral no need to check
|
||||
fmt.Fprint(w, "fake pkg data")
|
||||
}), &calls
|
||||
}
|
||||
|
||||
const testUrlBase = "/core/os/x86_64"
|
||||
|
||||
func newTestServer(t *testing.T, mirrorHandler http.HandlerFunc) (*httptest.Server, *Server) {
|
||||
t.Helper()
|
||||
|
||||
mirror := httptest.NewServer(mirrorHandler)
|
||||
t.Cleanup(func() { mirror.Close() })
|
||||
|
||||
c, err := cache.NewCache(t.TempDir(), []string{mirror.URL + "/"}, []string{"core"})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create cache: %v", err)
|
||||
}
|
||||
cfg := &Config{
|
||||
Port: "0",
|
||||
Auth: AuthConfig{Token: "testtoken"},
|
||||
}
|
||||
logLevel := new(slog.LevelVar)
|
||||
srv := &Server{
|
||||
cfg: cfg,
|
||||
c: c,
|
||||
logLevel: logLevel,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /{repo}/os/{arch}/{file}", srv.handlerPackage)
|
||||
mux.HandleFunc("POST /api/refresh", srv.handlerRefresh)
|
||||
mux.HandleFunc("POST /api/loglevel", srv.handlerLogLevel)
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
t.Cleanup(ts.Close)
|
||||
return ts, srv
|
||||
}
|
||||
|
||||
func TestHandlerPkgsExist(t *testing.T) {
|
||||
const expected = "fake pkg data"
|
||||
const expectedFile = "attachment; filename=somepkg.tar.zst"
|
||||
|
||||
ts, _ := newTestServer(t, mirrorOK)
|
||||
|
||||
resp, err := http.Get(ts.URL + testUrlBase + "/somepkg.tar.zst")
|
||||
if err != nil {
|
||||
t.Fatalf("GET failed: %v", err)
|
||||
}
|
||||
|
||||
respFile := resp.Header.Get("Content-Disposition")
|
||||
|
||||
if resp.ContentLength != int64(len(expected)) {
|
||||
t.Errorf("expected %d got %d", len(expected), resp.ContentLength)
|
||||
}
|
||||
if respFile != expectedFile {
|
||||
t.Errorf("expected %s got %s", expectedFile, respFile)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("err reading body: %v", err)
|
||||
}
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s got %s", expected, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerPkgsMiss(t *testing.T) {
|
||||
|
||||
ts, _ := newTestServer(t, mirror404)
|
||||
|
||||
resp, err := http.Get(ts.URL + testUrlBase + "/somepkg.tar.zst")
|
||||
if err != nil {
|
||||
t.Fatalf("GET failed %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("expected 404 got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerPkgsDBSig(t *testing.T) {
|
||||
|
||||
handler, calls := mirrorOKWithCounter()
|
||||
ts, _ := newTestServer(t, handler)
|
||||
|
||||
resp, err := http.Get(ts.URL + testUrlBase + "/core.db.sig")
|
||||
if err != nil {
|
||||
t.Fatalf("GET failed %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("expected 404 got %d", resp.StatusCode)
|
||||
}
|
||||
if calls.Load() != 0 {
|
||||
t.Error("expected no upstream calls for .db.sig")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerRefreshUnauthorized(t *testing.T) {
|
||||
|
||||
ts, _ := newTestServer(t, mirrorOK)
|
||||
|
||||
req, err := http.NewRequest("POST", ts.URL+"/api/refresh", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer badtoken")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("POST failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected %d got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerRefreshOK(t *testing.T) {
|
||||
|
||||
ts, _ := newTestServer(t, mirrorOK)
|
||||
|
||||
req, err := http.NewRequest("POST", ts.URL+"/api/refresh", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer testtoken")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("POST failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
t.Errorf("expected %d got %d", http.StatusNoContent, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerLogLevelValid(t *testing.T) {
|
||||
|
||||
ts, srv := newTestServer(t, mirrorOK)
|
||||
|
||||
body := strings.NewReader(`{"loglevel": "debug"}`)
|
||||
req, err := http.NewRequest("POST", ts.URL+"/api/loglevel", body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer testtoken")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("POST failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
t.Errorf("expected %d got %d", http.StatusNoContent, resp.StatusCode)
|
||||
}
|
||||
got := srv.logLevel.Level()
|
||||
if got != slog.LevelDebug {
|
||||
t.Errorf("expected DEBUG got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerLogLevelInvalid(t *testing.T) {
|
||||
|
||||
ts, srv := newTestServer(t, mirrorOK)
|
||||
|
||||
body := strings.NewReader(`{"loglevel": "what"}`)
|
||||
req, err := http.NewRequest("POST", ts.URL+"/api/loglevel", body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer testtoken")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("POST failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected %d got %d", http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
got := srv.logLevel.Level()
|
||||
if got != slog.LevelInfo {
|
||||
t.Errorf("expected INFO got %s", got)
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ewpt3ch/pkgstash/internal/cache"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
cfg *Config
|
||||
c *cache.Cache
|
||||
logLevel *slog.LevelVar
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
// get options from cli flags
|
||||
var configPath string
|
||||
flag.StringVar(&configPath, "config", "", "path to config file")
|
||||
logFlag := flag.String("loglevel", "INFO", "loglevel: DEBUG, INFO, WARN, ERROR")
|
||||
flag.Parse()
|
||||
|
||||
// set config from flag if available
|
||||
if len(configPath) == 0 {
|
||||
configPath = "/etc/pkgstash/pkgstash.toml"
|
||||
}
|
||||
|
||||
//set log level from flag if available
|
||||
logLevel := new(slog.LevelVar)
|
||||
if err := logLevel.UnmarshalText([]byte(*logFlag)); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "invalid log level %q, defaulting to INFO\n", *logFlag)
|
||||
logLevel.Set(slog.LevelInfo)
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: logLevel,
|
||||
}))
|
||||
slog.SetDefault(logger)
|
||||
|
||||
cfg, err := ReadConfig(configPath)
|
||||
if err != nil {
|
||||
slog.Error("failed to read config", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
c, err := cache.NewCache(cfg.CacheRoot, cfg.MirrorURLs, cfg.MirroredRepos)
|
||||
if err != nil {
|
||||
slog.Error("failed to create cache", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer c.Close() //nolint:errcheck // best effort cleanup on exit
|
||||
|
||||
srv := &Server{
|
||||
cfg: cfg,
|
||||
c: c,
|
||||
logLevel: logLevel,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /{repo}/os/{arch}/{file}", srv.handlerPackage)
|
||||
mux.HandleFunc("POST /api/refresh", srv.handlerRefresh)
|
||||
mux.HandleFunc("POST /api/loglevel", srv.handlerLogLevel)
|
||||
|
||||
if err := srv.c.Refresh(); err != nil {
|
||||
slog.Error("failed to refesh db files", "err", err)
|
||||
//nolint:errcheck //already exiting
|
||||
_ = c.Close() // best effort cleanup on exit
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
httpServe := &http.Server{
|
||||
Addr: ":" + srv.cfg.Port,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
// gracefully quit the server
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
slog.Info("serving pkgstash", "root", cfg.CacheRoot, "port", cfg.Port)
|
||||
if err = httpServe.ListenAndServe(); err != http.ErrServerClosed {
|
||||
slog.Error("server failed", "err", err)
|
||||
_ = c.Close() // best effort cleanup on exit
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
<-quit
|
||||
slog.Info("shutting down")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := httpServe.Shutdown(ctx); err != nil {
|
||||
slog.Error("shutdown failed", "err", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func respondWithError(w http.ResponseWriter, code int, msg string) {
|
||||
type returnVals struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
respondWithJSON(w, code, returnVals{Error: msg})
|
||||
}
|
||||
|
||||
func respondWithJSON(w http.ResponseWriter, code int, payload any) {
|
||||
dat, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
_, _ = w.Write(dat)
|
||||
}
|
||||
Reference in New Issue
Block a user