diff --git a/cmd/rest-server/main.go b/cmd/rest-server/main.go new file mode 100644 index 0000000..dc91b71 --- /dev/null +++ b/cmd/rest-server/main.go @@ -0,0 +1,90 @@ +package main + +import ( + "log" + "net/http" + "os" + "path/filepath" + "runtime" + "runtime/pprof" + + restserver "github.com/restic/rest-server" + "github.com/spf13/cobra" +) + +// cmdRoot is the base command when no other command has been specified. +var cmdRoot = &cobra.Command{ + Use: "rest-server", + Short: "Run a REST server for use with restic", + SilenceErrors: true, + SilenceUsage: true, + RunE: runRoot, +} + +func init() { + flags := cmdRoot.Flags() + flags.StringVar(&restserver.Config.CPUProfile, "cpuprofile", restserver.Config.CPUProfile, "write CPU profile to file") + flags.BoolVar(&restserver.Config.Debug, "debug", restserver.Config.Debug, "output debug messages") + flags.StringVar(&restserver.Config.Listen, "listen", restserver.Config.Listen, "listen address") + flags.StringVar(&restserver.Config.Log, "log", restserver.Config.Log, "log HTTP requests in the combined log format") + flags.StringVar(&restserver.Config.Path, "path", restserver.Config.Path, "data directory") + flags.BoolVar(&restserver.Config.TLS, "tls", restserver.Config.TLS, "turn on TLS support") +} + +var version = "manually" + +func runRoot(cmd *cobra.Command, args []string) error { + log.SetFlags(0) + + log.Printf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH) + log.Printf("Data directory: %s", restserver.Config.Path) + + if restserver.Config.CPUProfile != "" { + f, err := os.Create(restserver.Config.CPUProfile) + if err != nil { + log.Fatal(err) + } + if err := pprof.StartCPUProfile(f); err != nil { + log.Fatal(err) + } + log.Println("CPU profiling enabled") + defer pprof.StopCPUProfile() + } + + mux := restserver.NewMux() + + var handler http.Handler + htpasswdFile, err := restserver.NewHtpasswdFromFile(filepath.Join(restserver.Config.Path, ".htpasswd")) + if err != nil { + handler = mux + log.Println("Authentication disabled") + } else { + handler = restserver.AuthHandler(htpasswdFile, mux) + log.Println("Authentication enabled") + } + + if !restserver.Config.TLS { + log.Printf("Starting server on %s\n", restserver.Config.Listen) + err = http.ListenAndServe(restserver.Config.Listen, handler) + } else { + privateKey := filepath.Join(restserver.Config.Path, "private_key") + publicKey := filepath.Join(restserver.Config.Path, "public_key") + log.Println("TLS enabled") + log.Printf("Private key: %s", privateKey) + log.Printf("Public key: %s", publicKey) + log.Printf("Starting server on %s\n", restserver.Config.Listen) + err = http.ListenAndServeTLS(restserver.Config.Listen, publicKey, privateKey, handler) + } + if err != nil { + log.Fatal(err) + } + + return nil + +} + +func main() { + if err := cmdRoot.Execute(); err != nil { + log.Fatalf("error: %v", err) + } +} diff --git a/handlers.go b/handlers.go index 6313f3a..e02d63f 100644 --- a/handlers.go +++ b/handlers.go @@ -1,4 +1,4 @@ -package main +package restserver import ( "encoding/json" @@ -22,10 +22,10 @@ func isHashed(dir string) bool { func getRepo(r *http.Request) string { if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo") { - return filepath.Join(config.path, pat.Param(r, "repo")) + return filepath.Join(Config.Path, pat.Param(r, "repo")) } - return config.path + return Config.Path } // AuthHandler wraps h with a http.HandlerFunc that performs basic authentication against the user/passwords pairs @@ -43,13 +43,13 @@ func AuthHandler(f *HtpasswdFile, h http.Handler) http.HandlerFunc { // CheckConfig checks whether a configuration exists. func CheckConfig(w http.ResponseWriter, r *http.Request) { - if config.debug { + if Config.Debug { log.Println("CheckConfig()") } cfg := filepath.Join(getRepo(r), "config") st, err := os.Stat(cfg) if err != nil { - if config.debug { + if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) @@ -61,14 +61,14 @@ func CheckConfig(w http.ResponseWriter, r *http.Request) { // GetConfig allows for a config to be retrieved. func GetConfig(w http.ResponseWriter, r *http.Request) { - if config.debug { + if Config.Debug { log.Println("GetConfig()") } cfg := filepath.Join(getRepo(r), "config") bytes, err := ioutil.ReadFile(cfg) if err != nil { - if config.debug { + if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) @@ -80,14 +80,14 @@ func GetConfig(w http.ResponseWriter, r *http.Request) { // SaveConfig allows for a config to be saved. func SaveConfig(w http.ResponseWriter, r *http.Request) { - if config.debug { + if Config.Debug { log.Println("SaveConfig()") } cfg := filepath.Join(getRepo(r), "config") bytes, err := ioutil.ReadAll(r.Body) if err != nil { - if config.debug { + if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) @@ -95,7 +95,7 @@ func SaveConfig(w http.ResponseWriter, r *http.Request) { } if err := ioutil.WriteFile(cfg, bytes, 0600); err != nil { - if config.debug { + if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -105,12 +105,12 @@ func SaveConfig(w http.ResponseWriter, r *http.Request) { // DeleteConfig removes a config. func DeleteConfig(w http.ResponseWriter, r *http.Request) { - if config.debug { + if Config.Debug { log.Println("DeleteConfig()") } if err := os.Remove(filepath.Join(getRepo(r), "config")); err != nil { - if config.debug { + if Config.Debug { log.Print(err) } if os.IsNotExist(err) { @@ -124,7 +124,7 @@ func DeleteConfig(w http.ResponseWriter, r *http.Request) { // ListBlobs lists all blobs of a given type in an arbitrary order. func ListBlobs(w http.ResponseWriter, r *http.Request) { - if config.debug { + if Config.Debug { log.Println("ListBlobs()") } dir := pat.Param(r, "type") @@ -132,7 +132,7 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) { items, err := ioutil.ReadDir(path) if err != nil { - if config.debug { + if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) @@ -145,7 +145,7 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) { subpath := filepath.Join(path, i.Name()) subitems, err := ioutil.ReadDir(subpath) if err != nil { - if config.debug { + if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) @@ -161,7 +161,7 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) { data, err := json.Marshal(names) if err != nil { - if config.debug { + if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -173,7 +173,7 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) { // CheckBlob tests whether a blob exists. func CheckBlob(w http.ResponseWriter, r *http.Request) { - if config.debug { + if Config.Debug { log.Println("CheckBlob()") } dir := pat.Param(r, "type") @@ -186,7 +186,7 @@ func CheckBlob(w http.ResponseWriter, r *http.Request) { st, err := os.Stat(path) if err != nil { - if config.debug { + if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) @@ -198,7 +198,7 @@ func CheckBlob(w http.ResponseWriter, r *http.Request) { // GetBlob retrieves a blob from the repository. func GetBlob(w http.ResponseWriter, r *http.Request) { - if config.debug { + if Config.Debug { log.Println("GetBlob()") } dir := pat.Param(r, "type") @@ -211,7 +211,7 @@ func GetBlob(w http.ResponseWriter, r *http.Request) { file, err := os.Open(path) if err != nil { - if config.debug { + if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) @@ -224,7 +224,7 @@ func GetBlob(w http.ResponseWriter, r *http.Request) { // SaveBlob saves a blob to the repository. func SaveBlob(w http.ResponseWriter, r *http.Request) { - if config.debug { + if Config.Debug { log.Println("SaveBlob()") } repo := getRepo(r) @@ -238,7 +238,7 @@ func SaveBlob(w http.ResponseWriter, r *http.Request) { tf, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600) if err != nil { - if config.debug { + if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -248,7 +248,7 @@ func SaveBlob(w http.ResponseWriter, r *http.Request) { if _, err := io.Copy(tf, r.Body); err != nil { tf.Close() os.Remove(path) - if config.debug { + if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) @@ -258,7 +258,7 @@ func SaveBlob(w http.ResponseWriter, r *http.Request) { if err := tf.Sync(); err != nil { tf.Close() os.Remove(path) - if config.debug { + if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -267,7 +267,7 @@ func SaveBlob(w http.ResponseWriter, r *http.Request) { if err := tf.Close(); err != nil { os.Remove(path) - if config.debug { + if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -277,7 +277,7 @@ func SaveBlob(w http.ResponseWriter, r *http.Request) { // DeleteBlob deletes a blob from the repository. func DeleteBlob(w http.ResponseWriter, r *http.Request) { - if config.debug { + if Config.Debug { log.Println("DeleteBlob()") } dir := pat.Param(r, "type") @@ -289,7 +289,7 @@ func DeleteBlob(w http.ResponseWriter, r *http.Request) { path := filepath.Join(getRepo(r), dir, name) if err := os.Remove(path); err != nil { - if config.debug { + if Config.Debug { log.Print(err) } if os.IsNotExist(err) { @@ -303,7 +303,7 @@ func DeleteBlob(w http.ResponseWriter, r *http.Request) { // CreateRepo creates repository directories. func CreateRepo(w http.ResponseWriter, r *http.Request) { - if config.debug { + if Config.Debug { log.Println("CreateRepo()") } repo := getRepo(r) diff --git a/htpasswd.go b/htpasswd.go index 3f9c0a0..49fe06e 100644 --- a/htpasswd.go +++ b/htpasswd.go @@ -1,4 +1,4 @@ -package main +package restserver /* Copied from: github.com/bitly/oauth2_proxy diff --git a/main.go b/main.go deleted file mode 100644 index 71e45e0..0000000 --- a/main.go +++ /dev/null @@ -1,154 +0,0 @@ -package main - -import ( - "log" - "net/http" - "os" - "path/filepath" - "runtime" - "runtime/pprof" - - "github.com/gorilla/handlers" - "github.com/spf13/cobra" - "goji.io" - "goji.io/pat" -) - -// cmdRoot is the base command when no other command has been specified. -var cmdRoot = &cobra.Command{ - Use: "rest-server", - Short: "Run a REST server for use with restic", - SilenceErrors: true, - SilenceUsage: true, - RunE: runRoot, -} - -var config = struct { - path string - listen string - tls bool - log string - cpuprofile string - debug bool -}{} - -func init() { - flags := cmdRoot.Flags() - flags.StringVar(&config.cpuprofile, "cpuprofile", "", "write CPU profile to file") - flags.BoolVar(&config.debug, "debug", false, "output debug messages") - flags.StringVar(&config.listen, "listen", ":8000", "listen address") - flags.StringVar(&config.log, "log", "", "log HTTP requests in the combined log format") - flags.StringVar(&config.path, "path", "/tmp/restic", "data directory") - flags.BoolVar(&config.tls, "tls", false, "turn on TLS support") -} - -func debugHandler(next http.Handler) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - log.Printf("%s %s", r.Method, r.URL) - next.ServeHTTP(w, r) - }) -} - -func logHandler(next http.Handler) http.Handler { - accessLog, err := os.OpenFile(config.log, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) - if err != nil { - log.Fatal(err) - } - - return handlers.CombinedLoggingHandler(accessLog, next) -} - -func setupMux() *goji.Mux { - mux := goji.NewMux() - - if config.debug { - mux.Use(debugHandler) - } - - if config.log != "" { - mux.Use(logHandler) - } - - mux.HandleFunc(pat.Head("/config"), CheckConfig) - mux.HandleFunc(pat.Head("/:repo/config"), CheckConfig) - mux.HandleFunc(pat.Get("/config"), GetConfig) - mux.HandleFunc(pat.Get("/:repo/config"), GetConfig) - mux.HandleFunc(pat.Post("/config"), SaveConfig) - mux.HandleFunc(pat.Post("/:repo/config"), SaveConfig) - mux.HandleFunc(pat.Delete("/config"), DeleteConfig) - mux.HandleFunc(pat.Delete("/:repo/config"), DeleteConfig) - mux.HandleFunc(pat.Get("/:type/"), ListBlobs) - mux.HandleFunc(pat.Get("/:repo/:type/"), ListBlobs) - mux.HandleFunc(pat.Head("/:type/:name"), CheckBlob) - mux.HandleFunc(pat.Head("/:repo/:type/:name"), CheckBlob) - mux.HandleFunc(pat.Get("/:type/:name"), GetBlob) - mux.HandleFunc(pat.Get("/:repo/:type/:name"), GetBlob) - mux.HandleFunc(pat.Post("/:type/:name"), SaveBlob) - mux.HandleFunc(pat.Post("/:repo/:type/:name"), SaveBlob) - mux.HandleFunc(pat.Delete("/:type/:name"), DeleteBlob) - mux.HandleFunc(pat.Delete("/:repo/:type/:name"), DeleteBlob) - mux.HandleFunc(pat.Post("/"), CreateRepo) - mux.HandleFunc(pat.Post("/:repo"), CreateRepo) - mux.HandleFunc(pat.Post("/:repo/"), CreateRepo) - - return mux -} - -var version = "manually" - -func runRoot(cmd *cobra.Command, args []string) error { - log.SetFlags(0) - - log.Printf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH) - log.Printf("Data directory: %s", config.path) - - if config.cpuprofile != "" { - f, err := os.Create(config.cpuprofile) - if err != nil { - log.Fatal(err) - } - if err := pprof.StartCPUProfile(f); err != nil { - log.Fatal(err) - } - log.Println("CPU profiling enabled") - defer pprof.StopCPUProfile() - } - - mux := setupMux() - - var handler http.Handler - htpasswdFile, err := NewHtpasswdFromFile(filepath.Join(config.path, ".htpasswd")) - if err != nil { - handler = mux - log.Println("Authentication disabled") - } else { - handler = AuthHandler(htpasswdFile, mux) - log.Println("Authentication enabled") - } - - if !config.tls { - log.Printf("Starting server on %s\n", config.listen) - err = http.ListenAndServe(config.listen, handler) - } else { - privateKey := filepath.Join(config.path, "private_key") - publicKey := filepath.Join(config.path, "public_key") - log.Println("TLS enabled") - log.Printf("Private key: %s", privateKey) - log.Printf("Public key: %s", publicKey) - log.Printf("Starting server on %s\n", config.listen) - err = http.ListenAndServeTLS(config.listen, publicKey, privateKey, handler) - } - if err != nil { - log.Fatal(err) - } - - return nil - -} - -func main() { - if err := cmdRoot.Execute(); err != nil { - log.Fatalf("error: %v", err) - } -} diff --git a/mux.go b/mux.go new file mode 100644 index 0000000..23c5072 --- /dev/null +++ b/mux.go @@ -0,0 +1,77 @@ +package restserver + +import ( + "log" + "net/http" + "os" + + goji "goji.io" + + "github.com/gorilla/handlers" + "goji.io/pat" +) + +var Config = struct { + Path string + Listen string + TLS bool + Log string + CPUProfile string + Debug bool +}{ + Path: "/tmp/restic", + Listen: ":8000", +} + +func debugHandler(next http.Handler) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + log.Printf("%s %s", r.Method, r.URL) + next.ServeHTTP(w, r) + }) +} + +func logHandler(next http.Handler) http.Handler { + accessLog, err := os.OpenFile(Config.Log, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + log.Fatal(err) + } + + return handlers.CombinedLoggingHandler(accessLog, next) +} + +func NewMux() *goji.Mux { + mux := goji.NewMux() + + if Config.Debug { + mux.Use(debugHandler) + } + + if Config.Log != "" { + mux.Use(logHandler) + } + + mux.HandleFunc(pat.Head("/config"), CheckConfig) + mux.HandleFunc(pat.Head("/:repo/config"), CheckConfig) + mux.HandleFunc(pat.Get("/config"), GetConfig) + mux.HandleFunc(pat.Get("/:repo/config"), GetConfig) + mux.HandleFunc(pat.Post("/config"), SaveConfig) + mux.HandleFunc(pat.Post("/:repo/config"), SaveConfig) + mux.HandleFunc(pat.Delete("/config"), DeleteConfig) + mux.HandleFunc(pat.Delete("/:repo/config"), DeleteConfig) + mux.HandleFunc(pat.Get("/:type/"), ListBlobs) + mux.HandleFunc(pat.Get("/:repo/:type/"), ListBlobs) + mux.HandleFunc(pat.Head("/:type/:name"), CheckBlob) + mux.HandleFunc(pat.Head("/:repo/:type/:name"), CheckBlob) + mux.HandleFunc(pat.Get("/:type/:name"), GetBlob) + mux.HandleFunc(pat.Get("/:repo/:type/:name"), GetBlob) + mux.HandleFunc(pat.Post("/:type/:name"), SaveBlob) + mux.HandleFunc(pat.Post("/:repo/:type/:name"), SaveBlob) + mux.HandleFunc(pat.Delete("/:type/:name"), DeleteBlob) + mux.HandleFunc(pat.Delete("/:repo/:type/:name"), DeleteBlob) + mux.HandleFunc(pat.Post("/"), CreateRepo) + mux.HandleFunc(pat.Post("/:repo"), CreateRepo) + mux.HandleFunc(pat.Post("/:repo/"), CreateRepo) + + return mux +}