From b0a9a0452ee0b52af08f5a31048eafe60fee05ac Mon Sep 17 00:00:00 2001 From: akmet <10135260+akmet@users.noreply.github.com> Date: Fri, 1 Nov 2024 02:31:43 +0100 Subject: [PATCH] Add support for proxy-based authentication --- changelog/unreleased/pull-307 | 8 ++++ cmd/rest-server/main.go | 9 ++++- cmd/rest-server/main_test.go | 6 +++ handlers.go | 1 + mux.go | 17 ++++++--- mux_test.go | 72 +++++++++++++++++++++++++++++++++++ 6 files changed, 106 insertions(+), 7 deletions(-) create mode 100644 changelog/unreleased/pull-307 create mode 100644 mux_test.go diff --git a/changelog/unreleased/pull-307 b/changelog/unreleased/pull-307 new file mode 100644 index 0000000..2aa948e --- /dev/null +++ b/changelog/unreleased/pull-307 @@ -0,0 +1,8 @@ +Enhancement: Add support for proxy-based authentication + +The server now supports authentication via a proxy header specified with the --proxy-auth flag (e.g., --proxy-auth=X-Forwarded-User). +When this flag is set, the server will authenticate users based on the given header and disable BasicAuth. +Note that --proxy-auth is ignored if --no-auth is set, as --no-auth disables all authentication. + +https://github.com/restic/rest-server/issues/174 +https://github.com/restic/rest-server/pull/307 \ No newline at end of file diff --git a/cmd/rest-server/main.go b/cmd/rest-server/main.go index cabc565..f488b91 100644 --- a/cmd/rest-server/main.go +++ b/cmd/rest-server/main.go @@ -61,8 +61,9 @@ func newRestServerApp() *restServerApp { flags.BoolVar(&rv.Server.TLS, "tls", rv.Server.TLS, "turn on TLS support") flags.StringVar(&rv.Server.TLSCert, "tls-cert", rv.Server.TLSCert, "TLS certificate path") flags.StringVar(&rv.Server.TLSKey, "tls-key", rv.Server.TLSKey, "TLS key path") - flags.BoolVar(&rv.Server.NoAuth, "no-auth", rv.Server.NoAuth, "disable .htpasswd authentication") + flags.BoolVar(&rv.Server.NoAuth, "no-auth", rv.Server.NoAuth, "disable authentication") flags.StringVar(&rv.Server.HtpasswdPath, "htpasswd-file", rv.Server.HtpasswdPath, "location of .htpasswd file (default: \"/.htpasswd)\"") + flags.StringVar(&rv.Server.ProxyAuthUsername, "proxy-auth-username", rv.Server.ProxyAuthUsername, "specifies the HTTP header containing the username for proxy-based authentication") flags.BoolVar(&rv.Server.NoVerifyUpload, "no-verify-upload", rv.Server.NoVerifyUpload, "do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device") flags.BoolVar(&rv.Server.AppendOnly, "append-only", rv.Server.AppendOnly, "enable append only mode") @@ -130,7 +131,11 @@ func (app *restServerApp) runRoot(_ *cobra.Command, _ []string) error { if app.Server.NoAuth { log.Println("Authentication disabled") } else { - log.Println("Authentication enabled") + if app.Server.ProxyAuthUsername == "" { + log.Println("Authentication enabled") + } else { + log.Println("Proxy Authentication enabled.") + } } handler, err := restserver.NewHandler(&app.Server) diff --git a/cmd/rest-server/main_test.go b/cmd/rest-server/main_test.go index 1171530..7d08409 100644 --- a/cmd/rest-server/main_test.go +++ b/cmd/rest-server/main_test.go @@ -118,6 +118,12 @@ func TestGetHandler(t *testing.T) { t.Errorf("NoAuth=true: expected no error, got %v", err) } + // With NoAuth = false, no .htpasswd and ProxyAuth = X-Remote-User + _, err = getHandler(&restserver.Server{Path: dir, ProxyAuthUsername: "X-Remote-User"}) + if err != nil { + t.Errorf("NoAuth=false, ProxyAuthUsername = X-Remote-User: expected no error, got %v", err) + } + // With NoAuth = false and custom .htpasswd htpFile, err := os.CreateTemp(dir, "custom") if err != nil { diff --git a/handlers.go b/handlers.go index cde0637..12e760f 100644 --- a/handlers.go +++ b/handlers.go @@ -24,6 +24,7 @@ type Server struct { TLSCert string TLS bool NoAuth bool + ProxyAuthUsername string AppendOnly bool PrivateRepos bool Prometheus bool diff --git a/mux.go b/mux.go index 77fcdb4..9c604b3 100644 --- a/mux.go +++ b/mux.go @@ -41,10 +41,17 @@ func (s *Server) checkAuth(r *http.Request) (username string, ok bool) { if s.NoAuth { return username, true } - var password string - username, password, ok = r.BasicAuth() - if !ok || !s.htpasswdFile.Validate(username, password) { - return "", false + if s.ProxyAuthUsername != "" { + username = r.Header.Get(s.ProxyAuthUsername) + if username == "" { + return "", false + } + } else { + var password string + username, password, ok = r.BasicAuth() + if !ok || !s.htpasswdFile.Validate(username, password) { + return "", false + } } return username, true } @@ -66,7 +73,7 @@ func (s *Server) wrapMetricsAuth(f http.HandlerFunc) http.HandlerFunc { // NewHandler returns the master HTTP multiplexer/router. func NewHandler(server *Server) (http.Handler, error) { - if !server.NoAuth { + if !server.NoAuth && server.ProxyAuthUsername == "" { var err error if server.HtpasswdPath == "" { server.HtpasswdPath = filepath.Join(server.Path, ".htpasswd") diff --git a/mux_test.go b/mux_test.go new file mode 100644 index 0000000..f2ff0c6 --- /dev/null +++ b/mux_test.go @@ -0,0 +1,72 @@ +package restserver + +import ( + "net/http/httptest" + "testing" +) + +func TestCheckAuth(t *testing.T) { + tests := []struct { + name string + server *Server + requestHeaders map[string]string + basicAuth bool + basicUser string + basicPassword string + expectedUser string + expectedOk bool + }{ + { + name: "NoAuth enabled", + server: &Server{ + NoAuth: true, + }, + expectedOk: true, + }, + { + name: "Proxy Auth successful", + server: &Server{ + ProxyAuthUsername: "X-Remote-User", + }, + requestHeaders: map[string]string{ + "X-Remote-User": "restic", + }, + expectedUser: "restic", + expectedOk: true, + }, + { + name: "Proxy Auth empty header", + server: &Server{ + ProxyAuthUsername: "X-Remote-User", + }, + requestHeaders: map[string]string{ + "X-Remote-User": "", + }, + expectedOk: false, + }, + { + name: "Proxy Auth missing header", + server: &Server{ + ProxyAuthUsername: "X-Remote-User", + }, + expectedOk: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + for header, value := range tt.requestHeaders { + req.Header.Set(header, value) + } + if tt.basicAuth { + req.SetBasicAuth(tt.basicUser, tt.basicPassword) + } + + username, ok := tt.server.checkAuth(req) + if username != tt.expectedUser || ok != tt.expectedOk { + t.Errorf("expected (%v, %v), got (%v, %v)", tt.expectedUser, tt.expectedOk, username, ok) + } + }) + } +}