This commit is contained in:
textaligncenter 2022-04-24 20:16:06 +00:00 committed by GitHub
commit 713259af6c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 100 additions and 34 deletions

View file

@ -14,3 +14,4 @@ Wayne Scott <wsc9tt@gmail.com>
Zlatko Čalušić <zcalusic@bitsync.net> Zlatko Čalušić <zcalusic@bitsync.net>
cgonzalez <chgonzalezg@gmail.com> cgonzalez <chgonzalezg@gmail.com>
n0npax <marcin@niemira.net> n0npax <marcin@niemira.net>
textaligncenter <67056612+textaligncenter@users.noreply.github.com>

View file

@ -1,8 +1,11 @@
package main package main
import ( import (
"crypto/tls"
"crypto/x509"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"net/http" "net/http"
"os" "os"
@ -46,6 +49,8 @@ func init() {
flags.BoolVar(&server.TLS, "tls", server.TLS, "turn on TLS support") flags.BoolVar(&server.TLS, "tls", server.TLS, "turn on TLS support")
flags.StringVar(&server.TLSCert, "tls-cert", server.TLSCert, "TLS certificate path") flags.StringVar(&server.TLSCert, "tls-cert", server.TLSCert, "TLS certificate path")
flags.StringVar(&server.TLSKey, "tls-key", server.TLSKey, "TLS key path") flags.StringVar(&server.TLSKey, "tls-key", server.TLSKey, "TLS key path")
flags.BoolVar(&server.MTLS, "mtls", server.MTLS, "turn on client certificate support")
flags.StringVar(&server.CACert, "cacert", server.CACert, "mTLS CA certificate path")
flags.BoolVar(&server.NoAuth, "no-auth", server.NoAuth, "disable .htpasswd authentication") flags.BoolVar(&server.NoAuth, "no-auth", server.NoAuth, "disable .htpasswd authentication")
flags.BoolVar(&server.NoVerifyUpload, "no-verify-upload", server.NoVerifyUpload, flags.BoolVar(&server.NoVerifyUpload, "no-verify-upload", server.NoVerifyUpload,
"do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device") "do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device")
@ -57,12 +62,12 @@ func init() {
var version = "0.11.0" var version = "0.11.0"
func tlsSettings() (bool, string, string, error) { func tlsSettings() (bool, bool, string, string, string, error) {
var key, cert string var key, cert string
if !server.TLS && (server.TLSKey != "" || server.TLSCert != "") { if (!server.TLS && !server.MTLS) && (server.TLSKey != "" || server.TLSCert != "") {
return false, "", "", errors.New("requires enabled TLS") return false, false, "", "", "", errors.New("requires enabled TLS or mTLS")
} else if !server.TLS { } else if !server.TLS {
return false, "", "", nil return false, false, "", "", "", nil
} }
if server.TLSKey != "" { if server.TLSKey != "" {
key = server.TLSKey key = server.TLSKey
@ -74,7 +79,11 @@ func tlsSettings() (bool, string, string, error) {
} else { } else {
cert = filepath.Join(server.Path, "public_key") cert = filepath.Join(server.Path, "public_key")
} }
return server.TLS, key, cert, nil
if server.MTLS && server.CACert == "" {
return false, false, "", "", "", errors.New("missing cacert")
}
return server.TLS, server.MTLS, server.CACert, key, cert, nil
} }
func runRoot(cmd *cobra.Command, args []string) error { func runRoot(cmd *cobra.Command, args []string) error {
@ -125,7 +134,7 @@ func runRoot(cmd *cobra.Command, args []string) error {
log.Println("Private repositories disabled") log.Println("Private repositories disabled")
} }
enabledTLS, privateKey, publicKey, err := tlsSettings() enabledTLS, enabledMTLS, caCert, privateKey, publicKey, err := tlsSettings()
if err != nil { if err != nil {
return err return err
} }
@ -138,8 +147,31 @@ func runRoot(cmd *cobra.Command, args []string) error {
if !enabledTLS { if !enabledTLS {
err = http.Serve(listener, handler) err = http.Serve(listener, handler)
} else { } else {
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey) if enabledMTLS {
err = http.ServeTLS(listener, handler, publicKey, privateKey) log.Printf("mTLS enabled, private key %s, pubkey %s, cacert %s", privateKey, publicKey, caCert)
caCertPool := x509.NewCertPool()
caCertPem, err := ioutil.ReadFile(caCert)
if err != nil {
return errors.New("unable to read cacert")
}
caCertPool.AppendCertsFromPEM(caCertPem)
tlsConfig := &tls.Config{
ClientCAs: caCertPool,
ClientAuth: tls.VerifyClientCertIfGiven,
}
tlsConfig.BuildNameToCertificate()
server := &http.Server{
Addr: server.Listen,
TLSConfig: tlsConfig,
}
server.Handler = handler
err = server.ServeTLS(listener, publicKey, privateKey)
} else {
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
err = http.ServeTLS(listener, handler, publicKey, privateKey)
}
} }
return err return err

View file

@ -13,6 +13,7 @@ func TestTLSSettings(t *testing.T) {
type expected struct { type expected struct {
TLSKey string TLSKey string
TLSCert string TLSCert string
CAcert string
Error bool Error bool
} }
type passed struct { type passed struct {
@ -20,30 +21,29 @@ func TestTLSSettings(t *testing.T) {
TLS bool TLS bool
TLSKey string TLSKey string
TLSCert string TLSCert string
MTLS bool
CACert string
} }
var tests = []struct { var tests = []struct {
passed passed passed passed
expected expected expected expected
}{ }{
{passed{TLS: false}, expected{"", "", false}}, {passed{TLS: false}, expected{"", "", "", false}},
{passed{TLS: true}, expected{ {passed{TLS: true}, expected{"/tmp/restic/private_key", "/tmp/restic/public_key", "", false}},
filepath.Join(os.TempDir(), "restic/private_key"), {passed{Path: "/tmp", TLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "", false}},
filepath.Join(os.TempDir(), "restic/public_key"), {passed{Path: "/tmp", TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", "", false}},
false, {passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
}}, {passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", "", true}},
{passed{ {passed{Path: "/tmp", TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
Path: os.TempDir(),
TLS: true, {passed{TLS: false, MTLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "/etc/restic/cacert", false}},
}, expected{ {passed{TLS: true, MTLS: true}, expected{"/tmp/restic/private_key", "/tmp/restic/public_key", "", false}},
filepath.Join(os.TempDir(), "private_key"), {passed{Path: "/tmp", TLS: true, MTLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "", false}},
filepath.Join(os.TempDir(), "public_key"), {passed{Path: "/tmp", TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert", MTLS: true}, expected{"/etc/restic/key", "/etc/restic/cert", "", false}},
false, {passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
}}, {passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", "", true}},
{passed{Path: os.TempDir(), TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", false}}, {passed{Path: "/tmp", TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
{passed{Path: os.TempDir(), TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
{passed{Path: os.TempDir(), TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", true}},
{passed{Path: os.TempDir(), TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
} }
for _, test := range tests { for _, test := range tests {
@ -57,7 +57,7 @@ func TestTLSSettings(t *testing.T) {
server.TLSKey = test.passed.TLSKey server.TLSKey = test.passed.TLSKey
server.TLSCert = test.passed.TLSCert server.TLSCert = test.passed.TLSCert
gotTLS, gotKey, gotCert, err := tlsSettings() gotTLS, gotMTLS, gotCAcert, gotKey, gotCert, err := tlsSettings()
if err != nil && !test.expected.Error { if err != nil && !test.expected.Error {
t.Fatalf("tls_settings returned err (%v)", err) t.Fatalf("tls_settings returned err (%v)", err)
} }
@ -71,6 +71,7 @@ func TestTLSSettings(t *testing.T) {
if gotTLS != test.passed.TLS { if gotTLS != test.passed.TLS {
t.Errorf("TLS enabled, want (%v), got (%v)", test.passed.TLS, gotTLS) t.Errorf("TLS enabled, want (%v), got (%v)", test.passed.TLS, gotTLS)
} }
wantKey := test.expected.TLSKey wantKey := test.expected.TLSKey
if gotKey != wantKey { if gotKey != wantKey {
t.Errorf("wrong TLSPrivPath path, want (%v), got (%v)", wantKey, gotKey) t.Errorf("wrong TLSPrivPath path, want (%v), got (%v)", wantKey, gotKey)
@ -81,6 +82,14 @@ func TestTLSSettings(t *testing.T) {
t.Errorf("wrong TLSCertPath path, want (%v), got (%v)", wantCert, gotCert) t.Errorf("wrong TLSCertPath path, want (%v), got (%v)", wantCert, gotCert)
} }
if gotMTLS != test.passed.MTLS {
t.Errorf("mTLS enabled, want (%v), got (%v)", test.passed.MTLS, gotMTLS)
}
wantCAcert := test.expected.CAcert
if gotCAcert != wantCAcert {
t.Errorf("wrong CACertPath path, want (%v), got (%v)", wantCAcert, gotCAcert)
}
}) })
} }
} }

View file

@ -18,9 +18,11 @@ type Server struct {
Listen string Listen string
Log string Log string
CPUProfile string CPUProfile string
CACert string
TLSKey string TLSKey string
TLSCert string TLSCert string
TLS bool TLS bool
MTLS bool
NoAuth bool NoAuth bool
AppendOnly bool AppendOnly bool
PrivateRepos bool PrivateRepos bool

36
mux.go
View file

@ -1,7 +1,6 @@
package restserver package restserver
import ( import (
"fmt"
"log" "log"
"net/http" "net/http"
"os" "os"
@ -33,12 +32,35 @@ func (s *Server) checkAuth(r *http.Request) (username string, ok bool) {
if s.NoAuth { if s.NoAuth {
return username, true return username, true
} }
var password string
username, password, ok = r.BasicAuth() username, ok = s.validateBasicAuth(r)
if !ok || !s.htpasswdFile.Validate(username, password) { if ok {
return "", false return username, true
} }
return username, true
username, ok = validateClientCert(r)
if ok {
return username, true
}
return username, false
}
func (s *Server) validateBasicAuth(r *http.Request) (string, bool) {
username, password, ok := r.BasicAuth()
return username, ok && s.htpasswdFile.Validate(username, password)
}
func validateClientCert(r *http.Request) (string, bool) {
if r.TLS != nil {
for _, cert := range r.TLS.PeerCertificates {
username := cert.Subject.CommonName
if username != "" {
return username, true
}
}
}
return "", false
} }
func (s *Server) wrapMetricsAuth(f http.HandlerFunc) http.HandlerFunc { func (s *Server) wrapMetricsAuth(f http.HandlerFunc) http.HandlerFunc {
@ -62,7 +84,7 @@ func NewHandler(server *Server) (http.Handler, error) {
var err error var err error
server.htpasswdFile, err = NewHtpasswdFromFile(filepath.Join(server.Path, ".htpasswd")) server.htpasswdFile, err = NewHtpasswdFromFile(filepath.Join(server.Path, ".htpasswd"))
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot load .htpasswd (use --no-auth to disable): %v", err) // return nil, fmt.Errorf("cannot load .htpasswd (use --no-auth to disable): %v", err)
} }
} }