mirror of
https://github.com/restic/rest-server.git
synced 2025-10-19 15:43:21 +00:00
Merge 73df4faa19
into d24ffc13d8
This commit is contained in:
commit
713259af6c
5 changed files with 100 additions and 34 deletions
1
AUTHORS
1
AUTHORS
|
@ -14,3 +14,4 @@ Wayne Scott <wsc9tt@gmail.com>
|
||||||
Zlatko Čalušić <zcalusic@bitsync.net>
|
Zlatko Čalušić <zcalusic@bitsync.net>
|
||||||
cgonzalez <chgonzalezg@gmail.com>
|
cgonzalez <chgonzalezg@gmail.com>
|
||||||
n0npax <marcin@niemira.net>
|
n0npax <marcin@niemira.net>
|
||||||
|
textaligncenter <67056612+textaligncenter@users.noreply.github.com>
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
@ -46,6 +49,8 @@ func init() {
|
||||||
flags.BoolVar(&server.TLS, "tls", server.TLS, "turn on TLS support")
|
flags.BoolVar(&server.TLS, "tls", server.TLS, "turn on TLS support")
|
||||||
flags.StringVar(&server.TLSCert, "tls-cert", server.TLSCert, "TLS certificate path")
|
flags.StringVar(&server.TLSCert, "tls-cert", server.TLSCert, "TLS certificate path")
|
||||||
flags.StringVar(&server.TLSKey, "tls-key", server.TLSKey, "TLS key path")
|
flags.StringVar(&server.TLSKey, "tls-key", server.TLSKey, "TLS key path")
|
||||||
|
flags.BoolVar(&server.MTLS, "mtls", server.MTLS, "turn on client certificate support")
|
||||||
|
flags.StringVar(&server.CACert, "cacert", server.CACert, "mTLS CA certificate path")
|
||||||
flags.BoolVar(&server.NoAuth, "no-auth", server.NoAuth, "disable .htpasswd authentication")
|
flags.BoolVar(&server.NoAuth, "no-auth", server.NoAuth, "disable .htpasswd authentication")
|
||||||
flags.BoolVar(&server.NoVerifyUpload, "no-verify-upload", server.NoVerifyUpload,
|
flags.BoolVar(&server.NoVerifyUpload, "no-verify-upload", server.NoVerifyUpload,
|
||||||
"do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device")
|
"do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device")
|
||||||
|
@ -57,12 +62,12 @@ func init() {
|
||||||
|
|
||||||
var version = "0.11.0"
|
var version = "0.11.0"
|
||||||
|
|
||||||
func tlsSettings() (bool, string, string, error) {
|
func tlsSettings() (bool, bool, string, string, string, error) {
|
||||||
var key, cert string
|
var key, cert string
|
||||||
if !server.TLS && (server.TLSKey != "" || server.TLSCert != "") {
|
if (!server.TLS && !server.MTLS) && (server.TLSKey != "" || server.TLSCert != "") {
|
||||||
return false, "", "", errors.New("requires enabled TLS")
|
return false, false, "", "", "", errors.New("requires enabled TLS or mTLS")
|
||||||
} else if !server.TLS {
|
} else if !server.TLS {
|
||||||
return false, "", "", nil
|
return false, false, "", "", "", nil
|
||||||
}
|
}
|
||||||
if server.TLSKey != "" {
|
if server.TLSKey != "" {
|
||||||
key = server.TLSKey
|
key = server.TLSKey
|
||||||
|
@ -74,7 +79,11 @@ func tlsSettings() (bool, string, string, error) {
|
||||||
} else {
|
} else {
|
||||||
cert = filepath.Join(server.Path, "public_key")
|
cert = filepath.Join(server.Path, "public_key")
|
||||||
}
|
}
|
||||||
return server.TLS, key, cert, nil
|
|
||||||
|
if server.MTLS && server.CACert == "" {
|
||||||
|
return false, false, "", "", "", errors.New("missing cacert")
|
||||||
|
}
|
||||||
|
return server.TLS, server.MTLS, server.CACert, key, cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func runRoot(cmd *cobra.Command, args []string) error {
|
func runRoot(cmd *cobra.Command, args []string) error {
|
||||||
|
@ -125,7 +134,7 @@ func runRoot(cmd *cobra.Command, args []string) error {
|
||||||
log.Println("Private repositories disabled")
|
log.Println("Private repositories disabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
enabledTLS, privateKey, publicKey, err := tlsSettings()
|
enabledTLS, enabledMTLS, caCert, privateKey, publicKey, err := tlsSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -138,8 +147,31 @@ func runRoot(cmd *cobra.Command, args []string) error {
|
||||||
if !enabledTLS {
|
if !enabledTLS {
|
||||||
err = http.Serve(listener, handler)
|
err = http.Serve(listener, handler)
|
||||||
} else {
|
} else {
|
||||||
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
|
if enabledMTLS {
|
||||||
err = http.ServeTLS(listener, handler, publicKey, privateKey)
|
log.Printf("mTLS enabled, private key %s, pubkey %s, cacert %s", privateKey, publicKey, caCert)
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
caCertPem, err := ioutil.ReadFile(caCert)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("unable to read cacert")
|
||||||
|
}
|
||||||
|
caCertPool.AppendCertsFromPEM(caCertPem)
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
ClientCAs: caCertPool,
|
||||||
|
ClientAuth: tls.VerifyClientCertIfGiven,
|
||||||
|
}
|
||||||
|
tlsConfig.BuildNameToCertificate()
|
||||||
|
|
||||||
|
server := &http.Server{
|
||||||
|
Addr: server.Listen,
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
}
|
||||||
|
server.Handler = handler
|
||||||
|
err = server.ServeTLS(listener, publicKey, privateKey)
|
||||||
|
} else {
|
||||||
|
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
|
||||||
|
err = http.ServeTLS(listener, handler, publicKey, privateKey)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -13,6 +13,7 @@ func TestTLSSettings(t *testing.T) {
|
||||||
type expected struct {
|
type expected struct {
|
||||||
TLSKey string
|
TLSKey string
|
||||||
TLSCert string
|
TLSCert string
|
||||||
|
CAcert string
|
||||||
Error bool
|
Error bool
|
||||||
}
|
}
|
||||||
type passed struct {
|
type passed struct {
|
||||||
|
@ -20,30 +21,29 @@ func TestTLSSettings(t *testing.T) {
|
||||||
TLS bool
|
TLS bool
|
||||||
TLSKey string
|
TLSKey string
|
||||||
TLSCert string
|
TLSCert string
|
||||||
|
MTLS bool
|
||||||
|
CACert string
|
||||||
}
|
}
|
||||||
|
|
||||||
var tests = []struct {
|
var tests = []struct {
|
||||||
passed passed
|
passed passed
|
||||||
expected expected
|
expected expected
|
||||||
}{
|
}{
|
||||||
{passed{TLS: false}, expected{"", "", false}},
|
{passed{TLS: false}, expected{"", "", "", false}},
|
||||||
{passed{TLS: true}, expected{
|
{passed{TLS: true}, expected{"/tmp/restic/private_key", "/tmp/restic/public_key", "", false}},
|
||||||
filepath.Join(os.TempDir(), "restic/private_key"),
|
{passed{Path: "/tmp", TLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "", false}},
|
||||||
filepath.Join(os.TempDir(), "restic/public_key"),
|
{passed{Path: "/tmp", TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", "", false}},
|
||||||
false,
|
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
|
||||||
}},
|
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", "", true}},
|
||||||
{passed{
|
{passed{Path: "/tmp", TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
|
||||||
Path: os.TempDir(),
|
|
||||||
TLS: true,
|
{passed{TLS: false, MTLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "/etc/restic/cacert", false}},
|
||||||
}, expected{
|
{passed{TLS: true, MTLS: true}, expected{"/tmp/restic/private_key", "/tmp/restic/public_key", "", false}},
|
||||||
filepath.Join(os.TempDir(), "private_key"),
|
{passed{Path: "/tmp", TLS: true, MTLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "", false}},
|
||||||
filepath.Join(os.TempDir(), "public_key"),
|
{passed{Path: "/tmp", TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert", MTLS: true}, expected{"/etc/restic/key", "/etc/restic/cert", "", false}},
|
||||||
false,
|
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
|
||||||
}},
|
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", "", true}},
|
||||||
{passed{Path: os.TempDir(), TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", false}},
|
{passed{Path: "/tmp", TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
|
||||||
{passed{Path: os.TempDir(), TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
|
|
||||||
{passed{Path: os.TempDir(), TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", true}},
|
|
||||||
{passed{Path: os.TempDir(), TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
|
@ -57,7 +57,7 @@ func TestTLSSettings(t *testing.T) {
|
||||||
server.TLSKey = test.passed.TLSKey
|
server.TLSKey = test.passed.TLSKey
|
||||||
server.TLSCert = test.passed.TLSCert
|
server.TLSCert = test.passed.TLSCert
|
||||||
|
|
||||||
gotTLS, gotKey, gotCert, err := tlsSettings()
|
gotTLS, gotMTLS, gotCAcert, gotKey, gotCert, err := tlsSettings()
|
||||||
if err != nil && !test.expected.Error {
|
if err != nil && !test.expected.Error {
|
||||||
t.Fatalf("tls_settings returned err (%v)", err)
|
t.Fatalf("tls_settings returned err (%v)", err)
|
||||||
}
|
}
|
||||||
|
@ -71,6 +71,7 @@ func TestTLSSettings(t *testing.T) {
|
||||||
if gotTLS != test.passed.TLS {
|
if gotTLS != test.passed.TLS {
|
||||||
t.Errorf("TLS enabled, want (%v), got (%v)", test.passed.TLS, gotTLS)
|
t.Errorf("TLS enabled, want (%v), got (%v)", test.passed.TLS, gotTLS)
|
||||||
}
|
}
|
||||||
|
|
||||||
wantKey := test.expected.TLSKey
|
wantKey := test.expected.TLSKey
|
||||||
if gotKey != wantKey {
|
if gotKey != wantKey {
|
||||||
t.Errorf("wrong TLSPrivPath path, want (%v), got (%v)", wantKey, gotKey)
|
t.Errorf("wrong TLSPrivPath path, want (%v), got (%v)", wantKey, gotKey)
|
||||||
|
@ -81,6 +82,14 @@ func TestTLSSettings(t *testing.T) {
|
||||||
t.Errorf("wrong TLSCertPath path, want (%v), got (%v)", wantCert, gotCert)
|
t.Errorf("wrong TLSCertPath path, want (%v), got (%v)", wantCert, gotCert)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if gotMTLS != test.passed.MTLS {
|
||||||
|
t.Errorf("mTLS enabled, want (%v), got (%v)", test.passed.MTLS, gotMTLS)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantCAcert := test.expected.CAcert
|
||||||
|
if gotCAcert != wantCAcert {
|
||||||
|
t.Errorf("wrong CACertPath path, want (%v), got (%v)", wantCAcert, gotCAcert)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,9 +18,11 @@ type Server struct {
|
||||||
Listen string
|
Listen string
|
||||||
Log string
|
Log string
|
||||||
CPUProfile string
|
CPUProfile string
|
||||||
|
CACert string
|
||||||
TLSKey string
|
TLSKey string
|
||||||
TLSCert string
|
TLSCert string
|
||||||
TLS bool
|
TLS bool
|
||||||
|
MTLS bool
|
||||||
NoAuth bool
|
NoAuth bool
|
||||||
AppendOnly bool
|
AppendOnly bool
|
||||||
PrivateRepos bool
|
PrivateRepos bool
|
||||||
|
|
36
mux.go
36
mux.go
|
@ -1,7 +1,6 @@
|
||||||
package restserver
|
package restserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
@ -33,12 +32,35 @@ func (s *Server) checkAuth(r *http.Request) (username string, ok bool) {
|
||||||
if s.NoAuth {
|
if s.NoAuth {
|
||||||
return username, true
|
return username, true
|
||||||
}
|
}
|
||||||
var password string
|
|
||||||
username, password, ok = r.BasicAuth()
|
username, ok = s.validateBasicAuth(r)
|
||||||
if !ok || !s.htpasswdFile.Validate(username, password) {
|
if ok {
|
||||||
return "", false
|
return username, true
|
||||||
}
|
}
|
||||||
return username, true
|
|
||||||
|
username, ok = validateClientCert(r)
|
||||||
|
if ok {
|
||||||
|
return username, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return username, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) validateBasicAuth(r *http.Request) (string, bool) {
|
||||||
|
username, password, ok := r.BasicAuth()
|
||||||
|
return username, ok && s.htpasswdFile.Validate(username, password)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateClientCert(r *http.Request) (string, bool) {
|
||||||
|
if r.TLS != nil {
|
||||||
|
for _, cert := range r.TLS.PeerCertificates {
|
||||||
|
username := cert.Subject.CommonName
|
||||||
|
if username != "" {
|
||||||
|
return username, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) wrapMetricsAuth(f http.HandlerFunc) http.HandlerFunc {
|
func (s *Server) wrapMetricsAuth(f http.HandlerFunc) http.HandlerFunc {
|
||||||
|
@ -62,7 +84,7 @@ func NewHandler(server *Server) (http.Handler, error) {
|
||||||
var err error
|
var err error
|
||||||
server.htpasswdFile, err = NewHtpasswdFromFile(filepath.Join(server.Path, ".htpasswd"))
|
server.htpasswdFile, err = NewHtpasswdFromFile(filepath.Join(server.Path, ".htpasswd"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot load .htpasswd (use --no-auth to disable): %v", err)
|
// return nil, fmt.Errorf("cannot load .htpasswd (use --no-auth to disable): %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue