mirror of
https://github.com/restic/rest-server.git
synced 2025-10-19 07:33:21 +00:00

This is a very simple implementation. It optionally polls the disk for changes to the key / cert files, and attempts to reload them if it detects the modification time or the size has changed.
264 lines
8.5 KiB
Go
264 lines
8.5 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"runtime"
|
|
"runtime/pprof"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
restserver "github.com/restic/rest-server"
|
|
"github.com/spf13/cobra"
|
|
)
|
|
|
|
type restServerApp struct {
|
|
CmdRoot *cobra.Command
|
|
Server restserver.Server
|
|
CPUProfile string
|
|
|
|
listenerAddressMu sync.Mutex
|
|
listenerAddress net.Addr // set after startup
|
|
}
|
|
|
|
// cmdRoot is the base command when no other command has been specified.
|
|
func newRestServerApp() *restServerApp {
|
|
rv := &restServerApp{
|
|
CmdRoot: &cobra.Command{
|
|
Use: "rest-server",
|
|
Short: "Run a REST server for use with restic",
|
|
SilenceErrors: true,
|
|
SilenceUsage: true,
|
|
Args: func(_ *cobra.Command, args []string) error {
|
|
if len(args) != 0 {
|
|
return fmt.Errorf("rest-server expects no arguments - unknown argument: %s", args[0])
|
|
}
|
|
return nil
|
|
},
|
|
Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH),
|
|
},
|
|
Server: restserver.Server{
|
|
Path: filepath.Join(os.TempDir(), "restic"),
|
|
Listen: ":8000",
|
|
TLSMinVer: "1.2",
|
|
TLSReloadTime: time.Minute,
|
|
},
|
|
}
|
|
rv.CmdRoot.RunE = rv.runRoot
|
|
flags := rv.CmdRoot.Flags()
|
|
|
|
flags.StringVar(&rv.CPUProfile, "cpu-profile", rv.CPUProfile, "write CPU profile to file")
|
|
flags.BoolVar(&rv.Server.Debug, "debug", rv.Server.Debug, "output debug messages")
|
|
flags.StringVar(&rv.Server.Listen, "listen", rv.Server.Listen, "listen address")
|
|
flags.StringVar(&rv.Server.Log, "log", rv.Server.Log, "write HTTP requests in the combined log format to the specified `filename` (use \"-\" for logging to stdout)")
|
|
flags.Int64Var(&rv.Server.MaxRepoSize, "max-size", rv.Server.MaxRepoSize, "the maximum size of the repository in bytes")
|
|
flags.StringVar(&rv.Server.Path, "path", rv.Server.Path, "data directory")
|
|
flags.BoolVar(&rv.Server.TLS, "tls", rv.Server.TLS, "turn on TLS support")
|
|
flags.StringVar(&rv.Server.TLSCert, "tls-cert", rv.Server.TLSCert, "TLS certificate path")
|
|
flags.StringVar(&rv.Server.TLSKey, "tls-key", rv.Server.TLSKey, "TLS key path")
|
|
flags.BoolVar(&rv.Server.TLSDynamicReload, "tls-load-dyn", rv.Server.TLSDynamicReload, "dynamically reload TLS key and cert file from disk if they change")
|
|
flags.DurationVar(&rv.Server.TLSReloadTime, "tls-load-dyn-poll", rv.Server.TLSReloadTime, "poll at most once per interval when tls-load-dyn is enabled")
|
|
|
|
flags.StringVar(&rv.Server.TLSMinVer, "tls-min-ver", rv.Server.TLSMinVer, "TLS min version, one of (1.2|1.3)")
|
|
flags.BoolVar(&rv.Server.NoAuth, "no-auth", rv.Server.NoAuth, "disable authentication")
|
|
flags.StringVar(&rv.Server.HtpasswdPath, "htpasswd-file", rv.Server.HtpasswdPath, "location of .htpasswd file (default: \"<data directory>/.htpasswd)\"")
|
|
flags.StringVar(&rv.Server.ProxyAuthUsername, "proxy-auth-username", rv.Server.ProxyAuthUsername, "specifies the HTTP header containing the username for proxy-based authentication")
|
|
flags.BoolVar(&rv.Server.NoVerifyUpload, "no-verify-upload", rv.Server.NoVerifyUpload,
|
|
"do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device")
|
|
flags.BoolVar(&rv.Server.AppendOnly, "append-only", rv.Server.AppendOnly, "enable append only mode")
|
|
flags.BoolVar(&rv.Server.PrivateRepos, "private-repos", rv.Server.PrivateRepos, "users can only access their private repo")
|
|
flags.BoolVar(&rv.Server.Prometheus, "prometheus", rv.Server.Prometheus, "enable Prometheus metrics")
|
|
flags.BoolVar(&rv.Server.PrometheusNoAuth, "prometheus-no-auth", rv.Server.PrometheusNoAuth, "disable auth for Prometheus /metrics endpoint")
|
|
flags.BoolVar(&rv.Server.GroupAccessibleRepos, "group-accessible-repos", rv.Server.GroupAccessibleRepos, "let filesystem group be able to access repo files")
|
|
|
|
return rv
|
|
}
|
|
|
|
var version = "0.13.0"
|
|
|
|
func (app *restServerApp) tlsSettings() (bool, string, string, error) {
|
|
var key, cert string
|
|
if !app.Server.TLS && (app.Server.TLSKey != "" || app.Server.TLSCert != "") {
|
|
return false, "", "", errors.New("requires enabled TLS")
|
|
} else if !app.Server.TLS {
|
|
return false, "", "", nil
|
|
}
|
|
if app.Server.TLSKey != "" {
|
|
key = app.Server.TLSKey
|
|
} else {
|
|
key = filepath.Join(app.Server.Path, "private_key")
|
|
}
|
|
if app.Server.TLSCert != "" {
|
|
cert = app.Server.TLSCert
|
|
} else {
|
|
cert = filepath.Join(app.Server.Path, "public_key")
|
|
}
|
|
return app.Server.TLS, key, cert, nil
|
|
}
|
|
|
|
// returns the address that the app is listening on.
|
|
// returns nil if the application hasn't finished starting yet
|
|
func (app *restServerApp) ListenerAddress() net.Addr {
|
|
app.listenerAddressMu.Lock()
|
|
defer app.listenerAddressMu.Unlock()
|
|
return app.listenerAddress
|
|
}
|
|
|
|
func (app *restServerApp) runRoot(_ *cobra.Command, _ []string) error {
|
|
log.SetFlags(0)
|
|
|
|
log.Printf("Data directory: %s", app.Server.Path)
|
|
|
|
if app.CPUProfile != "" {
|
|
f, err := os.Create(app.CPUProfile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
_ = f.Close()
|
|
}()
|
|
|
|
if err := pprof.StartCPUProfile(f); err != nil {
|
|
return err
|
|
}
|
|
defer pprof.StopCPUProfile()
|
|
|
|
log.Println("CPU profiling enabled")
|
|
defer log.Println("Stopped CPU profiling")
|
|
}
|
|
|
|
if app.Server.NoAuth {
|
|
log.Println("Authentication disabled")
|
|
} else {
|
|
if app.Server.ProxyAuthUsername == "" {
|
|
log.Println("Authentication enabled")
|
|
} else {
|
|
log.Println("Proxy Authentication enabled.")
|
|
}
|
|
}
|
|
|
|
handler, err := restserver.NewHandler(&app.Server)
|
|
if err != nil {
|
|
log.Fatalf("error: %v", err)
|
|
}
|
|
|
|
if app.Server.AppendOnly {
|
|
log.Println("Append only mode enabled")
|
|
} else {
|
|
log.Println("Append only mode disabled")
|
|
}
|
|
|
|
if app.Server.PrivateRepos {
|
|
log.Println("Private repositories enabled")
|
|
} else {
|
|
log.Println("Private repositories disabled")
|
|
}
|
|
|
|
if app.Server.GroupAccessibleRepos {
|
|
log.Println("Group accessible repos enabled")
|
|
} else {
|
|
log.Println("Group accessible repos disabled")
|
|
}
|
|
|
|
enabledTLS, privateKey, publicKey, err := app.tlsSettings()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
listener, err := findListener(app.Server.Listen)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to listen: %w", err)
|
|
}
|
|
|
|
// set listener address, this is useful for tests
|
|
app.listenerAddressMu.Lock()
|
|
app.listenerAddress = listener.Addr()
|
|
app.listenerAddressMu.Unlock()
|
|
|
|
tlscfg := &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
CipherSuites: []uint16{
|
|
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
|
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
|
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
|
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
|
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
|
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
|
},
|
|
}
|
|
switch app.Server.TLSMinVer {
|
|
case "1.2":
|
|
tlscfg.MinVersion = tls.VersionTLS12
|
|
case "1.3":
|
|
tlscfg.MinVersion = tls.VersionTLS13
|
|
default:
|
|
return fmt.Errorf("Unsupported TLS min version: %s. Allowed versions are 1.2 or 1.3", app.Server.TLSMinVer)
|
|
}
|
|
srv := &http.Server{
|
|
Handler: handler,
|
|
TLSConfig: tlscfg,
|
|
}
|
|
|
|
if enabledTLS {
|
|
if app.Server.TLSDynamicReload {
|
|
dc, err := newDynamicChecker(publicKey, privateKey)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to load key pair: %w", err)
|
|
}
|
|
dc.poll(app.CmdRoot.Context(), app.Server.TLSReloadTime)
|
|
tlscfg.GetCertificate = dc.getCertificate
|
|
} else {
|
|
crt, err := tls.LoadX509KeyPair(publicKey, privateKey)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to load key pair: %w", err)
|
|
}
|
|
tlscfg.Certificates = []tls.Certificate{crt}
|
|
}
|
|
}
|
|
|
|
// run server in background
|
|
go func() {
|
|
if !enabledTLS {
|
|
err = srv.Serve(listener)
|
|
} else {
|
|
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
|
|
if app.Server.TLSDynamicReload {
|
|
log.Printf("TLS dynamic reloading enabled, will poll up to once every %s for changes", app.Server.TLSReloadTime)
|
|
}
|
|
err = srv.ServeTLS(listener, "", "")
|
|
}
|
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
log.Fatalf("listen and serve returned err: %v", err)
|
|
}
|
|
}()
|
|
|
|
// wait until done
|
|
<-app.CmdRoot.Context().Done()
|
|
|
|
// gracefully shutdown server
|
|
if err := srv.Shutdown(context.Background()); err != nil {
|
|
return fmt.Errorf("server shutdown returned an err: %w", err)
|
|
}
|
|
|
|
log.Println("shutdown cleanly")
|
|
return nil
|
|
}
|
|
|
|
func main() {
|
|
// create context to be notified on interrupt or term signal so that we can shutdown cleanly
|
|
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
|
defer stop()
|
|
|
|
if err := newRestServerApp().CmdRoot.ExecuteContext(ctx); err != nil {
|
|
log.Fatalf("error: %v", err)
|
|
}
|
|
}
|