mirror of
https://github.com/restic/rest-server.git
synced 2025-10-19 15:43:21 +00:00
Merge 73df4faa19
into d24ffc13d8
This commit is contained in:
commit
713259af6c
5 changed files with 100 additions and 34 deletions
1
AUTHORS
1
AUTHORS
|
@ -14,3 +14,4 @@ Wayne Scott <wsc9tt@gmail.com>
|
|||
Zlatko Čalušić <zcalusic@bitsync.net>
|
||||
cgonzalez <chgonzalezg@gmail.com>
|
||||
n0npax <marcin@niemira.net>
|
||||
textaligncenter <67056612+textaligncenter@users.noreply.github.com>
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
@ -46,6 +49,8 @@ func init() {
|
|||
flags.BoolVar(&server.TLS, "tls", server.TLS, "turn on TLS support")
|
||||
flags.StringVar(&server.TLSCert, "tls-cert", server.TLSCert, "TLS certificate path")
|
||||
flags.StringVar(&server.TLSKey, "tls-key", server.TLSKey, "TLS key path")
|
||||
flags.BoolVar(&server.MTLS, "mtls", server.MTLS, "turn on client certificate support")
|
||||
flags.StringVar(&server.CACert, "cacert", server.CACert, "mTLS CA certificate path")
|
||||
flags.BoolVar(&server.NoAuth, "no-auth", server.NoAuth, "disable .htpasswd authentication")
|
||||
flags.BoolVar(&server.NoVerifyUpload, "no-verify-upload", server.NoVerifyUpload,
|
||||
"do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device")
|
||||
|
@ -57,12 +62,12 @@ func init() {
|
|||
|
||||
var version = "0.11.0"
|
||||
|
||||
func tlsSettings() (bool, string, string, error) {
|
||||
func tlsSettings() (bool, bool, string, string, string, error) {
|
||||
var key, cert string
|
||||
if !server.TLS && (server.TLSKey != "" || server.TLSCert != "") {
|
||||
return false, "", "", errors.New("requires enabled TLS")
|
||||
if (!server.TLS && !server.MTLS) && (server.TLSKey != "" || server.TLSCert != "") {
|
||||
return false, false, "", "", "", errors.New("requires enabled TLS or mTLS")
|
||||
} else if !server.TLS {
|
||||
return false, "", "", nil
|
||||
return false, false, "", "", "", nil
|
||||
}
|
||||
if server.TLSKey != "" {
|
||||
key = server.TLSKey
|
||||
|
@ -74,7 +79,11 @@ func tlsSettings() (bool, string, string, error) {
|
|||
} else {
|
||||
cert = filepath.Join(server.Path, "public_key")
|
||||
}
|
||||
return server.TLS, key, cert, nil
|
||||
|
||||
if server.MTLS && server.CACert == "" {
|
||||
return false, false, "", "", "", errors.New("missing cacert")
|
||||
}
|
||||
return server.TLS, server.MTLS, server.CACert, key, cert, nil
|
||||
}
|
||||
|
||||
func runRoot(cmd *cobra.Command, args []string) error {
|
||||
|
@ -125,7 +134,7 @@ func runRoot(cmd *cobra.Command, args []string) error {
|
|||
log.Println("Private repositories disabled")
|
||||
}
|
||||
|
||||
enabledTLS, privateKey, publicKey, err := tlsSettings()
|
||||
enabledTLS, enabledMTLS, caCert, privateKey, publicKey, err := tlsSettings()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -138,8 +147,31 @@ func runRoot(cmd *cobra.Command, args []string) error {
|
|||
if !enabledTLS {
|
||||
err = http.Serve(listener, handler)
|
||||
} else {
|
||||
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
|
||||
err = http.ServeTLS(listener, handler, publicKey, privateKey)
|
||||
if enabledMTLS {
|
||||
log.Printf("mTLS enabled, private key %s, pubkey %s, cacert %s", privateKey, publicKey, caCert)
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPem, err := ioutil.ReadFile(caCert)
|
||||
if err != nil {
|
||||
return errors.New("unable to read cacert")
|
||||
}
|
||||
caCertPool.AppendCertsFromPEM(caCertPem)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
ClientCAs: caCertPool,
|
||||
ClientAuth: tls.VerifyClientCertIfGiven,
|
||||
}
|
||||
tlsConfig.BuildNameToCertificate()
|
||||
|
||||
server := &http.Server{
|
||||
Addr: server.Listen,
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
server.Handler = handler
|
||||
err = server.ServeTLS(listener, publicKey, privateKey)
|
||||
} else {
|
||||
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
|
||||
err = http.ServeTLS(listener, handler, publicKey, privateKey)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
|
|
|
@ -13,6 +13,7 @@ func TestTLSSettings(t *testing.T) {
|
|||
type expected struct {
|
||||
TLSKey string
|
||||
TLSCert string
|
||||
CAcert string
|
||||
Error bool
|
||||
}
|
||||
type passed struct {
|
||||
|
@ -20,30 +21,29 @@ func TestTLSSettings(t *testing.T) {
|
|||
TLS bool
|
||||
TLSKey string
|
||||
TLSCert string
|
||||
MTLS bool
|
||||
CACert string
|
||||
}
|
||||
|
||||
var tests = []struct {
|
||||
passed passed
|
||||
expected expected
|
||||
}{
|
||||
{passed{TLS: false}, expected{"", "", false}},
|
||||
{passed{TLS: true}, expected{
|
||||
filepath.Join(os.TempDir(), "restic/private_key"),
|
||||
filepath.Join(os.TempDir(), "restic/public_key"),
|
||||
false,
|
||||
}},
|
||||
{passed{
|
||||
Path: os.TempDir(),
|
||||
TLS: true,
|
||||
}, expected{
|
||||
filepath.Join(os.TempDir(), "private_key"),
|
||||
filepath.Join(os.TempDir(), "public_key"),
|
||||
false,
|
||||
}},
|
||||
{passed{Path: os.TempDir(), TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", false}},
|
||||
{passed{Path: os.TempDir(), TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
|
||||
{passed{Path: os.TempDir(), TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", true}},
|
||||
{passed{Path: os.TempDir(), TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
|
||||
{passed{TLS: false}, expected{"", "", "", false}},
|
||||
{passed{TLS: true}, expected{"/tmp/restic/private_key", "/tmp/restic/public_key", "", false}},
|
||||
{passed{Path: "/tmp", TLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "", false}},
|
||||
{passed{Path: "/tmp", TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", "", false}},
|
||||
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
|
||||
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", "", true}},
|
||||
{passed{Path: "/tmp", TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
|
||||
|
||||
{passed{TLS: false, MTLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "/etc/restic/cacert", false}},
|
||||
{passed{TLS: true, MTLS: true}, expected{"/tmp/restic/private_key", "/tmp/restic/public_key", "", false}},
|
||||
{passed{Path: "/tmp", TLS: true, MTLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "", false}},
|
||||
{passed{Path: "/tmp", TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert", MTLS: true}, expected{"/etc/restic/key", "/etc/restic/cert", "", false}},
|
||||
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
|
||||
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", "", true}},
|
||||
{passed{Path: "/tmp", TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
@ -57,7 +57,7 @@ func TestTLSSettings(t *testing.T) {
|
|||
server.TLSKey = test.passed.TLSKey
|
||||
server.TLSCert = test.passed.TLSCert
|
||||
|
||||
gotTLS, gotKey, gotCert, err := tlsSettings()
|
||||
gotTLS, gotMTLS, gotCAcert, gotKey, gotCert, err := tlsSettings()
|
||||
if err != nil && !test.expected.Error {
|
||||
t.Fatalf("tls_settings returned err (%v)", err)
|
||||
}
|
||||
|
@ -71,6 +71,7 @@ func TestTLSSettings(t *testing.T) {
|
|||
if gotTLS != test.passed.TLS {
|
||||
t.Errorf("TLS enabled, want (%v), got (%v)", test.passed.TLS, gotTLS)
|
||||
}
|
||||
|
||||
wantKey := test.expected.TLSKey
|
||||
if gotKey != wantKey {
|
||||
t.Errorf("wrong TLSPrivPath path, want (%v), got (%v)", wantKey, gotKey)
|
||||
|
@ -81,6 +82,14 @@ func TestTLSSettings(t *testing.T) {
|
|||
t.Errorf("wrong TLSCertPath path, want (%v), got (%v)", wantCert, gotCert)
|
||||
}
|
||||
|
||||
if gotMTLS != test.passed.MTLS {
|
||||
t.Errorf("mTLS enabled, want (%v), got (%v)", test.passed.MTLS, gotMTLS)
|
||||
}
|
||||
|
||||
wantCAcert := test.expected.CAcert
|
||||
if gotCAcert != wantCAcert {
|
||||
t.Errorf("wrong CACertPath path, want (%v), got (%v)", wantCAcert, gotCAcert)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,9 +18,11 @@ type Server struct {
|
|||
Listen string
|
||||
Log string
|
||||
CPUProfile string
|
||||
CACert string
|
||||
TLSKey string
|
||||
TLSCert string
|
||||
TLS bool
|
||||
MTLS bool
|
||||
NoAuth bool
|
||||
AppendOnly bool
|
||||
PrivateRepos bool
|
||||
|
|
36
mux.go
36
mux.go
|
@ -1,7 +1,6 @@
|
|||
package restserver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
@ -33,12 +32,35 @@ func (s *Server) checkAuth(r *http.Request) (username string, ok bool) {
|
|||
if s.NoAuth {
|
||||
return username, true
|
||||
}
|
||||
var password string
|
||||
username, password, ok = r.BasicAuth()
|
||||
if !ok || !s.htpasswdFile.Validate(username, password) {
|
||||
return "", false
|
||||
|
||||
username, ok = s.validateBasicAuth(r)
|
||||
if ok {
|
||||
return username, true
|
||||
}
|
||||
return username, true
|
||||
|
||||
username, ok = validateClientCert(r)
|
||||
if ok {
|
||||
return username, true
|
||||
}
|
||||
|
||||
return username, false
|
||||
}
|
||||
|
||||
func (s *Server) validateBasicAuth(r *http.Request) (string, bool) {
|
||||
username, password, ok := r.BasicAuth()
|
||||
return username, ok && s.htpasswdFile.Validate(username, password)
|
||||
}
|
||||
|
||||
func validateClientCert(r *http.Request) (string, bool) {
|
||||
if r.TLS != nil {
|
||||
for _, cert := range r.TLS.PeerCertificates {
|
||||
username := cert.Subject.CommonName
|
||||
if username != "" {
|
||||
return username, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (s *Server) wrapMetricsAuth(f http.HandlerFunc) http.HandlerFunc {
|
||||
|
@ -62,7 +84,7 @@ func NewHandler(server *Server) (http.Handler, error) {
|
|||
var err error
|
||||
server.htpasswdFile, err = NewHtpasswdFromFile(filepath.Join(server.Path, ".htpasswd"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot load .htpasswd (use --no-auth to disable): %v", err)
|
||||
// return nil, fmt.Errorf("cannot load .htpasswd (use --no-auth to disable): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue