From fa0b53efe7a8d9a76ebdb6086f94461f5df71588 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zlatko=20=C4=8Calu=C5=A1i=C4=87?= Date: Tue, 27 Dec 2016 13:42:43 +0100 Subject: [PATCH] Use goji.io mux Remove old router implementation. --- handlers.go | 417 +++++++++++++++++++++++-------------------------- main.go | 45 +++--- router.go | 137 ---------------- router_test.go | 72 --------- 4 files changed, 222 insertions(+), 449 deletions(-) delete mode 100644 router.go delete mode 100644 router_test.go diff --git a/handlers.go b/handlers.go index 5d8846a..3924563 100644 --- a/handlers.go +++ b/handlers.go @@ -13,16 +13,11 @@ import ( "time" ) -// Context contains repository metadata. -type Context struct { - path string -} - func isHashed(dir string) bool { return dir == "data" } -func createDirectories(c *Context) { +func createDirectories(path string) { log.Println("Creating repository directories") dirs := []string{ @@ -35,13 +30,13 @@ func createDirectories(c *Context) { } for _, d := range dirs { - if err := os.MkdirAll(filepath.Join(c.path, d), 0700); err != nil { + if err := os.MkdirAll(filepath.Join(path, d), 0700); err != nil { log.Fatal(err) } } for i := 0; i < 256; i++ { - if err := os.MkdirAll(filepath.Join(c.path, "data", fmt.Sprintf("%02x", i)), 0700); err != nil { + if err := os.MkdirAll(filepath.Join(path, "data", fmt.Sprintf("%02x", i)), 0700); err != nil { log.Fatal(err) } } @@ -60,233 +55,217 @@ func AuthHandler(f *HtpasswdFile, h http.Handler) http.HandlerFunc { } } -// CheckConfig returns a http.HandlerFunc that checks whether a configuration exists. -func CheckConfig(c *Context) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if *debug { - log.Println("CheckConfig()") - } - config := filepath.Join(c.path, "config") - st, err := os.Stat(config) - if err != nil { - http.Error(w, "404 not found", 404) - return - } - - w.Header().Add("Content-Length", fmt.Sprint(st.Size())) +// CheckConfig checks whether a configuration exists. +func CheckConfig(w http.ResponseWriter, r *http.Request) { + if *debug { + log.Println("CheckConfig()") } + config := filepath.Join(*path, "config") + st, err := os.Stat(config) + if err != nil { + http.Error(w, "404 not found", 404) + return + } + + w.Header().Add("Content-Length", fmt.Sprint(st.Size())) } -// GetConfig returns a http.HandlerFunc that allows for a config to be retrieved. -func GetConfig(c *Context) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if *debug { - log.Println("GetConfig()") - } - config := filepath.Join(c.path, "config") - bytes, err := ioutil.ReadFile(config) - if err != nil { - http.Error(w, "404 not found", 404) - return - } - - w.Write(bytes) +// GetConfig allows for a config to be retrieved. +func GetConfig(w http.ResponseWriter, r *http.Request) { + if *debug { + log.Println("GetConfig()") } + config := filepath.Join(*path, "config") + bytes, err := ioutil.ReadFile(config) + if err != nil { + http.Error(w, "404 not found", 404) + return + } + + w.Write(bytes) } -// SaveConfig returns a http.HandlerFunc that allows for a config to be saved. -func SaveConfig(c *Context) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if *debug { - log.Println("SaveConfig()") - } - config := filepath.Join(c.path, "config") - bytes, err := ioutil.ReadAll(r.Body) - if err != nil { - http.Error(w, "400 bad request", 400) - return - } - if err := ioutil.WriteFile(config, bytes, 0600); err != nil { - http.Error(w, "500 internal server error", 500) - return - } - - w.Write([]byte("200 ok")) +// SaveConfig allows for a config to be saved. +func SaveConfig(w http.ResponseWriter, r *http.Request) { + if *debug { + log.Println("SaveConfig()") } + config := filepath.Join(*path, "config") + bytes, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, "400 bad request", 400) + return + } + if err := ioutil.WriteFile(config, bytes, 0600); err != nil { + http.Error(w, "500 internal server error", 500) + return + } + + w.Write([]byte("200 ok")) } -// ListBlobs returns a http.HandlerFunc that lists all blobs of a given type in an arbitrary order. -func ListBlobs(c *Context) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if *debug { - log.Println("ListBlobs()") - } - vars := strings.Split(r.RequestURI, "/") - dir := vars[1] - path := filepath.Join(c.path, dir) +// ListBlobs lists all blobs of a given type in an arbitrary order. +func ListBlobs(w http.ResponseWriter, r *http.Request) { + if *debug { + log.Println("ListBlobs()") + } + vars := strings.Split(r.RequestURI, "/") + dir := vars[1] + path := filepath.Join(*path, dir) - items, err := ioutil.ReadDir(path) - if err != nil { - http.Error(w, "404 not found", 404) - return - } + items, err := ioutil.ReadDir(path) + if err != nil { + http.Error(w, "404 not found", 404) + return + } - var names []string - for _, i := range items { - if isHashed(dir) { - subpath := filepath.Join(path, i.Name()) - subitems, err := ioutil.ReadDir(subpath) - if err != nil { - http.Error(w, "404 not found", 404) - return - } - for _, f := range subitems { - names = append(names, f.Name()) - } - } else { - names = append(names, i.Name()) + var names []string + for _, i := range items { + if isHashed(dir) { + subpath := filepath.Join(path, i.Name()) + subitems, err := ioutil.ReadDir(subpath) + if err != nil { + http.Error(w, "404 not found", 404) + return } - } - - data, err := json.Marshal(names) - if err != nil { - http.Error(w, "500 internal server error", 500) - return - } - - w.Write(data) - } -} - -// CheckBlob returns a http.HandlerFunc that tests whether a blob exists and returns 200, if it does, or 404 otherwise. -func CheckBlob(c *Context) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if *debug { - log.Println("CheckBlob()") - } - vars := strings.Split(r.RequestURI, "/") - dir := vars[1] - name := vars[2] - - if isHashed(dir) { - name = filepath.Join(name[:2], name) - } - path := filepath.Join(c.path, dir, name) - - st, err := os.Stat(path) - if err != nil { - http.Error(w, "404 not found", 404) - return - } - - w.Header().Add("Content-Length", fmt.Sprint(st.Size())) - } -} - -// GetBlob returns a http.HandlerFunc that retrieves a blob from the repository. -func GetBlob(c *Context) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if *debug { - log.Println("GetBlob()") - } - vars := strings.Split(r.RequestURI, "/") - dir := vars[1] - name := vars[2] - - if isHashed(dir) { - name = filepath.Join(name[:2], name) - } - path := filepath.Join(c.path, dir, name) - - file, err := os.Open(path) - if err != nil { - http.Error(w, "404 not found", 404) - return - } - - http.ServeContent(w, r, "", time.Unix(0, 0), file) - file.Close() - } -} - -// SaveBlob returns a http.HandlerFunc that saves a blob to the repository. -func SaveBlob(c *Context) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if *debug { - log.Println("SaveBlob()") - } - vars := strings.Split(r.RequestURI, "/") - dir := vars[1] - name := vars[2] - - if dir == "keys" { - if _, err := os.Stat("keys"); err != nil && os.IsNotExist(err) { - createDirectories(c) + for _, f := range subitems { + names = append(names, f.Name()) } + } else { + names = append(names, i.Name()) } - - tmp := filepath.Join(c.path, "tmp", name) - - tf, err := os.OpenFile(tmp, os.O_CREATE|os.O_WRONLY, 0600) - if err != nil { - http.Error(w, "500 internal server error", 500) - return - } - - if _, err := io.Copy(tf, r.Body); err != nil { - tf.Close() - os.Remove(tmp) - http.Error(w, "400 bad request", 400) - return - } - if err := tf.Sync(); err != nil { - tf.Close() - os.Remove(tmp) - http.Error(w, "500 internal server error", 500) - return - } - if err := tf.Close(); err != nil { - os.Remove(tmp) - http.Error(w, "500 internal server error", 500) - return - } - - if isHashed(dir) { - name = filepath.Join(name[:2], name) - } - path := filepath.Join(c.path, dir, name) - - if err := os.Rename(tmp, path); err != nil { - os.Remove(tmp) - os.Remove(path) - http.Error(w, "500 internal server error", 500) - return - } - - w.Write([]byte("200 ok")) } + + data, err := json.Marshal(names) + if err != nil { + http.Error(w, "500 internal server error", 500) + return + } + + w.Write(data) } -// DeleteBlob returns a http.HandlerFunc that deletes a blob from the repository. -func DeleteBlob(c *Context) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if *debug { - log.Println("DeleteBlob()") - } - vars := strings.Split(r.RequestURI, "/") - dir := vars[1] - name := vars[2] - - if isHashed(dir) { - name = filepath.Join(name[:2], name) - } - path := filepath.Join(c.path, dir, name) - - if err := os.Remove(path); err != nil { - http.Error(w, "500 internal server error", 500) - return - } - - w.Write([]byte("200 ok")) +// CheckBlob tests whether a blob exists and returns 200, if it does, or 404 otherwise. +func CheckBlob(w http.ResponseWriter, r *http.Request) { + if *debug { + log.Println("CheckBlob()") } + vars := strings.Split(r.RequestURI, "/") + dir := vars[1] + name := vars[2] + + if isHashed(dir) { + name = filepath.Join(name[:2], name) + } + path := filepath.Join(*path, dir, name) + + st, err := os.Stat(path) + if err != nil { + http.Error(w, "404 not found", 404) + return + } + + w.Header().Add("Content-Length", fmt.Sprint(st.Size())) +} + +// GetBlob retrieves a blob from the repository. +func GetBlob(w http.ResponseWriter, r *http.Request) { + if *debug { + log.Println("GetBlob()") + } + vars := strings.Split(r.RequestURI, "/") + dir := vars[1] + name := vars[2] + + if isHashed(dir) { + name = filepath.Join(name[:2], name) + } + path := filepath.Join(*path, dir, name) + + file, err := os.Open(path) + if err != nil { + http.Error(w, "404 not found", 404) + return + } + + http.ServeContent(w, r, "", time.Unix(0, 0), file) + file.Close() +} + +// SaveBlob saves a blob to the repository. +func SaveBlob(w http.ResponseWriter, r *http.Request) { + if *debug { + log.Println("SaveBlob()") + } + vars := strings.Split(r.RequestURI, "/") + dir := vars[1] + name := vars[2] + + if dir == "keys" { + if _, err := os.Stat("keys"); err != nil && os.IsNotExist(err) { + createDirectories(*path) + } + } + + tmp := filepath.Join(*path, "tmp", name) + + tf, err := os.OpenFile(tmp, os.O_CREATE|os.O_WRONLY, 0600) + if err != nil { + http.Error(w, "500 internal server error", 500) + return + } + + if _, err := io.Copy(tf, r.Body); err != nil { + tf.Close() + os.Remove(tmp) + http.Error(w, "400 bad request", 400) + return + } + if err := tf.Sync(); err != nil { + tf.Close() + os.Remove(tmp) + http.Error(w, "500 internal server error", 500) + return + } + if err := tf.Close(); err != nil { + os.Remove(tmp) + http.Error(w, "500 internal server error", 500) + return + } + + if isHashed(dir) { + name = filepath.Join(name[:2], name) + } + path := filepath.Join(*path, dir, name) + + if err := os.Rename(tmp, path); err != nil { + os.Remove(tmp) + os.Remove(path) + http.Error(w, "500 internal server error", 500) + return + } + + w.Write([]byte("200 ok")) +} + +// DeleteBlob deletes a blob from the repository. +func DeleteBlob(w http.ResponseWriter, r *http.Request) { + if *debug { + log.Println("DeleteBlob()") + } + vars := strings.Split(r.RequestURI, "/") + dir := vars[1] + name := vars[2] + + if isHashed(dir) { + name = filepath.Join(name[:2], name) + } + path := filepath.Join(*path, dir, name) + + if err := os.Remove(path); err != nil { + http.Error(w, "500 internal server error", 500) + return + } + + w.Write([]byte("200 ok")) } diff --git a/main.go b/main.go index 108bb9f..43bc3cf 100644 --- a/main.go +++ b/main.go @@ -7,33 +7,36 @@ import ( "os" "path/filepath" "runtime/pprof" + + "goji.io" + "goji.io/pat" ) -var debug = flag.Bool("debug", false, "output debug messages") +var ( + path = flag.String("path", "/tmp/restic", "data directory") + listen = flag.String("listen", ":8000", "listen address") + tls = flag.Bool("tls", false, "turn on TLS support") + cpuprofile = flag.String("cpuprofile", "", "write CPU profile to file") + debug = flag.Bool("debug", false, "output debug messages") +) -func setupRoutes(path string) *Router { - context := &Context{path} +func setupMux() *goji.Mux { + mux := goji.NewMux() + mux.HandleFunc(pat.Head("/config"), CheckConfig) + mux.HandleFunc(pat.Get("/config"), GetConfig) + mux.HandleFunc(pat.Post("/config"), SaveConfig) + mux.HandleFunc(pat.Get("/:dir/"), ListBlobs) + mux.HandleFunc(pat.Head("/:dir/:name"), CheckBlob) + mux.HandleFunc(pat.Get("/:type/:name"), GetBlob) + mux.HandleFunc(pat.Post("/:type/:name"), SaveBlob) + mux.HandleFunc(pat.Delete("/:type/:name"), DeleteBlob) - router := NewRouter() - router.HeadFunc("/config", CheckConfig(context)) - router.GetFunc("/config", GetConfig(context)) - router.PostFunc("/config", SaveConfig(context)) - router.GetFunc("/:dir/", ListBlobs(context)) - router.HeadFunc("/:dir/:name", CheckBlob(context)) - router.GetFunc("/:type/:name", GetBlob(context)) - router.PostFunc("/:type/:name", SaveBlob(context)) - router.DeleteFunc("/:type/:name", DeleteBlob(context)) - - return router + return mux } func main() { log.SetFlags(0) - var cpuprofile = flag.String("cpuprofile", "", "write CPU profile to file") - var listen = flag.String("listen", ":8000", "listen address") - var path = flag.String("path", "/tmp/restic", "data directory") - var tls = flag.Bool("tls", false, "turn on TLS support") flag.Parse() if *cpuprofile != "" { @@ -48,15 +51,15 @@ func main() { defer pprof.StopCPUProfile() } - router := setupRoutes(*path) + mux := setupMux() var handler http.Handler htpasswdFile, err := NewHtpasswdFromFile(filepath.Join(*path, ".htpasswd")) if err != nil { - handler = router + handler = mux log.Println("Authentication disabled") } else { - handler = AuthHandler(htpasswdFile, router) + handler = AuthHandler(htpasswdFile, mux) log.Println("Authentication enabled") } diff --git a/router.go b/router.go deleted file mode 100644 index 07f0edf..0000000 --- a/router.go +++ /dev/null @@ -1,137 +0,0 @@ -package main - -import ( - "log" - "net/http" - "strings" -) - -// Route is a handler for a path that was already split. -type Route struct { - path []string - handler http.Handler -} - -// Router maps HTTP methods to a slice of Route handlers. -type Router struct { - routes map[string][]Route -} - -// NewRouter creates a new Router and returns a pointer to it. -func NewRouter() *Router { - return &Router{make(map[string][]Route)} -} - -// Options registers handler for path with method "OPTIONS". -func (router *Router) Options(path string, handler http.Handler) { - router.Handle("OPTIONS", path, handler) -} - -// OptionsFunc registers handler for path with method "OPTIONS". -func (router *Router) OptionsFunc(path string, handler http.HandlerFunc) { - router.Handle("OPTIONS", path, handler) -} - -// Get registers handler for path with method "GET". -func (router *Router) Get(path string, handler http.Handler) { - router.Handle("GET", path, handler) -} - -// GetFunc registers handler for path with method "GET". -func (router *Router) GetFunc(path string, handler http.HandlerFunc) { - router.Handle("GET", path, handler) -} - -// Head registers handler for path with method "HEAD". -func (router *Router) Head(path string, handler http.Handler) { - router.Handle("HEAD", path, handler) -} - -// HeadFunc registers handler for path with method "HEAD". -func (router *Router) HeadFunc(path string, handler http.HandlerFunc) { - router.Handle("HEAD", path, handler) -} - -// Post registers handler for path with method "POST". -func (router *Router) Post(path string, handler http.Handler) { - router.Handle("POST", path, handler) -} - -// PostFunc registers handler for path with method "POST". -func (router *Router) PostFunc(path string, handler http.HandlerFunc) { - router.Handle("POST", path, handler) -} - -// Put registers handler for path with method "PUT". -func (router *Router) Put(path string, handler http.Handler) { - router.Handle("PUT", path, handler) -} - -// PutFunc registers handler for path with method "PUT". -func (router *Router) PutFunc(path string, handler http.HandlerFunc) { - router.Handle("PUT", path, handler) -} - -// Delete registers handler for path with method "DELETE". -func (router *Router) Delete(path string, handler http.Handler) { - router.Handle("DELETE", path, handler) -} - -// DeleteFunc registers handler for path with method "DELETE". -func (router *Router) DeleteFunc(path string, handler http.HandlerFunc) { - router.Handle("DELETE", path, handler) -} - -// Trace registers handler for path with method "TRACE". -func (router *Router) Trace(path string, handler http.Handler) { - router.Handle("TRACE", path, handler) -} - -// TraceFunc registers handler for path with method "TRACE". -func (router *Router) TraceFunc(path string, handler http.HandlerFunc) { - router.Handle("TRACE", path, handler) -} - -// Connect registers handler for path with method "Connect". -func (router *Router) Connect(path string, handler http.Handler) { - router.Handle("Connect", path, handler) -} - -// ConnectFunc registers handler for path with method "Connect". -func (router *Router) ConnectFunc(path string, handler http.HandlerFunc) { - router.Handle("Connect", path, handler) -} - -// Handle registers a http.Handler for method and uri. -func (router *Router) Handle(method string, uri string, handler http.Handler) { - routes := router.routes[method] - path := strings.Split(uri, "/") - routes = append(routes, Route{path, handler}) - router.routes[method] = routes -} - -func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { - method := r.Method - uri := r.RequestURI - path := strings.Split(uri, "/") - - if *debug { - log.Printf("%s %s", method, uri) - } - -ROUTE: - for _, route := range router.routes[method] { - if len(route.path) != len(path) { - continue - } - for i := 0; i < len(route.path); i++ { - if !strings.HasPrefix(route.path[i], ":") && route.path[i] != path[i] { - continue ROUTE - } - } - route.handler.ServeHTTP(w, r) - return - } - - http.Error(w, "404 not found", 404) -} diff --git a/router_test.go b/router_test.go deleted file mode 100644 index b776a9f..0000000 --- a/router_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package main - -import ( - "io/ioutil" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -func TestRouter(t *testing.T) { - router := NewRouter() - - getConfig := []byte("GET /config") - router.GetFunc("/config", func(w http.ResponseWriter, r *http.Request) { - w.Write(getConfig) - }) - - postConfig := []byte("POST /config") - router.PostFunc("/config", func(w http.ResponseWriter, r *http.Request) { - w.Write(postConfig) - }) - - getBlobs := []byte("GET /blobs/") - router.GetFunc("/blobs/", func(w http.ResponseWriter, r *http.Request) { - w.Write(getBlobs) - }) - - getBlob := []byte("GET /blobs/:sha") - router.GetFunc("/blobs/:sha", func(w http.ResponseWriter, r *http.Request) { - w.Write(getBlob) - }) - - server := httptest.NewServer(router) - defer server.Close() - - getConfigResp, _ := http.Get(server.URL + "/config") - getConfigBody, _ := ioutil.ReadAll(getConfigResp.Body) - if getConfigResp.StatusCode != 200 { - t.Fatalf("Wanted HTTP Status 200, got %d", getConfigResp.StatusCode) - } - if string(getConfig) != string(getConfigBody) { - t.Fatalf("Config wrong:\nWanted '%s'\nGot: '%s'", string(getConfig), string(getConfigBody)) - } - - postConfigResp, _ := http.Post(server.URL+"/config", "binary/octet-stream", strings.NewReader("post test")) - postConfigBody, _ := ioutil.ReadAll(postConfigResp.Body) - if postConfigResp.StatusCode != 200 { - t.Fatalf("Wanted HTTP Status 200, got %d", postConfigResp.StatusCode) - } - if string(postConfig) != string(postConfigBody) { - t.Fatalf("Config wrong:\nWanted '%s'\nGot: '%s'", string(postConfig), string(postConfigBody)) - } - - getBlobsResp, _ := http.Get(server.URL + "/blobs/") - getBlobsBody, _ := ioutil.ReadAll(getBlobsResp.Body) - if getBlobsResp.StatusCode != 200 { - t.Fatalf("Wanted HTTP Status 200, got %d", getBlobsResp.StatusCode) - } - if string(getBlobs) != string(getBlobsBody) { - t.Fatalf("Config wrong:\nWanted '%s'\nGot: '%s'", string(getBlobs), string(getBlobsBody)) - } - - getBlobResp, _ := http.Get(server.URL + "/blobs/test") - getBlobBody, _ := ioutil.ReadAll(getBlobResp.Body) - if getBlobResp.StatusCode != 200 { - t.Fatalf("Wanted HTTP Status 200, got %d", getBlobResp.StatusCode) - } - if string(getBlob) != string(getBlobBody) { - t.Fatalf("Config wrong:\nWanted '%s'\nGot: '%s'", string(getBlob), string(getBlobBody)) - } -}