mirror of
https://github.com/restic/rest-server.git
synced 2025-10-19 07:33:21 +00:00
Add TLS Hot Reload. Fixes #94
This is a very simple implementation. It optionally polls the disk for changes to the key / cert files, and attempts to reload them if it detects the modification time or the size has changed.
This commit is contained in:
parent
eee73d3bc1
commit
3444ebc215
5 changed files with 277 additions and 6 deletions
|
@ -41,7 +41,7 @@ Flags:
|
|||
--listen string listen address (default ":8000")
|
||||
--log filename write HTTP requests in the combined log format to the specified filename (use "-" for logging to stdout)
|
||||
--max-size int the maximum size of the repository in bytes
|
||||
--no-auth disable .htpasswd authentication
|
||||
--no-auth disable authentication
|
||||
--no-verify-upload do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device
|
||||
--path string data directory (default "/tmp/restic")
|
||||
--private-repos users can only access their private repo
|
||||
|
@ -51,6 +51,8 @@ Flags:
|
|||
--tls turn on TLS support
|
||||
--tls-cert string TLS certificate path
|
||||
--tls-key string TLS key path
|
||||
--tls-load-dyn dynamically reload TLS key and cert file from disk if they change
|
||||
--tls-load-dyn-poll duration poll at most once per interval when tls-load-dyn is enabled (default 1m0s)
|
||||
--tls-min-ver string TLS min version, one of (1.2|1.3) (default "1.2")
|
||||
-v, --version version for rest-server
|
||||
```
|
||||
|
|
92
cmd/rest-server/dynamicchecker.go
Normal file
92
cmd/rest-server/dynamicchecker.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type dynamicChecker struct {
|
||||
certificate atomic.Pointer[tls.Certificate]
|
||||
keyFile, certFile string
|
||||
keyFileInfo, certFileInfo fs.FileInfo
|
||||
}
|
||||
|
||||
// newDynamicChecker creates a struct that holds the data we need to do
|
||||
// dynamic certificate reloads from disk. If it cannot load the files
|
||||
// or they are invalid, an error is returned. Following a successful
|
||||
// instantiation, the getCertificate method will always return a valid
|
||||
// certificate, and we should call the poll method to check for changes.
|
||||
func newDynamicChecker(certFile, keyFile string) (*dynamicChecker, error) {
|
||||
keyFileInfo, err := os.Stat(keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certFileInfo, err := os.Stat(certFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
crt, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dc := &dynamicChecker{
|
||||
keyFile: keyFile,
|
||||
certFile: certFile,
|
||||
keyFileInfo: keyFileInfo,
|
||||
certFileInfo: certFileInfo,
|
||||
}
|
||||
dc.certificate.Store(&crt)
|
||||
return dc, nil
|
||||
}
|
||||
|
||||
// getCertificate - always returns a valid tls.Certificate and nil error.
|
||||
func (dc *dynamicChecker) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return dc.certificate.Load(), nil
|
||||
}
|
||||
|
||||
// poll runs in a goroutine and periodically polls the key and cert for
|
||||
// updates.
|
||||
func (dc *dynamicChecker) poll(ctx context.Context, interval time.Duration) {
|
||||
go func() {
|
||||
t := time.NewTimer(interval)
|
||||
defer t.Stop() // go >= 1.23 means we don't have to check the return
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
keyFileInfo, err := os.Stat(dc.keyFile)
|
||||
if err != nil {
|
||||
log.Printf("could not stat keyFile %s, using previous cert: %s", dc.keyFile, err)
|
||||
break // select
|
||||
}
|
||||
certFileInfo, err := os.Stat(dc.certFile)
|
||||
if err != nil {
|
||||
log.Printf("could not stat certFile %s, using previous cert: %s", dc.certFile, err)
|
||||
break // select
|
||||
}
|
||||
if !keyFileInfo.ModTime().Equal(dc.keyFileInfo.ModTime()) ||
|
||||
keyFileInfo.Size() != dc.keyFileInfo.Size() ||
|
||||
!certFileInfo.ModTime().Equal(dc.certFileInfo.ModTime()) ||
|
||||
certFileInfo.Size() != dc.certFileInfo.Size() {
|
||||
// they changed on disk, reload
|
||||
crt, err := tls.LoadX509KeyPair(dc.certFile, dc.keyFile)
|
||||
if err != nil {
|
||||
log.Printf("could not load cert and key files, using previous cert: %s", err)
|
||||
break // select
|
||||
}
|
||||
dc.certificate.Store(&crt)
|
||||
dc.certFileInfo = certFileInfo
|
||||
dc.keyFileInfo = keyFileInfo
|
||||
log.Printf("successfully reloaded certificate from disk")
|
||||
}
|
||||
} // end select
|
||||
t.Reset(interval)
|
||||
} // end for
|
||||
}()
|
||||
}
|
150
cmd/rest-server/dynamicchecker_test.go
Normal file
150
cmd/rest-server/dynamicchecker_test.go
Normal file
|
@ -0,0 +1,150 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDynamicReload(t *testing.T) {
|
||||
cert, key, err := generateCertFiles()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("created %s and %s files", cert, key)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove(cert)
|
||||
_ = os.Remove(key)
|
||||
})
|
||||
err = generateSelfSigned(cert, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dc, err := newDynamicChecker(cert, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
dc.poll(ctx, time.Second)
|
||||
crt1Raw, err := dc.getCertificate(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
crt1, err := x509.ParseCertificate(crt1Raw.Certificate[0])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = generateSelfSigned(cert, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(time.Second * 2)
|
||||
crt2Raw, err := dc.getCertificate(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
crt2, err := x509.ParseCertificate(crt2Raw.Certificate[0])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if crt1.SerialNumber.Cmp(crt2.SerialNumber) == 0 {
|
||||
t.Fatal("expected certificate to be different")
|
||||
}
|
||||
t.Logf("cert 1 serial: %s cert 2 serial: %s", crt1.SerialNumber, crt2.SerialNumber)
|
||||
// force a certificate
|
||||
_ = os.Remove(cert)
|
||||
time.Sleep(time.Second * 2)
|
||||
crt3Raw, err := dc.getCertificate(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
crt3, err := x509.ParseCertificate(crt3Raw.Certificate[0])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if crt2.SerialNumber.Cmp(crt3.SerialNumber) != 0 {
|
||||
t.Fatal("expected certificate to be certificate")
|
||||
}
|
||||
}
|
||||
|
||||
func generateCertFiles() (cert, key string, err error) {
|
||||
certFile, err := os.CreateTemp("", "cert")
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
cert = certFile.Name()
|
||||
_ = certFile.Close()
|
||||
keyFile, err := os.CreateTemp("", "key")
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
key = keyFile.Name()
|
||||
_ = keyFile.Close()
|
||||
return cert, key, nil
|
||||
}
|
||||
|
||||
var serial = int64(9000)
|
||||
|
||||
func NextSerial() *big.Int {
|
||||
serial++
|
||||
return big.NewInt(serial)
|
||||
}
|
||||
|
||||
func generateSelfSigned(certFile, keyFile string) error {
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: NextSerial(),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Widgets Inc"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour * 24),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
certDer, err := x509.CreateCertificate(rand.Reader, template, template, pk.Public(), pk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyDer, err := x509.MarshalECPrivateKey(pk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyFh, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = keyFh.Close()
|
||||
}()
|
||||
certFh, err := os.OpenFile(certFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = certFh.Close()
|
||||
}()
|
||||
err = pem.Encode(certFh, &pem.Block{Type: "CERTIFICATE", Bytes: certDer})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = pem.Encode(keyFh, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDer})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -15,6 +15,7 @@ import (
|
|||
"runtime/pprof"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
restserver "github.com/restic/rest-server"
|
||||
"github.com/spf13/cobra"
|
||||
|
@ -46,9 +47,10 @@ func newRestServerApp() *restServerApp {
|
|||
Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH),
|
||||
},
|
||||
Server: restserver.Server{
|
||||
Path: filepath.Join(os.TempDir(), "restic"),
|
||||
Listen: ":8000",
|
||||
TLSMinVer: "1.2",
|
||||
Path: filepath.Join(os.TempDir(), "restic"),
|
||||
Listen: ":8000",
|
||||
TLSMinVer: "1.2",
|
||||
TLSReloadTime: time.Minute,
|
||||
},
|
||||
}
|
||||
rv.CmdRoot.RunE = rv.runRoot
|
||||
|
@ -63,6 +65,9 @@ func newRestServerApp() *restServerApp {
|
|||
flags.BoolVar(&rv.Server.TLS, "tls", rv.Server.TLS, "turn on TLS support")
|
||||
flags.StringVar(&rv.Server.TLSCert, "tls-cert", rv.Server.TLSCert, "TLS certificate path")
|
||||
flags.StringVar(&rv.Server.TLSKey, "tls-key", rv.Server.TLSKey, "TLS key path")
|
||||
flags.BoolVar(&rv.Server.TLSDynamicReload, "tls-load-dyn", rv.Server.TLSDynamicReload, "dynamically reload TLS key and cert file from disk if they change")
|
||||
flags.DurationVar(&rv.Server.TLSReloadTime, "tls-load-dyn-poll", rv.Server.TLSReloadTime, "poll at most once per interval when tls-load-dyn is enabled")
|
||||
|
||||
flags.StringVar(&rv.Server.TLSMinVer, "tls-min-ver", rv.Server.TLSMinVer, "TLS min version, one of (1.2|1.3)")
|
||||
flags.BoolVar(&rv.Server.NoAuth, "no-auth", rv.Server.NoAuth, "disable authentication")
|
||||
flags.StringVar(&rv.Server.HtpasswdPath, "htpasswd-file", rv.Server.HtpasswdPath, "location of .htpasswd file (default: \"<data directory>/.htpasswd)\"")
|
||||
|
@ -198,19 +203,38 @@ func (app *restServerApp) runRoot(_ *cobra.Command, _ []string) error {
|
|||
default:
|
||||
return fmt.Errorf("Unsupported TLS min version: %s. Allowed versions are 1.2 or 1.3", app.Server.TLSMinVer)
|
||||
}
|
||||
|
||||
srv := &http.Server{
|
||||
Handler: handler,
|
||||
TLSConfig: tlscfg,
|
||||
}
|
||||
|
||||
if enabledTLS {
|
||||
if app.Server.TLSDynamicReload {
|
||||
dc, err := newDynamicChecker(publicKey, privateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to load key pair: %w", err)
|
||||
}
|
||||
dc.poll(app.CmdRoot.Context(), app.Server.TLSReloadTime)
|
||||
tlscfg.GetCertificate = dc.getCertificate
|
||||
} else {
|
||||
crt, err := tls.LoadX509KeyPair(publicKey, privateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to load key pair: %w", err)
|
||||
}
|
||||
tlscfg.Certificates = []tls.Certificate{crt}
|
||||
}
|
||||
}
|
||||
|
||||
// run server in background
|
||||
go func() {
|
||||
if !enabledTLS {
|
||||
err = srv.Serve(listener)
|
||||
} else {
|
||||
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
|
||||
err = srv.ServeTLS(listener, publicKey, privateKey)
|
||||
if app.Server.TLSDynamicReload {
|
||||
log.Printf("TLS dynamic reloading enabled, will poll up to once every %s for changes", app.Server.TLSReloadTime)
|
||||
}
|
||||
err = srv.ServeTLS(listener, "", "")
|
||||
}
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("listen and serve returned err: %v", err)
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/restic/rest-server/quota"
|
||||
"github.com/restic/rest-server/repo"
|
||||
|
@ -24,6 +25,8 @@ type Server struct {
|
|||
TLSCert string
|
||||
TLSMinVer string
|
||||
TLS bool
|
||||
TLSDynamicReload bool
|
||||
TLSReloadTime time.Duration
|
||||
NoAuth bool
|
||||
ProxyAuthUsername string
|
||||
AppendOnly bool
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue