diff --git a/htpasswd.go b/htpasswd.go index 49fe06e..52ef657 100644 --- a/htpasswd.go +++ b/htpasswd.go @@ -1,7 +1,7 @@ package restserver /* -Copied from: github.com/bitly/oauth2_proxy +Original version copied from: github.com/bitly/oauth2_proxy MIT License @@ -28,52 +28,130 @@ import ( "crypto/sha1" "encoding/base64" "encoding/csv" - "io" "log" "os" + "sync" + "time" ) // Lookup passwords in a htpasswd file. The entries must have been created with -s for SHA encryption. // HtpasswdFile is a map for usernames to passwords. type HtpasswdFile struct { + mutex sync.Mutex + path string + stat os.FileInfo + throttle chan struct{} Users map[string]string } // NewHtpasswdFromFile reads the users and passwords from a htpasswd file and returns them. If an error is encountered, // it is returned, together with a nil-Pointer for the HtpasswdFile. func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) { - r, err := os.Open(path) + stat, err := os.Stat(path) if err != nil { return nil, err } - defer r.Close() - return NewHtpasswd(r) + + h := &HtpasswdFile{ + mutex: sync.Mutex{}, + path: path, + stat: stat, + throttle: make(chan struct{}), + } + + if err := h.Reload(); err != nil { + return nil, err + } + + // Start a goroutine that limits reload checks to once a second at most + go h.throttleTimer() + + return h, nil } -// NewHtpasswd reads the users and passwords from a htpasswd datastream in file and returns them. If an error is -// encountered, it is returned, together with a nil-Pointer for the HtpasswdFile. -func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) { - cr := csv.NewReader(file) +// throttleTimer sends at most one message per second to throttle file change checks. +func (h *HtpasswdFile) throttleTimer() { + var check struct{} + for { + time.Sleep(1 * time.Second) + h.throttle <- check + } +} + +// Reload reloads the htpasswd file. If the reload fails, the Users map is not changed and the error is returned. +func (h *HtpasswdFile) Reload() error { + r, err := os.Open(h.path) + if err != nil { + return err + } + defer r.Close() + + cr := csv.NewReader(r) cr.Comma = ':' cr.Comment = '#' cr.TrimLeadingSpace = true records, err := cr.ReadAll() if err != nil { - return nil, err + return err } - h := &HtpasswdFile{Users: make(map[string]string)} + users := make(map[string]string) for _, record := range records { - h.Users[record[0]] = record[1] + users[record[0]] = record[1] } - return h, nil + + // Replace the Users map + h.mutex.Lock() + h.Users = users + h.mutex.Unlock() + return nil +} + +// ReloadCheck checks at most once per second if the file changed and will reload the file if it did. +// It logs errors and successful reloads, and returns an error if any was encountered. +func (h *HtpasswdFile) ReloadCheck() error { + select { + case <-h.throttle: + stat, err := os.Stat(h.path) + if err != nil { + log.Printf("Could not stat htpasswd file: %v", err) + return err + } + + reload := false + + h.mutex.Lock() + if stat.ModTime() != h.stat.ModTime() || stat.Size() != h.stat.Size() { + reload = true + h.stat = stat + } + h.mutex.Unlock() + + if reload { + err := h.Reload() + if err == nil { + log.Printf("Reloaded htpasswd file") + } else { + log.Printf("Could not reload htpasswd file: %v", err) + return err + } + } + default: + // No need to check + } + return nil } // Validate returns true if password matches the stored password for user. If no password for user is stored, or the // password is wrong, false is returned. func (h *HtpasswdFile) Validate(user string, password string) bool { + _ = h.ReloadCheck() + + h.mutex.Lock() realPassword, exists := h.Users[user] + h.mutex.Unlock() + if !exists { return false }