diff --git a/handlers.go b/handlers.go index e02d63f..d56b966 100644 --- a/handlers.go +++ b/handlers.go @@ -2,12 +2,14 @@ package restserver import ( "encoding/json" + "errors" "fmt" "io" "io/ioutil" "log" "net/http" "os" + "path" "path/filepath" "strings" "time" @@ -20,12 +22,83 @@ func isHashed(dir string) bool { return dir == "data" } -func getRepo(r *http.Request) string { - if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo") { - return filepath.Join(Config.Path, pat.Param(r, "repo")) +func valid(name string) bool { + // taken from net/http.Dir + if strings.Contains(name, "\x00") { + return false } - return Config.Path + if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) { + return false + } + + return true +} + +var validTypes = []string{"data", "index", "keys", "locks", "snapshots", "config"} + +func isValidType(name string) bool { + for _, tpe := range validTypes { + if name == tpe { + return true + } + } + + return false +} + +// join takes a number of path names, sanitizes them, and returns them joined +// with base for the current operating system to use (dirs separated by +// filepath.Separator). The returned path is always either equal to base or a +// subdir of base. +func join(base string, names ...string) (string, error) { + clean := make([]string, 0, len(names)+1) + clean = append(clean, base) + + // taken from net/http.Dir + for _, name := range names { + if !valid(name) { + return "", errors.New("invalid character in path") + } + + clean = append(clean, filepath.FromSlash(path.Clean("/"+name))) + } + + return filepath.Join(clean...), nil +} + +// getRepo returns the repository location, relative to Config.Path. +func getRepo(r *http.Request) string { + if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo") { + return pat.Param(r, "repo") + } + + return "." +} + +// getPath returns the path for a file type in the repo. +func getPath(r *http.Request, fileType string) (string, error) { + if !isValidType(fileType) { + return "", errors.New("invalid file type") + } + return join(Config.Path, getRepo(r), fileType) +} + +// getFilePath returns the path for a file in the repo. +func getFilePath(r *http.Request, fileType, name string) (string, error) { + if !isValidType(fileType) { + return "", errors.New("invalid file type") + } + + if isHashed(fileType) { + if len(name) < 2 { + return "", errors.New("file name is too short") + } + + return join(Config.Path, getRepo(r), fileType, name[:2], name) + } + + return join(Config.Path, getRepo(r), fileType, name) } // AuthHandler wraps h with a http.HandlerFunc that performs basic authentication against the user/passwords pairs @@ -46,7 +119,12 @@ func CheckConfig(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("CheckConfig()") } - cfg := filepath.Join(getRepo(r), "config") + cfg, err := getPath(r, "config") + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + st, err := os.Stat(cfg) if err != nil { if Config.Debug { @@ -64,7 +142,11 @@ func GetConfig(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("GetConfig()") } - cfg := filepath.Join(getRepo(r), "config") + cfg, err := getPath(r, "config") + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } bytes, err := ioutil.ReadFile(cfg) if err != nil { @@ -83,7 +165,11 @@ func SaveConfig(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("SaveConfig()") } - cfg := filepath.Join(getRepo(r), "config") + cfg, err := getPath(r, "config") + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } bytes, err := ioutil.ReadAll(r.Body) if err != nil { @@ -108,8 +194,13 @@ func DeleteConfig(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("DeleteConfig()") } + cfg, err := getPath(r, "config") + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } - if err := os.Remove(filepath.Join(getRepo(r), "config")); err != nil { + if err := os.Remove(cfg); err != nil { if Config.Debug { log.Print(err) } @@ -127,8 +218,12 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("ListBlobs()") } - dir := pat.Param(r, "type") - path := filepath.Join(getRepo(r), dir) + fileType := pat.Param(r, "type") + path, err := getPath(r, fileType) + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } items, err := ioutil.ReadDir(path) if err != nil { @@ -141,7 +236,7 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) { var names []string for _, i := range items { - if isHashed(dir) { + if isHashed(fileType) { subpath := filepath.Join(path, i.Name()) subitems, err := ioutil.ReadDir(subpath) if err != nil { @@ -176,13 +271,12 @@ func CheckBlob(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("CheckBlob()") } - dir := pat.Param(r, "type") - name := pat.Param(r, "name") - if isHashed(dir) { - name = filepath.Join(name[:2], name) + path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name")) + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return } - path := filepath.Join(getRepo(r), dir, name) st, err := os.Stat(path) if err != nil { @@ -201,13 +295,12 @@ func GetBlob(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("GetBlob()") } - dir := pat.Param(r, "type") - name := pat.Param(r, "name") - if isHashed(dir) { - name = filepath.Join(name[:2], name) + path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name")) + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return } - path := filepath.Join(getRepo(r), dir, name) file, err := os.Open(path) if err != nil { @@ -227,14 +320,12 @@ func SaveBlob(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("SaveBlob()") } - repo := getRepo(r) - dir := pat.Param(r, "type") - name := pat.Param(r, "name") - if isHashed(dir) { - name = filepath.Join(name[:2], name) + path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name")) + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return } - path := filepath.Join(repo, dir, name) tf, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600) if err != nil { @@ -280,13 +371,12 @@ func DeleteBlob(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("DeleteBlob()") } - dir := pat.Param(r, "type") - name := pat.Param(r, "name") - if isHashed(dir) { - name = filepath.Join(name[:2], name) + path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name")) + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return } - path := filepath.Join(getRepo(r), dir, name) if err := os.Remove(path); err != nil { if Config.Debug { @@ -306,7 +396,12 @@ func CreateRepo(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("CreateRepo()") } - repo := getRepo(r) + + repo, err := join(Config.Path, getRepo(r)) + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } if r.URL.Query().Get("create") != "true" { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) @@ -321,7 +416,11 @@ func CreateRepo(w http.ResponseWriter, r *http.Request) { return } - for _, d := range []string{"data", "index", "keys", "locks", "snapshots", "tmp"} { + for _, d := range validTypes { + if d == "config" { + continue + } + if err := os.MkdirAll(filepath.Join(repo, d), 0700); err != nil { log.Print(err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) diff --git a/handlers_test.go b/handlers_test.go new file mode 100644 index 0000000..6536a26 --- /dev/null +++ b/handlers_test.go @@ -0,0 +1,39 @@ +package restserver + +import ( + "path/filepath" + "testing" +) + +func TestJoin(t *testing.T) { + var tests = []struct { + base, name string + result string + }{ + {"/", "foo/bar", "/foo/bar"}, + {"/srv/server", "foo/bar", "/srv/server/foo/bar"}, + {"/srv/server", "/foo/bar", "/srv/server/foo/bar"}, + {"/srv/server", "foo/../bar", "/srv/server/bar"}, + {"/srv/server", "../bar", "/srv/server/bar"}, + {"/srv/server", "..", "/srv/server"}, + {"/srv/server", "../..", "/srv/server"}, + {"/srv/server", "/repo/data/", "/srv/server/repo/data"}, + {"/srv/server", "/repo/data/../..", "/srv/server"}, + {"/srv/server", "/repo/data/../data/../../..", "/srv/server"}, + {"/srv/server", "/repo/data/../data/../../..", "/srv/server"}, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + got, err := join(filepath.FromSlash(test.base), test.name) + if err != nil { + t.Fatal(err) + } + + want := filepath.FromSlash(test.result) + if got != want { + t.Fatalf("wrong result returned, want %v, got %v", want, got) + } + }) + } +}