Move main function into separate package (closes #12)

This commit is contained in:
Matthew Holt 2017-06-21 15:23:33 -06:00 committed by Zlatko Čalušić
parent 07b6d5facf
commit 65152c7bf5
5 changed files with 196 additions and 183 deletions

90
cmd/rest-server/main.go Normal file
View file

@ -0,0 +1,90 @@
package main
import (
"log"
"net/http"
"os"
"path/filepath"
"runtime"
"runtime/pprof"
restserver "github.com/restic/rest-server"
"github.com/spf13/cobra"
)
// cmdRoot is the base command when no other command has been specified.
var cmdRoot = &cobra.Command{
Use: "rest-server",
Short: "Run a REST server for use with restic",
SilenceErrors: true,
SilenceUsage: true,
RunE: runRoot,
}
func init() {
flags := cmdRoot.Flags()
flags.StringVar(&restserver.Config.CPUProfile, "cpuprofile", restserver.Config.CPUProfile, "write CPU profile to file")
flags.BoolVar(&restserver.Config.Debug, "debug", restserver.Config.Debug, "output debug messages")
flags.StringVar(&restserver.Config.Listen, "listen", restserver.Config.Listen, "listen address")
flags.StringVar(&restserver.Config.Log, "log", restserver.Config.Log, "log HTTP requests in the combined log format")
flags.StringVar(&restserver.Config.Path, "path", restserver.Config.Path, "data directory")
flags.BoolVar(&restserver.Config.TLS, "tls", restserver.Config.TLS, "turn on TLS support")
}
var version = "manually"
func runRoot(cmd *cobra.Command, args []string) error {
log.SetFlags(0)
log.Printf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH)
log.Printf("Data directory: %s", restserver.Config.Path)
if restserver.Config.CPUProfile != "" {
f, err := os.Create(restserver.Config.CPUProfile)
if err != nil {
log.Fatal(err)
}
if err := pprof.StartCPUProfile(f); err != nil {
log.Fatal(err)
}
log.Println("CPU profiling enabled")
defer pprof.StopCPUProfile()
}
mux := restserver.NewMux()
var handler http.Handler
htpasswdFile, err := restserver.NewHtpasswdFromFile(filepath.Join(restserver.Config.Path, ".htpasswd"))
if err != nil {
handler = mux
log.Println("Authentication disabled")
} else {
handler = restserver.AuthHandler(htpasswdFile, mux)
log.Println("Authentication enabled")
}
if !restserver.Config.TLS {
log.Printf("Starting server on %s\n", restserver.Config.Listen)
err = http.ListenAndServe(restserver.Config.Listen, handler)
} else {
privateKey := filepath.Join(restserver.Config.Path, "private_key")
publicKey := filepath.Join(restserver.Config.Path, "public_key")
log.Println("TLS enabled")
log.Printf("Private key: %s", privateKey)
log.Printf("Public key: %s", publicKey)
log.Printf("Starting server on %s\n", restserver.Config.Listen)
err = http.ListenAndServeTLS(restserver.Config.Listen, publicKey, privateKey, handler)
}
if err != nil {
log.Fatal(err)
}
return nil
}
func main() {
if err := cmdRoot.Execute(); err != nil {
log.Fatalf("error: %v", err)
}
}

View file

@ -1,4 +1,4 @@
package main package restserver
import ( import (
"encoding/json" "encoding/json"
@ -22,10 +22,10 @@ func isHashed(dir string) bool {
func getRepo(r *http.Request) string { func getRepo(r *http.Request) string {
if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo") { if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo") {
return filepath.Join(config.path, pat.Param(r, "repo")) return filepath.Join(Config.Path, pat.Param(r, "repo"))
} }
return config.path return Config.Path
} }
// AuthHandler wraps h with a http.HandlerFunc that performs basic authentication against the user/passwords pairs // AuthHandler wraps h with a http.HandlerFunc that performs basic authentication against the user/passwords pairs
@ -43,13 +43,13 @@ func AuthHandler(f *HtpasswdFile, h http.Handler) http.HandlerFunc {
// CheckConfig checks whether a configuration exists. // CheckConfig checks whether a configuration exists.
func CheckConfig(w http.ResponseWriter, r *http.Request) { func CheckConfig(w http.ResponseWriter, r *http.Request) {
if config.debug { if Config.Debug {
log.Println("CheckConfig()") log.Println("CheckConfig()")
} }
cfg := filepath.Join(getRepo(r), "config") cfg := filepath.Join(getRepo(r), "config")
st, err := os.Stat(cfg) st, err := os.Stat(cfg)
if err != nil { if err != nil {
if config.debug { if Config.Debug {
log.Print(err) log.Print(err)
} }
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
@ -61,14 +61,14 @@ func CheckConfig(w http.ResponseWriter, r *http.Request) {
// GetConfig allows for a config to be retrieved. // GetConfig allows for a config to be retrieved.
func GetConfig(w http.ResponseWriter, r *http.Request) { func GetConfig(w http.ResponseWriter, r *http.Request) {
if config.debug { if Config.Debug {
log.Println("GetConfig()") log.Println("GetConfig()")
} }
cfg := filepath.Join(getRepo(r), "config") cfg := filepath.Join(getRepo(r), "config")
bytes, err := ioutil.ReadFile(cfg) bytes, err := ioutil.ReadFile(cfg)
if err != nil { if err != nil {
if config.debug { if Config.Debug {
log.Print(err) log.Print(err)
} }
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
@ -80,14 +80,14 @@ func GetConfig(w http.ResponseWriter, r *http.Request) {
// SaveConfig allows for a config to be saved. // SaveConfig allows for a config to be saved.
func SaveConfig(w http.ResponseWriter, r *http.Request) { func SaveConfig(w http.ResponseWriter, r *http.Request) {
if config.debug { if Config.Debug {
log.Println("SaveConfig()") log.Println("SaveConfig()")
} }
cfg := filepath.Join(getRepo(r), "config") cfg := filepath.Join(getRepo(r), "config")
bytes, err := ioutil.ReadAll(r.Body) bytes, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
if config.debug { if Config.Debug {
log.Print(err) log.Print(err)
} }
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
@ -95,7 +95,7 @@ func SaveConfig(w http.ResponseWriter, r *http.Request) {
} }
if err := ioutil.WriteFile(cfg, bytes, 0600); err != nil { if err := ioutil.WriteFile(cfg, bytes, 0600); err != nil {
if config.debug { if Config.Debug {
log.Print(err) log.Print(err)
} }
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
@ -105,12 +105,12 @@ func SaveConfig(w http.ResponseWriter, r *http.Request) {
// DeleteConfig removes a config. // DeleteConfig removes a config.
func DeleteConfig(w http.ResponseWriter, r *http.Request) { func DeleteConfig(w http.ResponseWriter, r *http.Request) {
if config.debug { if Config.Debug {
log.Println("DeleteConfig()") log.Println("DeleteConfig()")
} }
if err := os.Remove(filepath.Join(getRepo(r), "config")); err != nil { if err := os.Remove(filepath.Join(getRepo(r), "config")); err != nil {
if config.debug { if Config.Debug {
log.Print(err) log.Print(err)
} }
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -124,7 +124,7 @@ func DeleteConfig(w http.ResponseWriter, r *http.Request) {
// ListBlobs lists all blobs of a given type in an arbitrary order. // ListBlobs lists all blobs of a given type in an arbitrary order.
func ListBlobs(w http.ResponseWriter, r *http.Request) { func ListBlobs(w http.ResponseWriter, r *http.Request) {
if config.debug { if Config.Debug {
log.Println("ListBlobs()") log.Println("ListBlobs()")
} }
dir := pat.Param(r, "type") dir := pat.Param(r, "type")
@ -132,7 +132,7 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) {
items, err := ioutil.ReadDir(path) items, err := ioutil.ReadDir(path)
if err != nil { if err != nil {
if config.debug { if Config.Debug {
log.Print(err) log.Print(err)
} }
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
@ -145,7 +145,7 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) {
subpath := filepath.Join(path, i.Name()) subpath := filepath.Join(path, i.Name())
subitems, err := ioutil.ReadDir(subpath) subitems, err := ioutil.ReadDir(subpath)
if err != nil { if err != nil {
if config.debug { if Config.Debug {
log.Print(err) log.Print(err)
} }
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
@ -161,7 +161,7 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) {
data, err := json.Marshal(names) data, err := json.Marshal(names)
if err != nil { if err != nil {
if config.debug { if Config.Debug {
log.Print(err) log.Print(err)
} }
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
@ -173,7 +173,7 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) {
// CheckBlob tests whether a blob exists. // CheckBlob tests whether a blob exists.
func CheckBlob(w http.ResponseWriter, r *http.Request) { func CheckBlob(w http.ResponseWriter, r *http.Request) {
if config.debug { if Config.Debug {
log.Println("CheckBlob()") log.Println("CheckBlob()")
} }
dir := pat.Param(r, "type") dir := pat.Param(r, "type")
@ -186,7 +186,7 @@ func CheckBlob(w http.ResponseWriter, r *http.Request) {
st, err := os.Stat(path) st, err := os.Stat(path)
if err != nil { if err != nil {
if config.debug { if Config.Debug {
log.Print(err) log.Print(err)
} }
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
@ -198,7 +198,7 @@ func CheckBlob(w http.ResponseWriter, r *http.Request) {
// GetBlob retrieves a blob from the repository. // GetBlob retrieves a blob from the repository.
func GetBlob(w http.ResponseWriter, r *http.Request) { func GetBlob(w http.ResponseWriter, r *http.Request) {
if config.debug { if Config.Debug {
log.Println("GetBlob()") log.Println("GetBlob()")
} }
dir := pat.Param(r, "type") dir := pat.Param(r, "type")
@ -211,7 +211,7 @@ func GetBlob(w http.ResponseWriter, r *http.Request) {
file, err := os.Open(path) file, err := os.Open(path)
if err != nil { if err != nil {
if config.debug { if Config.Debug {
log.Print(err) log.Print(err)
} }
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
@ -224,7 +224,7 @@ func GetBlob(w http.ResponseWriter, r *http.Request) {
// SaveBlob saves a blob to the repository. // SaveBlob saves a blob to the repository.
func SaveBlob(w http.ResponseWriter, r *http.Request) { func SaveBlob(w http.ResponseWriter, r *http.Request) {
if config.debug { if Config.Debug {
log.Println("SaveBlob()") log.Println("SaveBlob()")
} }
repo := getRepo(r) repo := getRepo(r)
@ -238,7 +238,7 @@ func SaveBlob(w http.ResponseWriter, r *http.Request) {
tf, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600) tf, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600)
if err != nil { if err != nil {
if config.debug { if Config.Debug {
log.Print(err) log.Print(err)
} }
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
@ -248,7 +248,7 @@ func SaveBlob(w http.ResponseWriter, r *http.Request) {
if _, err := io.Copy(tf, r.Body); err != nil { if _, err := io.Copy(tf, r.Body); err != nil {
tf.Close() tf.Close()
os.Remove(path) os.Remove(path)
if config.debug { if Config.Debug {
log.Print(err) log.Print(err)
} }
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
@ -258,7 +258,7 @@ func SaveBlob(w http.ResponseWriter, r *http.Request) {
if err := tf.Sync(); err != nil { if err := tf.Sync(); err != nil {
tf.Close() tf.Close()
os.Remove(path) os.Remove(path)
if config.debug { if Config.Debug {
log.Print(err) log.Print(err)
} }
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
@ -267,7 +267,7 @@ func SaveBlob(w http.ResponseWriter, r *http.Request) {
if err := tf.Close(); err != nil { if err := tf.Close(); err != nil {
os.Remove(path) os.Remove(path)
if config.debug { if Config.Debug {
log.Print(err) log.Print(err)
} }
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
@ -277,7 +277,7 @@ func SaveBlob(w http.ResponseWriter, r *http.Request) {
// DeleteBlob deletes a blob from the repository. // DeleteBlob deletes a blob from the repository.
func DeleteBlob(w http.ResponseWriter, r *http.Request) { func DeleteBlob(w http.ResponseWriter, r *http.Request) {
if config.debug { if Config.Debug {
log.Println("DeleteBlob()") log.Println("DeleteBlob()")
} }
dir := pat.Param(r, "type") dir := pat.Param(r, "type")
@ -289,7 +289,7 @@ func DeleteBlob(w http.ResponseWriter, r *http.Request) {
path := filepath.Join(getRepo(r), dir, name) path := filepath.Join(getRepo(r), dir, name)
if err := os.Remove(path); err != nil { if err := os.Remove(path); err != nil {
if config.debug { if Config.Debug {
log.Print(err) log.Print(err)
} }
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -303,7 +303,7 @@ func DeleteBlob(w http.ResponseWriter, r *http.Request) {
// CreateRepo creates repository directories. // CreateRepo creates repository directories.
func CreateRepo(w http.ResponseWriter, r *http.Request) { func CreateRepo(w http.ResponseWriter, r *http.Request) {
if config.debug { if Config.Debug {
log.Println("CreateRepo()") log.Println("CreateRepo()")
} }
repo := getRepo(r) repo := getRepo(r)

View file

@ -1,4 +1,4 @@
package main package restserver
/* /*
Copied from: github.com/bitly/oauth2_proxy Copied from: github.com/bitly/oauth2_proxy

154
main.go
View file

@ -1,154 +0,0 @@
package main
import (
"log"
"net/http"
"os"
"path/filepath"
"runtime"
"runtime/pprof"
"github.com/gorilla/handlers"
"github.com/spf13/cobra"
"goji.io"
"goji.io/pat"
)
// cmdRoot is the base command when no other command has been specified.
var cmdRoot = &cobra.Command{
Use: "rest-server",
Short: "Run a REST server for use with restic",
SilenceErrors: true,
SilenceUsage: true,
RunE: runRoot,
}
var config = struct {
path string
listen string
tls bool
log string
cpuprofile string
debug bool
}{}
func init() {
flags := cmdRoot.Flags()
flags.StringVar(&config.cpuprofile, "cpuprofile", "", "write CPU profile to file")
flags.BoolVar(&config.debug, "debug", false, "output debug messages")
flags.StringVar(&config.listen, "listen", ":8000", "listen address")
flags.StringVar(&config.log, "log", "", "log HTTP requests in the combined log format")
flags.StringVar(&config.path, "path", "/tmp/restic", "data directory")
flags.BoolVar(&config.tls, "tls", false, "turn on TLS support")
}
func debugHandler(next http.Handler) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
log.Printf("%s %s", r.Method, r.URL)
next.ServeHTTP(w, r)
})
}
func logHandler(next http.Handler) http.Handler {
accessLog, err := os.OpenFile(config.log, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
log.Fatal(err)
}
return handlers.CombinedLoggingHandler(accessLog, next)
}
func setupMux() *goji.Mux {
mux := goji.NewMux()
if config.debug {
mux.Use(debugHandler)
}
if config.log != "" {
mux.Use(logHandler)
}
mux.HandleFunc(pat.Head("/config"), CheckConfig)
mux.HandleFunc(pat.Head("/:repo/config"), CheckConfig)
mux.HandleFunc(pat.Get("/config"), GetConfig)
mux.HandleFunc(pat.Get("/:repo/config"), GetConfig)
mux.HandleFunc(pat.Post("/config"), SaveConfig)
mux.HandleFunc(pat.Post("/:repo/config"), SaveConfig)
mux.HandleFunc(pat.Delete("/config"), DeleteConfig)
mux.HandleFunc(pat.Delete("/:repo/config"), DeleteConfig)
mux.HandleFunc(pat.Get("/:type/"), ListBlobs)
mux.HandleFunc(pat.Get("/:repo/:type/"), ListBlobs)
mux.HandleFunc(pat.Head("/:type/:name"), CheckBlob)
mux.HandleFunc(pat.Head("/:repo/:type/:name"), CheckBlob)
mux.HandleFunc(pat.Get("/:type/:name"), GetBlob)
mux.HandleFunc(pat.Get("/:repo/:type/:name"), GetBlob)
mux.HandleFunc(pat.Post("/:type/:name"), SaveBlob)
mux.HandleFunc(pat.Post("/:repo/:type/:name"), SaveBlob)
mux.HandleFunc(pat.Delete("/:type/:name"), DeleteBlob)
mux.HandleFunc(pat.Delete("/:repo/:type/:name"), DeleteBlob)
mux.HandleFunc(pat.Post("/"), CreateRepo)
mux.HandleFunc(pat.Post("/:repo"), CreateRepo)
mux.HandleFunc(pat.Post("/:repo/"), CreateRepo)
return mux
}
var version = "manually"
func runRoot(cmd *cobra.Command, args []string) error {
log.SetFlags(0)
log.Printf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH)
log.Printf("Data directory: %s", config.path)
if config.cpuprofile != "" {
f, err := os.Create(config.cpuprofile)
if err != nil {
log.Fatal(err)
}
if err := pprof.StartCPUProfile(f); err != nil {
log.Fatal(err)
}
log.Println("CPU profiling enabled")
defer pprof.StopCPUProfile()
}
mux := setupMux()
var handler http.Handler
htpasswdFile, err := NewHtpasswdFromFile(filepath.Join(config.path, ".htpasswd"))
if err != nil {
handler = mux
log.Println("Authentication disabled")
} else {
handler = AuthHandler(htpasswdFile, mux)
log.Println("Authentication enabled")
}
if !config.tls {
log.Printf("Starting server on %s\n", config.listen)
err = http.ListenAndServe(config.listen, handler)
} else {
privateKey := filepath.Join(config.path, "private_key")
publicKey := filepath.Join(config.path, "public_key")
log.Println("TLS enabled")
log.Printf("Private key: %s", privateKey)
log.Printf("Public key: %s", publicKey)
log.Printf("Starting server on %s\n", config.listen)
err = http.ListenAndServeTLS(config.listen, publicKey, privateKey, handler)
}
if err != nil {
log.Fatal(err)
}
return nil
}
func main() {
if err := cmdRoot.Execute(); err != nil {
log.Fatalf("error: %v", err)
}
}

77
mux.go Normal file
View file

@ -0,0 +1,77 @@
package restserver
import (
"log"
"net/http"
"os"
goji "goji.io"
"github.com/gorilla/handlers"
"goji.io/pat"
)
var Config = struct {
Path string
Listen string
TLS bool
Log string
CPUProfile string
Debug bool
}{
Path: "/tmp/restic",
Listen: ":8000",
}
func debugHandler(next http.Handler) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
log.Printf("%s %s", r.Method, r.URL)
next.ServeHTTP(w, r)
})
}
func logHandler(next http.Handler) http.Handler {
accessLog, err := os.OpenFile(Config.Log, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
log.Fatal(err)
}
return handlers.CombinedLoggingHandler(accessLog, next)
}
func NewMux() *goji.Mux {
mux := goji.NewMux()
if Config.Debug {
mux.Use(debugHandler)
}
if Config.Log != "" {
mux.Use(logHandler)
}
mux.HandleFunc(pat.Head("/config"), CheckConfig)
mux.HandleFunc(pat.Head("/:repo/config"), CheckConfig)
mux.HandleFunc(pat.Get("/config"), GetConfig)
mux.HandleFunc(pat.Get("/:repo/config"), GetConfig)
mux.HandleFunc(pat.Post("/config"), SaveConfig)
mux.HandleFunc(pat.Post("/:repo/config"), SaveConfig)
mux.HandleFunc(pat.Delete("/config"), DeleteConfig)
mux.HandleFunc(pat.Delete("/:repo/config"), DeleteConfig)
mux.HandleFunc(pat.Get("/:type/"), ListBlobs)
mux.HandleFunc(pat.Get("/:repo/:type/"), ListBlobs)
mux.HandleFunc(pat.Head("/:type/:name"), CheckBlob)
mux.HandleFunc(pat.Head("/:repo/:type/:name"), CheckBlob)
mux.HandleFunc(pat.Get("/:type/:name"), GetBlob)
mux.HandleFunc(pat.Get("/:repo/:type/:name"), GetBlob)
mux.HandleFunc(pat.Post("/:type/:name"), SaveBlob)
mux.HandleFunc(pat.Post("/:repo/:type/:name"), SaveBlob)
mux.HandleFunc(pat.Delete("/:type/:name"), DeleteBlob)
mux.HandleFunc(pat.Delete("/:repo/:type/:name"), DeleteBlob)
mux.HandleFunc(pat.Post("/"), CreateRepo)
mux.HandleFunc(pat.Post("/:repo"), CreateRepo)
mux.HandleFunc(pat.Post("/:repo/"), CreateRepo)
return mux
}