diff --git a/cmd/server/config.go b/cmd/server/config.go index c8170a6..27ce2d0 100644 --- a/cmd/server/config.go +++ b/cmd/server/config.go @@ -2,20 +2,23 @@ package main import ( "fmt" + "os" + "strings" "github.com/BurntSushi/toml" ) -type Config struct { - CacheRoot string `toml:"cache_root"` - MirrorURLs []string `toml:"mirror_urls"` - MirroredRepos []string `toml:"mirrored_repos"` - Port string `toml:"port"` - Auth AuthConfig `toml:"auth"` +type rawConfig struct { + Config + EnvFile string `toml:"env_file"` } -type AuthConfig struct { - Token string `toml:"token"` +type Config struct { + CacheRoot string `toml:"cache_root"` + MirrorURLs []string `toml:"mirror_urls"` + MirroredRepos []string `toml:"mirrored_repos"` + Port string `toml:"port"` + Token string } /* Function kept for reference for future logic @@ -31,12 +34,19 @@ func NewConfig() *Config { func ReadConfig(path string) (*Config, error) { - var cfg Config - _, err := toml.DecodeFile(path, &cfg) + var rawcfg rawConfig + _, err := toml.DecodeFile(path, &rawcfg) if err != nil { return nil, fmt.Errorf("error loading config from %s: %w", path, err) } + cfg := rawcfg.Config + + err = cfg.loadToken(rawcfg.EnvFile) + if err != nil { + return nil, fmt.Errorf("error getting token: %v", err) + } + if err = cfg.validate(); err != nil { return nil, fmt.Errorf("invalid config: %w", err) } @@ -44,6 +54,32 @@ func ReadConfig(path string) (*Config, error) { return &cfg, nil } +func (c *Config) loadToken(path string) error { + + info, err := os.Stat(path) + if err != nil { + return fmt.Errorf("failed to stat env file: %v", err) + } + + if info.Mode().Perm() != 0600 { + return fmt.Errorf("env file perms not secure, expected 0600 got %o", info.Mode().Perm()) + } + + data, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("failed to read env file: %v", err) + } + + key, val, ok := strings.Cut(strings.TrimSpace(string(data)), "=") + if !ok || key != "PKGSTASH_TOKEN" { + return fmt.Errorf("invalid env file format") + } + + c.Token = val + return nil + +} + func (c *Config) validate() error { if c.CacheRoot == "" { return fmt.Errorf("cache root is required") @@ -57,7 +93,7 @@ func (c *Config) validate() error { if c.Port == "" { return fmt.Errorf("port required") } - if c.Auth.Token == "" { + if c.Token == "" || c.Token == "changeme" { return fmt.Errorf("auth token is required") } return nil diff --git a/cmd/server/handler_api.go b/cmd/server/handler_api.go index b22db5b..6cea328 100644 --- a/cmd/server/handler_api.go +++ b/cmd/server/handler_api.go @@ -8,7 +8,7 @@ import ( ) func (s *Server) handlerRefresh(w http.ResponseWriter, req *http.Request) { - if req.Header.Get("Authorization") != "Bearer "+s.cfg.Auth.Token { + if req.Header.Get("Authorization") != "Bearer "+s.cfg.Token { ip := req.Header.Get("X-Real-IP") if ip == "" { ip = req.RemoteAddr @@ -27,7 +27,7 @@ func (s *Server) handlerRefresh(w http.ResponseWriter, req *http.Request) { } func (s *Server) handlerLogLevel(w http.ResponseWriter, req *http.Request) { - if req.Header.Get("Authorization") != "Bearer "+s.cfg.Auth.Token { + if req.Header.Get("Authorization") != "Bearer "+s.cfg.Token { ip := req.Header.Get("X-Real-IP") if ip == "" { ip = req.RemoteAddr