package restserver import ( "encoding/json" "errors" "fmt" "io" "io/ioutil" "log" "net/http" "os" "path" "path/filepath" "strings" "time" "goji.io/middleware" "goji.io/pat" ) func isHashed(dir string) bool { return dir == "data" } func valid(name string) bool { // taken from net/http.Dir if strings.Contains(name, "\x00") { return false } if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) { return false } return true } var validTypes = []string{"data", "index", "keys", "locks", "snapshots", "config"} func isValidType(name string) bool { for _, tpe := range validTypes { if name == tpe { return true } } return false } // join takes a number of path names, sanitizes them, and returns them joined // with base for the current operating system to use (dirs separated by // filepath.Separator). The returned path is always either equal to base or a // subdir of base. func join(base string, names ...string) (string, error) { clean := make([]string, 0, len(names)+1) clean = append(clean, base) // taken from net/http.Dir for _, name := range names { if !valid(name) { return "", errors.New("invalid character in path") } clean = append(clean, filepath.FromSlash(path.Clean("/"+name))) } return filepath.Join(clean...), nil } // getRepo returns the repository location, relative to Config.Path. func getRepo(r *http.Request) string { if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo") { return pat.Param(r, "repo") } return "." } // getPath returns the path for a file type in the repo. func getPath(r *http.Request, fileType string) (string, error) { if !isValidType(fileType) { return "", errors.New("invalid file type") } return join(Config.Path, getRepo(r), fileType) } // getFilePath returns the path for a file in the repo. func getFilePath(r *http.Request, fileType, name string) (string, error) { if !isValidType(fileType) { return "", errors.New("invalid file type") } if isHashed(fileType) { if len(name) < 2 { return "", errors.New("file name is too short") } return join(Config.Path, getRepo(r), fileType, name[:2], name) } return join(Config.Path, getRepo(r), fileType, name) } // AuthHandler wraps h with a http.HandlerFunc that performs basic authentication against the user/passwords pairs // stored in f and returns the http.HandlerFunc. func AuthHandler(f *HtpasswdFile, h http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if username, password, ok := r.BasicAuth(); !ok || !f.Validate(username, password) { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } h.ServeHTTP(w, r) } } // CheckConfig checks whether a configuration exists. func CheckConfig(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("CheckConfig()") } cfg, err := getPath(r, "config") if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } st, err := os.Stat(cfg) if err != nil { if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } w.Header().Add("Content-Length", fmt.Sprint(st.Size())) } // GetConfig allows for a config to be retrieved. func GetConfig(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("GetConfig()") } cfg, err := getPath(r, "config") if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } bytes, err := ioutil.ReadFile(cfg) if err != nil { if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } w.Write(bytes) } // SaveConfig allows for a config to be saved. func SaveConfig(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("SaveConfig()") } cfg, err := getPath(r, "config") if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } bytes, err := ioutil.ReadAll(r.Body) if err != nil { if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } if err := ioutil.WriteFile(cfg, bytes, 0600); err != nil { if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } } // DeleteConfig removes a config. func DeleteConfig(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("DeleteConfig()") } if Config.AppendOnly { http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } cfg, err := getPath(r, "config") if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } if err := os.Remove(cfg); err != nil { if Config.Debug { log.Print(err) } if os.IsNotExist(err) { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) } else { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } return } } // ListBlobs lists all blobs of a given type in an arbitrary order. func ListBlobs(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("ListBlobs()") } fileType := pat.Param(r, "type") path, err := getPath(r, fileType) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } items, err := ioutil.ReadDir(path) if err != nil { if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } var names []string for _, i := range items { if isHashed(fileType) { subpath := filepath.Join(path, i.Name()) subitems, err := ioutil.ReadDir(subpath) if err != nil { if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } for _, f := range subitems { names = append(names, f.Name()) } } else { names = append(names, i.Name()) } } data, err := json.Marshal(names) if err != nil { if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } w.Write(data) } // CheckBlob tests whether a blob exists. func CheckBlob(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("CheckBlob()") } path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name")) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } st, err := os.Stat(path) if err != nil { if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } w.Header().Add("Content-Length", fmt.Sprint(st.Size())) } // GetBlob retrieves a blob from the repository. func GetBlob(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("GetBlob()") } path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name")) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } file, err := os.Open(path) if err != nil { if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } http.ServeContent(w, r, "", time.Unix(0, 0), file) file.Close() } // SaveBlob saves a blob to the repository. func SaveBlob(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("SaveBlob()") } path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name")) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } tf, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600) if err != nil { if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } if _, err := io.Copy(tf, r.Body); err != nil { tf.Close() os.Remove(path) if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } if err := tf.Sync(); err != nil { tf.Close() os.Remove(path) if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } if err := tf.Close(); err != nil { os.Remove(path) if Config.Debug { log.Print(err) } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } } // DeleteBlob deletes a blob from the repository. func DeleteBlob(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("DeleteBlob()") } if Config.AppendOnly && pat.Param(r, "type") != "locks" { http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name")) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } if err := os.Remove(path); err != nil { if Config.Debug { log.Print(err) } if os.IsNotExist(err) { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) } else { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } return } } // CreateRepo creates repository directories. func CreateRepo(w http.ResponseWriter, r *http.Request) { if Config.Debug { log.Println("CreateRepo()") } repo, err := join(Config.Path, getRepo(r)) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } if r.URL.Query().Get("create") != "true" { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } log.Printf("Creating repository directories in %s\n", repo) if err := os.MkdirAll(repo, 0700); err != nil { log.Print(err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } for _, d := range validTypes { if d == "config" { continue } if err := os.MkdirAll(filepath.Join(repo, d), 0700); err != nil { log.Print(err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } } for i := 0; i < 256; i++ { if err := os.MkdirAll(filepath.Join(repo, "data", fmt.Sprintf("%02x", i)), 0700); err != nil { log.Print(err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } } }