This commit is contained in:
textaligncenter 2022-04-24 20:16:06 +00:00 committed by GitHub
commit 713259af6c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 100 additions and 34 deletions

View file

@ -14,3 +14,4 @@ Wayne Scott <wsc9tt@gmail.com>
Zlatko Čalušić <zcalusic@bitsync.net>
cgonzalez <chgonzalezg@gmail.com>
n0npax <marcin@niemira.net>
textaligncenter <67056612+textaligncenter@users.noreply.github.com>

View file

@ -1,8 +1,11 @@
package main
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
@ -46,6 +49,8 @@ func init() {
flags.BoolVar(&server.TLS, "tls", server.TLS, "turn on TLS support")
flags.StringVar(&server.TLSCert, "tls-cert", server.TLSCert, "TLS certificate path")
flags.StringVar(&server.TLSKey, "tls-key", server.TLSKey, "TLS key path")
flags.BoolVar(&server.MTLS, "mtls", server.MTLS, "turn on client certificate support")
flags.StringVar(&server.CACert, "cacert", server.CACert, "mTLS CA certificate path")
flags.BoolVar(&server.NoAuth, "no-auth", server.NoAuth, "disable .htpasswd authentication")
flags.BoolVar(&server.NoVerifyUpload, "no-verify-upload", server.NoVerifyUpload,
"do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device")
@ -57,12 +62,12 @@ func init() {
var version = "0.11.0"
func tlsSettings() (bool, string, string, error) {
func tlsSettings() (bool, bool, string, string, string, error) {
var key, cert string
if !server.TLS && (server.TLSKey != "" || server.TLSCert != "") {
return false, "", "", errors.New("requires enabled TLS")
if (!server.TLS && !server.MTLS) && (server.TLSKey != "" || server.TLSCert != "") {
return false, false, "", "", "", errors.New("requires enabled TLS or mTLS")
} else if !server.TLS {
return false, "", "", nil
return false, false, "", "", "", nil
}
if server.TLSKey != "" {
key = server.TLSKey
@ -74,7 +79,11 @@ func tlsSettings() (bool, string, string, error) {
} else {
cert = filepath.Join(server.Path, "public_key")
}
return server.TLS, key, cert, nil
if server.MTLS && server.CACert == "" {
return false, false, "", "", "", errors.New("missing cacert")
}
return server.TLS, server.MTLS, server.CACert, key, cert, nil
}
func runRoot(cmd *cobra.Command, args []string) error {
@ -125,7 +134,7 @@ func runRoot(cmd *cobra.Command, args []string) error {
log.Println("Private repositories disabled")
}
enabledTLS, privateKey, publicKey, err := tlsSettings()
enabledTLS, enabledMTLS, caCert, privateKey, publicKey, err := tlsSettings()
if err != nil {
return err
}
@ -137,10 +146,33 @@ func runRoot(cmd *cobra.Command, args []string) error {
if !enabledTLS {
err = http.Serve(listener, handler)
} else {
if enabledMTLS {
log.Printf("mTLS enabled, private key %s, pubkey %s, cacert %s", privateKey, publicKey, caCert)
caCertPool := x509.NewCertPool()
caCertPem, err := ioutil.ReadFile(caCert)
if err != nil {
return errors.New("unable to read cacert")
}
caCertPool.AppendCertsFromPEM(caCertPem)
tlsConfig := &tls.Config{
ClientCAs: caCertPool,
ClientAuth: tls.VerifyClientCertIfGiven,
}
tlsConfig.BuildNameToCertificate()
server := &http.Server{
Addr: server.Listen,
TLSConfig: tlsConfig,
}
server.Handler = handler
err = server.ServeTLS(listener, publicKey, privateKey)
} else {
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
err = http.ServeTLS(listener, handler, publicKey, privateKey)
}
}
return err
}

View file

@ -13,6 +13,7 @@ func TestTLSSettings(t *testing.T) {
type expected struct {
TLSKey string
TLSCert string
CAcert string
Error bool
}
type passed struct {
@ -20,30 +21,29 @@ func TestTLSSettings(t *testing.T) {
TLS bool
TLSKey string
TLSCert string
MTLS bool
CACert string
}
var tests = []struct {
passed passed
expected expected
}{
{passed{TLS: false}, expected{"", "", false}},
{passed{TLS: true}, expected{
filepath.Join(os.TempDir(), "restic/private_key"),
filepath.Join(os.TempDir(), "restic/public_key"),
false,
}},
{passed{
Path: os.TempDir(),
TLS: true,
}, expected{
filepath.Join(os.TempDir(), "private_key"),
filepath.Join(os.TempDir(), "public_key"),
false,
}},
{passed{Path: os.TempDir(), TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", false}},
{passed{Path: os.TempDir(), TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
{passed{Path: os.TempDir(), TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", true}},
{passed{Path: os.TempDir(), TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
{passed{TLS: false}, expected{"", "", "", false}},
{passed{TLS: true}, expected{"/tmp/restic/private_key", "/tmp/restic/public_key", "", false}},
{passed{Path: "/tmp", TLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "", false}},
{passed{Path: "/tmp", TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", "", false}},
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", "", true}},
{passed{Path: "/tmp", TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
{passed{TLS: false, MTLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "/etc/restic/cacert", false}},
{passed{TLS: true, MTLS: true}, expected{"/tmp/restic/private_key", "/tmp/restic/public_key", "", false}},
{passed{Path: "/tmp", TLS: true, MTLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "", false}},
{passed{Path: "/tmp", TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert", MTLS: true}, expected{"/etc/restic/key", "/etc/restic/cert", "", false}},
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", "", true}},
{passed{Path: "/tmp", TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
}
for _, test := range tests {
@ -57,7 +57,7 @@ func TestTLSSettings(t *testing.T) {
server.TLSKey = test.passed.TLSKey
server.TLSCert = test.passed.TLSCert
gotTLS, gotKey, gotCert, err := tlsSettings()
gotTLS, gotMTLS, gotCAcert, gotKey, gotCert, err := tlsSettings()
if err != nil && !test.expected.Error {
t.Fatalf("tls_settings returned err (%v)", err)
}
@ -71,6 +71,7 @@ func TestTLSSettings(t *testing.T) {
if gotTLS != test.passed.TLS {
t.Errorf("TLS enabled, want (%v), got (%v)", test.passed.TLS, gotTLS)
}
wantKey := test.expected.TLSKey
if gotKey != wantKey {
t.Errorf("wrong TLSPrivPath path, want (%v), got (%v)", wantKey, gotKey)
@ -81,6 +82,14 @@ func TestTLSSettings(t *testing.T) {
t.Errorf("wrong TLSCertPath path, want (%v), got (%v)", wantCert, gotCert)
}
if gotMTLS != test.passed.MTLS {
t.Errorf("mTLS enabled, want (%v), got (%v)", test.passed.MTLS, gotMTLS)
}
wantCAcert := test.expected.CAcert
if gotCAcert != wantCAcert {
t.Errorf("wrong CACertPath path, want (%v), got (%v)", wantCAcert, gotCAcert)
}
})
}
}

View file

@ -18,9 +18,11 @@ type Server struct {
Listen string
Log string
CPUProfile string
CACert string
TLSKey string
TLSCert string
TLS bool
MTLS bool
NoAuth bool
AppendOnly bool
PrivateRepos bool

36
mux.go
View file

@ -1,7 +1,6 @@
package restserver
import (
"fmt"
"log"
"net/http"
"os"
@ -33,12 +32,35 @@ func (s *Server) checkAuth(r *http.Request) (username string, ok bool) {
if s.NoAuth {
return username, true
}
var password string
username, password, ok = r.BasicAuth()
if !ok || !s.htpasswdFile.Validate(username, password) {
return "", false
}
username, ok = s.validateBasicAuth(r)
if ok {
return username, true
}
username, ok = validateClientCert(r)
if ok {
return username, true
}
return username, false
}
func (s *Server) validateBasicAuth(r *http.Request) (string, bool) {
username, password, ok := r.BasicAuth()
return username, ok && s.htpasswdFile.Validate(username, password)
}
func validateClientCert(r *http.Request) (string, bool) {
if r.TLS != nil {
for _, cert := range r.TLS.PeerCertificates {
username := cert.Subject.CommonName
if username != "" {
return username, true
}
}
}
return "", false
}
func (s *Server) wrapMetricsAuth(f http.HandlerFunc) http.HandlerFunc {
@ -62,7 +84,7 @@ func NewHandler(server *Server) (http.Handler, error) {
var err error
server.htpasswdFile, err = NewHtpasswdFromFile(filepath.Join(server.Path, ".htpasswd"))
if err != nil {
return nil, fmt.Errorf("cannot load .htpasswd (use --no-auth to disable): %v", err)
// return nil, fmt.Errorf("cannot load .htpasswd (use --no-auth to disable): %v", err)
}
}