mirror of
https://github.com/restic/rest-server.git
synced 2025-10-19 07:33:21 +00:00
Add TLS Hot Reload. Fixes #94
This is a very simple implementation. It optionally polls the disk for changes to the key / cert files, and attempts to reload them if it detects the modification time or the size has changed.
This commit is contained in:
parent
eee73d3bc1
commit
3444ebc215
5 changed files with 277 additions and 6 deletions
|
@ -41,7 +41,7 @@ Flags:
|
||||||
--listen string listen address (default ":8000")
|
--listen string listen address (default ":8000")
|
||||||
--log filename write HTTP requests in the combined log format to the specified filename (use "-" for logging to stdout)
|
--log filename write HTTP requests in the combined log format to the specified filename (use "-" for logging to stdout)
|
||||||
--max-size int the maximum size of the repository in bytes
|
--max-size int the maximum size of the repository in bytes
|
||||||
--no-auth disable .htpasswd authentication
|
--no-auth disable authentication
|
||||||
--no-verify-upload do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device
|
--no-verify-upload do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device
|
||||||
--path string data directory (default "/tmp/restic")
|
--path string data directory (default "/tmp/restic")
|
||||||
--private-repos users can only access their private repo
|
--private-repos users can only access their private repo
|
||||||
|
@ -51,6 +51,8 @@ Flags:
|
||||||
--tls turn on TLS support
|
--tls turn on TLS support
|
||||||
--tls-cert string TLS certificate path
|
--tls-cert string TLS certificate path
|
||||||
--tls-key string TLS key path
|
--tls-key string TLS key path
|
||||||
|
--tls-load-dyn dynamically reload TLS key and cert file from disk if they change
|
||||||
|
--tls-load-dyn-poll duration poll at most once per interval when tls-load-dyn is enabled (default 1m0s)
|
||||||
--tls-min-ver string TLS min version, one of (1.2|1.3) (default "1.2")
|
--tls-min-ver string TLS min version, one of (1.2|1.3) (default "1.2")
|
||||||
-v, --version version for rest-server
|
-v, --version version for rest-server
|
||||||
```
|
```
|
||||||
|
|
92
cmd/rest-server/dynamicchecker.go
Normal file
92
cmd/rest-server/dynamicchecker.go
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"io/fs"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dynamicChecker struct {
|
||||||
|
certificate atomic.Pointer[tls.Certificate]
|
||||||
|
keyFile, certFile string
|
||||||
|
keyFileInfo, certFileInfo fs.FileInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// newDynamicChecker creates a struct that holds the data we need to do
|
||||||
|
// dynamic certificate reloads from disk. If it cannot load the files
|
||||||
|
// or they are invalid, an error is returned. Following a successful
|
||||||
|
// instantiation, the getCertificate method will always return a valid
|
||||||
|
// certificate, and we should call the poll method to check for changes.
|
||||||
|
func newDynamicChecker(certFile, keyFile string) (*dynamicChecker, error) {
|
||||||
|
keyFileInfo, err := os.Stat(keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
certFileInfo, err := os.Stat(certFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
crt, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dc := &dynamicChecker{
|
||||||
|
keyFile: keyFile,
|
||||||
|
certFile: certFile,
|
||||||
|
keyFileInfo: keyFileInfo,
|
||||||
|
certFileInfo: certFileInfo,
|
||||||
|
}
|
||||||
|
dc.certificate.Store(&crt)
|
||||||
|
return dc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCertificate - always returns a valid tls.Certificate and nil error.
|
||||||
|
func (dc *dynamicChecker) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
return dc.certificate.Load(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// poll runs in a goroutine and periodically polls the key and cert for
|
||||||
|
// updates.
|
||||||
|
func (dc *dynamicChecker) poll(ctx context.Context, interval time.Duration) {
|
||||||
|
go func() {
|
||||||
|
t := time.NewTimer(interval)
|
||||||
|
defer t.Stop() // go >= 1.23 means we don't have to check the return
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-t.C:
|
||||||
|
keyFileInfo, err := os.Stat(dc.keyFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("could not stat keyFile %s, using previous cert: %s", dc.keyFile, err)
|
||||||
|
break // select
|
||||||
|
}
|
||||||
|
certFileInfo, err := os.Stat(dc.certFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("could not stat certFile %s, using previous cert: %s", dc.certFile, err)
|
||||||
|
break // select
|
||||||
|
}
|
||||||
|
if !keyFileInfo.ModTime().Equal(dc.keyFileInfo.ModTime()) ||
|
||||||
|
keyFileInfo.Size() != dc.keyFileInfo.Size() ||
|
||||||
|
!certFileInfo.ModTime().Equal(dc.certFileInfo.ModTime()) ||
|
||||||
|
certFileInfo.Size() != dc.certFileInfo.Size() {
|
||||||
|
// they changed on disk, reload
|
||||||
|
crt, err := tls.LoadX509KeyPair(dc.certFile, dc.keyFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("could not load cert and key files, using previous cert: %s", err)
|
||||||
|
break // select
|
||||||
|
}
|
||||||
|
dc.certificate.Store(&crt)
|
||||||
|
dc.certFileInfo = certFileInfo
|
||||||
|
dc.keyFileInfo = keyFileInfo
|
||||||
|
log.Printf("successfully reloaded certificate from disk")
|
||||||
|
}
|
||||||
|
} // end select
|
||||||
|
t.Reset(interval)
|
||||||
|
} // end for
|
||||||
|
}()
|
||||||
|
}
|
150
cmd/rest-server/dynamicchecker_test.go
Normal file
150
cmd/rest-server/dynamicchecker_test.go
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"math/big"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDynamicReload(t *testing.T) {
|
||||||
|
cert, key, err := generateCertFiles()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Logf("created %s and %s files", cert, key)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = os.Remove(cert)
|
||||||
|
_ = os.Remove(key)
|
||||||
|
})
|
||||||
|
err = generateSelfSigned(cert, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
dc, err := newDynamicChecker(cert, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
dc.poll(ctx, time.Second)
|
||||||
|
crt1Raw, err := dc.getCertificate(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
crt1, err := x509.ParseCertificate(crt1Raw.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
err = generateSelfSigned(cert, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second * 2)
|
||||||
|
crt2Raw, err := dc.getCertificate(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
crt2, err := x509.ParseCertificate(crt2Raw.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if crt1.SerialNumber.Cmp(crt2.SerialNumber) == 0 {
|
||||||
|
t.Fatal("expected certificate to be different")
|
||||||
|
}
|
||||||
|
t.Logf("cert 1 serial: %s cert 2 serial: %s", crt1.SerialNumber, crt2.SerialNumber)
|
||||||
|
// force a certificate
|
||||||
|
_ = os.Remove(cert)
|
||||||
|
time.Sleep(time.Second * 2)
|
||||||
|
crt3Raw, err := dc.getCertificate(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
crt3, err := x509.ParseCertificate(crt3Raw.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if crt2.SerialNumber.Cmp(crt3.SerialNumber) != 0 {
|
||||||
|
t.Fatal("expected certificate to be certificate")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateCertFiles() (cert, key string, err error) {
|
||||||
|
certFile, err := os.CreateTemp("", "cert")
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
cert = certFile.Name()
|
||||||
|
_ = certFile.Close()
|
||||||
|
keyFile, err := os.CreateTemp("", "key")
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
key = keyFile.Name()
|
||||||
|
_ = keyFile.Close()
|
||||||
|
return cert, key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var serial = int64(9000)
|
||||||
|
|
||||||
|
func NextSerial() *big.Int {
|
||||||
|
serial++
|
||||||
|
return big.NewInt(serial)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSelfSigned(certFile, keyFile string) error {
|
||||||
|
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
template := &x509.Certificate{
|
||||||
|
SerialNumber: NextSerial(),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"Widgets Inc"},
|
||||||
|
},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(time.Hour * 24),
|
||||||
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
DNSNames: []string{"localhost"},
|
||||||
|
}
|
||||||
|
certDer, err := x509.CreateCertificate(rand.Reader, template, template, pk.Public(), pk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
keyDer, err := x509.MarshalECPrivateKey(pk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
keyFh, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = keyFh.Close()
|
||||||
|
}()
|
||||||
|
certFh, err := os.OpenFile(certFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = certFh.Close()
|
||||||
|
}()
|
||||||
|
err = pem.Encode(certFh, &pem.Block{Type: "CERTIFICATE", Bytes: certDer})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = pem.Encode(keyFh, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDer})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
restserver "github.com/restic/rest-server"
|
restserver "github.com/restic/rest-server"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
@ -46,9 +47,10 @@ func newRestServerApp() *restServerApp {
|
||||||
Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH),
|
Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH),
|
||||||
},
|
},
|
||||||
Server: restserver.Server{
|
Server: restserver.Server{
|
||||||
Path: filepath.Join(os.TempDir(), "restic"),
|
Path: filepath.Join(os.TempDir(), "restic"),
|
||||||
Listen: ":8000",
|
Listen: ":8000",
|
||||||
TLSMinVer: "1.2",
|
TLSMinVer: "1.2",
|
||||||
|
TLSReloadTime: time.Minute,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
rv.CmdRoot.RunE = rv.runRoot
|
rv.CmdRoot.RunE = rv.runRoot
|
||||||
|
@ -63,6 +65,9 @@ func newRestServerApp() *restServerApp {
|
||||||
flags.BoolVar(&rv.Server.TLS, "tls", rv.Server.TLS, "turn on TLS support")
|
flags.BoolVar(&rv.Server.TLS, "tls", rv.Server.TLS, "turn on TLS support")
|
||||||
flags.StringVar(&rv.Server.TLSCert, "tls-cert", rv.Server.TLSCert, "TLS certificate path")
|
flags.StringVar(&rv.Server.TLSCert, "tls-cert", rv.Server.TLSCert, "TLS certificate path")
|
||||||
flags.StringVar(&rv.Server.TLSKey, "tls-key", rv.Server.TLSKey, "TLS key path")
|
flags.StringVar(&rv.Server.TLSKey, "tls-key", rv.Server.TLSKey, "TLS key path")
|
||||||
|
flags.BoolVar(&rv.Server.TLSDynamicReload, "tls-load-dyn", rv.Server.TLSDynamicReload, "dynamically reload TLS key and cert file from disk if they change")
|
||||||
|
flags.DurationVar(&rv.Server.TLSReloadTime, "tls-load-dyn-poll", rv.Server.TLSReloadTime, "poll at most once per interval when tls-load-dyn is enabled")
|
||||||
|
|
||||||
flags.StringVar(&rv.Server.TLSMinVer, "tls-min-ver", rv.Server.TLSMinVer, "TLS min version, one of (1.2|1.3)")
|
flags.StringVar(&rv.Server.TLSMinVer, "tls-min-ver", rv.Server.TLSMinVer, "TLS min version, one of (1.2|1.3)")
|
||||||
flags.BoolVar(&rv.Server.NoAuth, "no-auth", rv.Server.NoAuth, "disable authentication")
|
flags.BoolVar(&rv.Server.NoAuth, "no-auth", rv.Server.NoAuth, "disable authentication")
|
||||||
flags.StringVar(&rv.Server.HtpasswdPath, "htpasswd-file", rv.Server.HtpasswdPath, "location of .htpasswd file (default: \"<data directory>/.htpasswd)\"")
|
flags.StringVar(&rv.Server.HtpasswdPath, "htpasswd-file", rv.Server.HtpasswdPath, "location of .htpasswd file (default: \"<data directory>/.htpasswd)\"")
|
||||||
|
@ -198,19 +203,38 @@ func (app *restServerApp) runRoot(_ *cobra.Command, _ []string) error {
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("Unsupported TLS min version: %s. Allowed versions are 1.2 or 1.3", app.Server.TLSMinVer)
|
return fmt.Errorf("Unsupported TLS min version: %s. Allowed versions are 1.2 or 1.3", app.Server.TLSMinVer)
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Handler: handler,
|
Handler: handler,
|
||||||
TLSConfig: tlscfg,
|
TLSConfig: tlscfg,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if enabledTLS {
|
||||||
|
if app.Server.TLSDynamicReload {
|
||||||
|
dc, err := newDynamicChecker(publicKey, privateKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to load key pair: %w", err)
|
||||||
|
}
|
||||||
|
dc.poll(app.CmdRoot.Context(), app.Server.TLSReloadTime)
|
||||||
|
tlscfg.GetCertificate = dc.getCertificate
|
||||||
|
} else {
|
||||||
|
crt, err := tls.LoadX509KeyPair(publicKey, privateKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to load key pair: %w", err)
|
||||||
|
}
|
||||||
|
tlscfg.Certificates = []tls.Certificate{crt}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// run server in background
|
// run server in background
|
||||||
go func() {
|
go func() {
|
||||||
if !enabledTLS {
|
if !enabledTLS {
|
||||||
err = srv.Serve(listener)
|
err = srv.Serve(listener)
|
||||||
} else {
|
} else {
|
||||||
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
|
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
|
||||||
err = srv.ServeTLS(listener, publicKey, privateKey)
|
if app.Server.TLSDynamicReload {
|
||||||
|
log.Printf("TLS dynamic reloading enabled, will poll up to once every %s for changes", app.Server.TLSReloadTime)
|
||||||
|
}
|
||||||
|
err = srv.ServeTLS(listener, "", "")
|
||||||
}
|
}
|
||||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Fatalf("listen and serve returned err: %v", err)
|
log.Fatalf("listen and serve returned err: %v", err)
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/restic/rest-server/quota"
|
"github.com/restic/rest-server/quota"
|
||||||
"github.com/restic/rest-server/repo"
|
"github.com/restic/rest-server/repo"
|
||||||
|
@ -24,6 +25,8 @@ type Server struct {
|
||||||
TLSCert string
|
TLSCert string
|
||||||
TLSMinVer string
|
TLSMinVer string
|
||||||
TLS bool
|
TLS bool
|
||||||
|
TLSDynamicReload bool
|
||||||
|
TLSReloadTime time.Duration
|
||||||
NoAuth bool
|
NoAuth bool
|
||||||
ProxyAuthUsername string
|
ProxyAuthUsername string
|
||||||
AppendOnly bool
|
AppendOnly bool
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue