Add TLS Hot Reload. Fixes #94

This is a very simple implementation.  It optionally polls the disk
for changes to the key / cert files, and attempts to reload them
if it detects the modification time or the size has changed.
This commit is contained in:
Nathan Johnson 2025-04-14 18:25:27 -05:00
parent eee73d3bc1
commit 3444ebc215
5 changed files with 277 additions and 6 deletions

View file

@ -41,7 +41,7 @@ Flags:
--listen string listen address (default ":8000") --listen string listen address (default ":8000")
--log filename write HTTP requests in the combined log format to the specified filename (use "-" for logging to stdout) --log filename write HTTP requests in the combined log format to the specified filename (use "-" for logging to stdout)
--max-size int the maximum size of the repository in bytes --max-size int the maximum size of the repository in bytes
--no-auth disable .htpasswd authentication --no-auth disable authentication
--no-verify-upload do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device --no-verify-upload do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device
--path string data directory (default "/tmp/restic") --path string data directory (default "/tmp/restic")
--private-repos users can only access their private repo --private-repos users can only access their private repo
@ -51,6 +51,8 @@ Flags:
--tls turn on TLS support --tls turn on TLS support
--tls-cert string TLS certificate path --tls-cert string TLS certificate path
--tls-key string TLS key path --tls-key string TLS key path
--tls-load-dyn dynamically reload TLS key and cert file from disk if they change
--tls-load-dyn-poll duration poll at most once per interval when tls-load-dyn is enabled (default 1m0s)
--tls-min-ver string TLS min version, one of (1.2|1.3) (default "1.2") --tls-min-ver string TLS min version, one of (1.2|1.3) (default "1.2")
-v, --version version for rest-server -v, --version version for rest-server
``` ```

View file

@ -0,0 +1,92 @@
package main
import (
"context"
"crypto/tls"
"io/fs"
"log"
"os"
"sync/atomic"
"time"
)
type dynamicChecker struct {
certificate atomic.Pointer[tls.Certificate]
keyFile, certFile string
keyFileInfo, certFileInfo fs.FileInfo
}
// newDynamicChecker creates a struct that holds the data we need to do
// dynamic certificate reloads from disk. If it cannot load the files
// or they are invalid, an error is returned. Following a successful
// instantiation, the getCertificate method will always return a valid
// certificate, and we should call the poll method to check for changes.
func newDynamicChecker(certFile, keyFile string) (*dynamicChecker, error) {
keyFileInfo, err := os.Stat(keyFile)
if err != nil {
return nil, err
}
certFileInfo, err := os.Stat(certFile)
if err != nil {
return nil, err
}
crt, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
dc := &dynamicChecker{
keyFile: keyFile,
certFile: certFile,
keyFileInfo: keyFileInfo,
certFileInfo: certFileInfo,
}
dc.certificate.Store(&crt)
return dc, nil
}
// getCertificate - always returns a valid tls.Certificate and nil error.
func (dc *dynamicChecker) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
return dc.certificate.Load(), nil
}
// poll runs in a goroutine and periodically polls the key and cert for
// updates.
func (dc *dynamicChecker) poll(ctx context.Context, interval time.Duration) {
go func() {
t := time.NewTimer(interval)
defer t.Stop() // go >= 1.23 means we don't have to check the return
for {
select {
case <-ctx.Done():
return
case <-t.C:
keyFileInfo, err := os.Stat(dc.keyFile)
if err != nil {
log.Printf("could not stat keyFile %s, using previous cert: %s", dc.keyFile, err)
break // select
}
certFileInfo, err := os.Stat(dc.certFile)
if err != nil {
log.Printf("could not stat certFile %s, using previous cert: %s", dc.certFile, err)
break // select
}
if !keyFileInfo.ModTime().Equal(dc.keyFileInfo.ModTime()) ||
keyFileInfo.Size() != dc.keyFileInfo.Size() ||
!certFileInfo.ModTime().Equal(dc.certFileInfo.ModTime()) ||
certFileInfo.Size() != dc.certFileInfo.Size() {
// they changed on disk, reload
crt, err := tls.LoadX509KeyPair(dc.certFile, dc.keyFile)
if err != nil {
log.Printf("could not load cert and key files, using previous cert: %s", err)
break // select
}
dc.certificate.Store(&crt)
dc.certFileInfo = certFileInfo
dc.keyFileInfo = keyFileInfo
log.Printf("successfully reloaded certificate from disk")
}
} // end select
t.Reset(interval)
} // end for
}()
}

View file

@ -0,0 +1,150 @@
package main
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"testing"
"time"
)
func TestDynamicReload(t *testing.T) {
cert, key, err := generateCertFiles()
if err != nil {
t.Fatal(err)
}
t.Logf("created %s and %s files", cert, key)
t.Cleanup(func() {
_ = os.Remove(cert)
_ = os.Remove(key)
})
err = generateSelfSigned(cert, key)
if err != nil {
t.Fatal(err)
}
dc, err := newDynamicChecker(cert, key)
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
dc.poll(ctx, time.Second)
crt1Raw, err := dc.getCertificate(nil)
if err != nil {
t.Fatal(err)
}
crt1, err := x509.ParseCertificate(crt1Raw.Certificate[0])
if err != nil {
t.Fatal(err)
}
err = generateSelfSigned(cert, key)
if err != nil {
t.Fatal(err)
}
time.Sleep(time.Second * 2)
crt2Raw, err := dc.getCertificate(nil)
if err != nil {
t.Fatal(err)
}
crt2, err := x509.ParseCertificate(crt2Raw.Certificate[0])
if err != nil {
t.Fatal(err)
}
if crt1.SerialNumber.Cmp(crt2.SerialNumber) == 0 {
t.Fatal("expected certificate to be different")
}
t.Logf("cert 1 serial: %s cert 2 serial: %s", crt1.SerialNumber, crt2.SerialNumber)
// force a certificate
_ = os.Remove(cert)
time.Sleep(time.Second * 2)
crt3Raw, err := dc.getCertificate(nil)
if err != nil {
t.Fatal(err)
}
crt3, err := x509.ParseCertificate(crt3Raw.Certificate[0])
if err != nil {
t.Fatal(err)
}
if crt2.SerialNumber.Cmp(crt3.SerialNumber) != 0 {
t.Fatal("expected certificate to be certificate")
}
}
func generateCertFiles() (cert, key string, err error) {
certFile, err := os.CreateTemp("", "cert")
if err != nil {
return "", "", err
}
cert = certFile.Name()
_ = certFile.Close()
keyFile, err := os.CreateTemp("", "key")
if err != nil {
return "", "", err
}
key = keyFile.Name()
_ = keyFile.Close()
return cert, key, nil
}
var serial = int64(9000)
func NextSerial() *big.Int {
serial++
return big.NewInt(serial)
}
func generateSelfSigned(certFile, keyFile string) error {
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return err
}
template := &x509.Certificate{
SerialNumber: NextSerial(),
Subject: pkix.Name{
Organization: []string{"Widgets Inc"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
}
certDer, err := x509.CreateCertificate(rand.Reader, template, template, pk.Public(), pk)
if err != nil {
return err
}
keyDer, err := x509.MarshalECPrivateKey(pk)
if err != nil {
return err
}
keyFh, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer func() {
_ = keyFh.Close()
}()
certFh, err := os.OpenFile(certFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer func() {
_ = certFh.Close()
}()
err = pem.Encode(certFh, &pem.Block{Type: "CERTIFICATE", Bytes: certDer})
if err != nil {
return err
}
err = pem.Encode(keyFh, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDer})
if err != nil {
return err
}
return nil
}

View file

@ -15,6 +15,7 @@ import (
"runtime/pprof" "runtime/pprof"
"sync" "sync"
"syscall" "syscall"
"time"
restserver "github.com/restic/rest-server" restserver "github.com/restic/rest-server"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -46,9 +47,10 @@ func newRestServerApp() *restServerApp {
Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH), Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH),
}, },
Server: restserver.Server{ Server: restserver.Server{
Path: filepath.Join(os.TempDir(), "restic"), Path: filepath.Join(os.TempDir(), "restic"),
Listen: ":8000", Listen: ":8000",
TLSMinVer: "1.2", TLSMinVer: "1.2",
TLSReloadTime: time.Minute,
}, },
} }
rv.CmdRoot.RunE = rv.runRoot rv.CmdRoot.RunE = rv.runRoot
@ -63,6 +65,9 @@ func newRestServerApp() *restServerApp {
flags.BoolVar(&rv.Server.TLS, "tls", rv.Server.TLS, "turn on TLS support") flags.BoolVar(&rv.Server.TLS, "tls", rv.Server.TLS, "turn on TLS support")
flags.StringVar(&rv.Server.TLSCert, "tls-cert", rv.Server.TLSCert, "TLS certificate path") flags.StringVar(&rv.Server.TLSCert, "tls-cert", rv.Server.TLSCert, "TLS certificate path")
flags.StringVar(&rv.Server.TLSKey, "tls-key", rv.Server.TLSKey, "TLS key path") flags.StringVar(&rv.Server.TLSKey, "tls-key", rv.Server.TLSKey, "TLS key path")
flags.BoolVar(&rv.Server.TLSDynamicReload, "tls-load-dyn", rv.Server.TLSDynamicReload, "dynamically reload TLS key and cert file from disk if they change")
flags.DurationVar(&rv.Server.TLSReloadTime, "tls-load-dyn-poll", rv.Server.TLSReloadTime, "poll at most once per interval when tls-load-dyn is enabled")
flags.StringVar(&rv.Server.TLSMinVer, "tls-min-ver", rv.Server.TLSMinVer, "TLS min version, one of (1.2|1.3)") flags.StringVar(&rv.Server.TLSMinVer, "tls-min-ver", rv.Server.TLSMinVer, "TLS min version, one of (1.2|1.3)")
flags.BoolVar(&rv.Server.NoAuth, "no-auth", rv.Server.NoAuth, "disable authentication") flags.BoolVar(&rv.Server.NoAuth, "no-auth", rv.Server.NoAuth, "disable authentication")
flags.StringVar(&rv.Server.HtpasswdPath, "htpasswd-file", rv.Server.HtpasswdPath, "location of .htpasswd file (default: \"<data directory>/.htpasswd)\"") flags.StringVar(&rv.Server.HtpasswdPath, "htpasswd-file", rv.Server.HtpasswdPath, "location of .htpasswd file (default: \"<data directory>/.htpasswd)\"")
@ -198,19 +203,38 @@ func (app *restServerApp) runRoot(_ *cobra.Command, _ []string) error {
default: default:
return fmt.Errorf("Unsupported TLS min version: %s. Allowed versions are 1.2 or 1.3", app.Server.TLSMinVer) return fmt.Errorf("Unsupported TLS min version: %s. Allowed versions are 1.2 or 1.3", app.Server.TLSMinVer)
} }
srv := &http.Server{ srv := &http.Server{
Handler: handler, Handler: handler,
TLSConfig: tlscfg, TLSConfig: tlscfg,
} }
if enabledTLS {
if app.Server.TLSDynamicReload {
dc, err := newDynamicChecker(publicKey, privateKey)
if err != nil {
return fmt.Errorf("unable to load key pair: %w", err)
}
dc.poll(app.CmdRoot.Context(), app.Server.TLSReloadTime)
tlscfg.GetCertificate = dc.getCertificate
} else {
crt, err := tls.LoadX509KeyPair(publicKey, privateKey)
if err != nil {
return fmt.Errorf("unable to load key pair: %w", err)
}
tlscfg.Certificates = []tls.Certificate{crt}
}
}
// run server in background // run server in background
go func() { go func() {
if !enabledTLS { if !enabledTLS {
err = srv.Serve(listener) err = srv.Serve(listener)
} else { } else {
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey) log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
err = srv.ServeTLS(listener, publicKey, privateKey) if app.Server.TLSDynamicReload {
log.Printf("TLS dynamic reloading enabled, will poll up to once every %s for changes", app.Server.TLSReloadTime)
}
err = srv.ServeTLS(listener, "", "")
} }
if err != nil && !errors.Is(err, http.ErrServerClosed) { if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("listen and serve returned err: %v", err) log.Fatalf("listen and serve returned err: %v", err)

View file

@ -8,6 +8,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
"time"
"github.com/restic/rest-server/quota" "github.com/restic/rest-server/quota"
"github.com/restic/rest-server/repo" "github.com/restic/rest-server/repo"
@ -24,6 +25,8 @@ type Server struct {
TLSCert string TLSCert string
TLSMinVer string TLSMinVer string
TLS bool TLS bool
TLSDynamicReload bool
TLSReloadTime time.Duration
NoAuth bool NoAuth bool
ProxyAuthUsername string ProxyAuthUsername string
AppendOnly bool AppendOnly bool