diff --git a/README.md b/README.md index cc8dd63..fa69edd 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,9 @@ Flags: --path string data directory (default "/tmp/restic") --prometheus enable Prometheus metrics --tls turn on TLS support + --tls-cert string TLS certificate path + --tls-key string TLS key path + ``` By default the server persists backup data in `/tmp/restic`. Start the server with a custom persistence directory: @@ -88,7 +91,7 @@ The server uses an `.htpasswd` file to specify users. You can create such a fil htpasswd -s -c .htpasswd username ``` -By default the server uses HTTP protocol. This is not very secure since with Basic Authentication, username and passwords will travel in cleartext in every request. In order to enable TLS support just add the `-tls` argument and add a private and public key at the root of your persistence directory. +By default the server uses HTTP protocol. This is not very secure since with Basic Authentication, username and passwords will travel in cleartext in every request. In order to enable TLS support just add the `-tls` argument and add a private and public key at the root of your persistence directory. You may also specify private and public keys by --tls-cert and --tls-key Signed certificate is required by the restic backend, but if you just want to test the feature you can generate unsigned keys with the following commands: diff --git a/cmd/rest-server/main.go b/cmd/rest-server/main.go index 38acced..4696403 100644 --- a/cmd/rest-server/main.go +++ b/cmd/rest-server/main.go @@ -1,15 +1,15 @@ package main import ( + "errors" + restserver "github.com/restic/rest-server" + "github.com/spf13/cobra" "log" "net/http" "os" "path/filepath" "runtime" "runtime/pprof" - - restserver "github.com/restic/rest-server" - "github.com/spf13/cobra" ) // cmdRoot is the base command when no other command has been specified. @@ -29,12 +29,35 @@ func init() { flags.StringVar(&restserver.Config.Log, "log", restserver.Config.Log, "log HTTP requests in the combined log format") flags.StringVar(&restserver.Config.Path, "path", restserver.Config.Path, "data directory") flags.BoolVar(&restserver.Config.TLS, "tls", restserver.Config.TLS, "turn on TLS support") + flags.StringVar(&restserver.Config.TLSCert, "tls-cert", restserver.Config.TLSCert, "TLS certificate path") + flags.StringVar(&restserver.Config.TLSKey, "tls-key", restserver.Config.TLSKey, "TLS key path") flags.BoolVar(&restserver.Config.AppendOnly, "append-only", restserver.Config.AppendOnly, "enable append only mode") flags.BoolVar(&restserver.Config.Prometheus, "prometheus", restserver.Config.Prometheus, "enable Prometheus metrics") } var version = "manually" +func tlsSettings() (bool, string, string, error) { + var key, cert string + enabledTLS := restserver.Config.TLS + if !enabledTLS && (restserver.Config.TLSKey != "" || restserver.Config.TLSCert != "") { + return false, "", "", errors.New("requires enabled TLS") + } else if !enabledTLS { + return false, "", "", nil + } + if restserver.Config.TLSKey != "" { + key = restserver.Config.TLSKey + } else { + key = filepath.Join(restserver.Config.Path, "private_key") + } + if restserver.Config.TLSCert != "" { + cert = restserver.Config.TLSCert + } else { + cert = filepath.Join(restserver.Config.Path, "public_key") + } + return enabledTLS, key, cert, nil +} + func runRoot(cmd *cobra.Command, args []string) error { log.SetFlags(0) @@ -65,22 +88,24 @@ func runRoot(cmd *cobra.Command, args []string) error { log.Println("Authentication enabled") } - if !restserver.Config.TLS { + enabledTLS, privateKey, publicKey, err := tlsSettings() + if err != nil { + return err + } + if !enabledTLS { log.Printf("Starting server on %s\n", restserver.Config.Listen) err = http.ListenAndServe(restserver.Config.Listen, handler) } else { - privateKey := filepath.Join(restserver.Config.Path, "private_key") - publicKey := filepath.Join(restserver.Config.Path, "public_key") + log.Println("TLS enabled") log.Printf("Private key: %s", privateKey) - log.Printf("Public key: %s", publicKey) + log.Printf("Public key(certificate): %s", publicKey) log.Printf("Starting server on %s\n", restserver.Config.Listen) err = http.ListenAndServeTLS(restserver.Config.Listen, publicKey, privateKey, handler) } return err } - func main() { if err := cmdRoot.Execute(); err != nil { log.Fatalf("error: %v", err) diff --git a/cmd/rest-server/main_test.go b/cmd/rest-server/main_test.go new file mode 100644 index 0000000..5f41e92 --- /dev/null +++ b/cmd/rest-server/main_test.go @@ -0,0 +1,72 @@ +package main + +import ( + restserver "github.com/restic/rest-server" + "testing" +) + +func TestTLSSettings(t *testing.T) { + type expected struct { + TLSKey string + TLSCert string + Error bool + } + type passed struct { + Path string + TLS bool + TLSKey string + TLSCert string + } + + var tests = []struct { + passed passed + expected expected + }{ + {passed{TLS: false}, expected{"", "", false}}, + {passed{TLS: true}, expected{"/tmp/restic/private_key", "/tmp/restic/public_key", false}}, + {passed{Path: "/tmp", TLS: true}, expected{"/tmp/private_key", "/tmp/public_key", false}}, + {passed{Path: "/tmp", TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", false}}, + {passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", true}}, + {passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", true}}, + {passed{Path: "/tmp", TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", true}}, + } + + defaultConfig := restserver.Config + for _, test := range tests { + + t.Run("", func(t *testing.T) { + defer func() { restserver.Config = defaultConfig }() + if test.passed.Path != "" { + restserver.Config.Path = test.passed.Path + } + restserver.Config.TLS = test.passed.TLS + restserver.Config.TLSKey = test.passed.TLSKey + restserver.Config.TLSCert = test.passed.TLSCert + + gotTLS, gotKey, gotCert, err := tlsSettings() + if err != nil && !test.expected.Error { + t.Fatalf("tls_settings returned err (%v)", err) + } + if test.expected.Error { + if err == nil { + t.Fatalf("Error not returned properly (%v)", test) + } else { + return + } + } + if gotTLS != test.passed.TLS { + t.Errorf("TLS enabled, want (%v), got (%v)", test.passed.TLS, gotTLS) + } + wantKey := test.expected.TLSKey + if gotKey != wantKey { + t.Errorf("wrong TLSPrivPath path, want (%v), got (%v)", wantKey, gotKey) + } + + wantCert := test.expected.TLSCert + if gotCert != wantCert { + t.Errorf("wrong TLSCertPath path, want (%v), got (%v)", wantCert, gotCert) + } + + }) + } +} diff --git a/mux.go b/mux.go index 6d853f3..2b5a572 100644 --- a/mux.go +++ b/mux.go @@ -19,6 +19,8 @@ var Config = struct { Log string CPUProfile string TLS bool + TLSKey string + TLSCert string AppendOnly bool Prometheus bool Debug bool