Replace flag with cobra

This commit is contained in:
Alexander Neumann 2016-12-30 20:57:48 +01:00 committed by Zlatko Čalušić
parent 6bee34700d
commit 0f4373ed7f
2 changed files with 60 additions and 32 deletions

View file

@ -22,10 +22,10 @@ func isHashed(dir string) bool {
func getRepo(r *http.Request) string { func getRepo(r *http.Request) string {
if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo/") { if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo/") {
return filepath.Join(*path, pat.Param(r, "repo")) return filepath.Join(config.path, pat.Param(r, "repo"))
} }
return *path return config.path
} }
func createDirectories(path string) { func createDirectories(path string) {
@ -72,7 +72,7 @@ func AuthHandler(f *HtpasswdFile, h http.Handler) http.HandlerFunc {
// CheckConfig checks whether a configuration exists. // CheckConfig checks whether a configuration exists.
func CheckConfig(w http.ResponseWriter, r *http.Request) { func CheckConfig(w http.ResponseWriter, r *http.Request) {
if *debug { if config.debug {
log.Println("CheckConfig()") log.Println("CheckConfig()")
} }
config := filepath.Join(getRepo(r), "config") config := filepath.Join(getRepo(r), "config")
@ -87,7 +87,7 @@ func CheckConfig(w http.ResponseWriter, r *http.Request) {
// GetConfig allows for a config to be retrieved. // GetConfig allows for a config to be retrieved.
func GetConfig(w http.ResponseWriter, r *http.Request) { func GetConfig(w http.ResponseWriter, r *http.Request) {
if *debug { if config.debug {
log.Println("GetConfig()") log.Println("GetConfig()")
} }
config := filepath.Join(getRepo(r), "config") config := filepath.Join(getRepo(r), "config")
@ -102,7 +102,7 @@ func GetConfig(w http.ResponseWriter, r *http.Request) {
// SaveConfig allows for a config to be saved. // SaveConfig allows for a config to be saved.
func SaveConfig(w http.ResponseWriter, r *http.Request) { func SaveConfig(w http.ResponseWriter, r *http.Request) {
if *debug { if config.debug {
log.Println("SaveConfig()") log.Println("SaveConfig()")
} }
config := filepath.Join(getRepo(r), "config") config := filepath.Join(getRepo(r), "config")
@ -121,7 +121,7 @@ func SaveConfig(w http.ResponseWriter, r *http.Request) {
// ListBlobs lists all blobs of a given type in an arbitrary order. // ListBlobs lists all blobs of a given type in an arbitrary order.
func ListBlobs(w http.ResponseWriter, r *http.Request) { func ListBlobs(w http.ResponseWriter, r *http.Request) {
if *debug { if config.debug {
log.Println("ListBlobs()") log.Println("ListBlobs()")
} }
dir := pat.Param(r, "type") dir := pat.Param(r, "type")
@ -161,7 +161,7 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) {
// CheckBlob tests whether a blob exists. // CheckBlob tests whether a blob exists.
func CheckBlob(w http.ResponseWriter, r *http.Request) { func CheckBlob(w http.ResponseWriter, r *http.Request) {
if *debug { if config.debug {
log.Println("CheckBlob()") log.Println("CheckBlob()")
} }
dir := pat.Param(r, "type") dir := pat.Param(r, "type")
@ -183,7 +183,7 @@ func CheckBlob(w http.ResponseWriter, r *http.Request) {
// GetBlob retrieves a blob from the repository. // GetBlob retrieves a blob from the repository.
func GetBlob(w http.ResponseWriter, r *http.Request) { func GetBlob(w http.ResponseWriter, r *http.Request) {
if *debug { if config.debug {
log.Println("GetBlob()") log.Println("GetBlob()")
} }
dir := pat.Param(r, "type") dir := pat.Param(r, "type")
@ -206,7 +206,7 @@ func GetBlob(w http.ResponseWriter, r *http.Request) {
// SaveBlob saves a blob to the repository. // SaveBlob saves a blob to the repository.
func SaveBlob(w http.ResponseWriter, r *http.Request) { func SaveBlob(w http.ResponseWriter, r *http.Request) {
if *debug { if config.debug {
log.Println("SaveBlob()") log.Println("SaveBlob()")
} }
repo := getRepo(r) repo := getRepo(r)
@ -262,7 +262,7 @@ func SaveBlob(w http.ResponseWriter, r *http.Request) {
// DeleteBlob deletes a blob from the repository. // DeleteBlob deletes a blob from the repository.
func DeleteBlob(w http.ResponseWriter, r *http.Request) { func DeleteBlob(w http.ResponseWriter, r *http.Request) {
if *debug { if config.debug {
log.Println("DeleteBlob()") log.Println("DeleteBlob()")
} }
dir := pat.Param(r, "type") dir := pat.Param(r, "type")

72
main.go
View file

@ -1,24 +1,43 @@
package main package main
import ( import (
"flag"
"log" "log"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"runtime/pprof" "runtime/pprof"
"github.com/spf13/cobra"
"goji.io" "goji.io"
"goji.io/pat" "goji.io/pat"
) )
var ( // cmdRoot is the base command when no other command has been specified.
path = flag.String("path", "/tmp/restic", "data directory") var cmdRoot = &cobra.Command{
listen = flag.String("listen", ":8000", "listen address") Use: "rest-server",
tls = flag.Bool("tls", false, "turn on TLS support") Short: "Run a REST server for use with restic",
cpuprofile = flag.String("cpuprofile", "", "write CPU profile to file") SilenceErrors: true,
debug = flag.Bool("debug", false, "output debug messages") SilenceUsage: true,
) RunE: runRoot,
}
var config = struct {
path string
listen string
tls bool
cpuprofile string
debug bool
}{}
func init() {
flags := cmdRoot.Flags()
flags.StringVar(&config.path, "path", "/tmp/restic", "data directory")
flags.StringVar(&config.listen, "listen", ":8000", "listen address")
flags.BoolVar(&config.tls, "tls", false, "turn on TLS support")
flags.StringVar(&config.cpuprofile, "cpuprofile", "", "write CPU profile to file")
flags.BoolVar(&config.debug, "debug", false, "output debug messages")
}
func debugHandler(next http.Handler) http.Handler { func debugHandler(next http.Handler) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(
@ -31,7 +50,7 @@ func debugHandler(next http.Handler) http.Handler {
func setupMux() *goji.Mux { func setupMux() *goji.Mux {
mux := goji.NewMux() mux := goji.NewMux()
if *debug { if config.debug {
mux.Use(debugHandler) mux.Use(debugHandler)
} }
@ -55,13 +74,11 @@ func setupMux() *goji.Mux {
return mux return mux
} }
func main() { func runRoot(cmd *cobra.Command, args []string) error {
log.SetFlags(0) log.SetFlags(0)
flag.Parse() if config.cpuprofile != "" {
f, err := os.Create(config.cpuprofile)
if *cpuprofile != "" {
f, err := os.Create(*cpuprofile)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -75,7 +92,7 @@ func main() {
mux := setupMux() mux := setupMux()
var handler http.Handler var handler http.Handler
htpasswdFile, err := NewHtpasswdFromFile(filepath.Join(*path, ".htpasswd")) htpasswdFile, err := NewHtpasswdFromFile(filepath.Join(config.path, ".htpasswd"))
if err != nil { if err != nil {
handler = mux handler = mux
log.Println("Authentication disabled") log.Println("Authentication disabled")
@ -84,19 +101,30 @@ func main() {
log.Println("Authentication enabled") log.Println("Authentication enabled")
} }
if !*tls { if !config.tls {
log.Printf("Starting server on %s\n", *listen) log.Printf("Starting server on %s\n", config.listen)
err = http.ListenAndServe(*listen, handler) err = http.ListenAndServe(config.listen, handler)
} else { } else {
privateKey := filepath.Join(*path, "private_key") privateKey := filepath.Join(config.path, "private_key")
publicKey := filepath.Join(*path, "public_key") publicKey := filepath.Join(config.path, "public_key")
log.Println("TLS enabled") log.Println("TLS enabled")
log.Printf("Private key: %s", privateKey) log.Printf("Private key: %s", privateKey)
log.Printf("Public key: %s", publicKey) log.Printf("Public key: %s", publicKey)
log.Printf("Starting server on %s\n", *listen) log.Printf("Starting server on %s\n", config.listen)
err = http.ListenAndServeTLS(*listen, publicKey, privateKey, handler) err = http.ListenAndServeTLS(config.listen, publicKey, privateKey, handler)
} }
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
return nil
}
func main() {
err := cmdRoot.Execute()
if err != nil {
log.Printf("error: %v", err)
os.Exit(1)
}
} }