rest-server/handlers.go
Zlatko Čalušić d056b85432 Check errors in many places
Admittedly, in some places just document the fact that we ignore error
return values, 'cause we don't know what to do with it.  At least, the
linter is happy.
2017-10-25 18:31:34 +02:00

470 lines
12 KiB
Go

package restserver
import (
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"path"
"path/filepath"
"strings"
"time"
"github.com/miolini/datacounter"
"github.com/prometheus/client_golang/prometheus"
"goji.io/middleware"
"goji.io/pat"
)
func isHashed(dir string) bool {
return dir == "data"
}
func valid(name string) bool {
// taken from net/http.Dir
if strings.Contains(name, "\x00") {
return false
}
if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) {
return false
}
return true
}
var validTypes = []string{"data", "index", "keys", "locks", "snapshots", "config"}
func isValidType(name string) bool {
for _, tpe := range validTypes {
if name == tpe {
return true
}
}
return false
}
// join takes a number of path names, sanitizes them, and returns them joined
// with base for the current operating system to use (dirs separated by
// filepath.Separator). The returned path is always either equal to base or a
// subdir of base.
func join(base string, names ...string) (string, error) {
clean := make([]string, 0, len(names)+1)
clean = append(clean, base)
// taken from net/http.Dir
for _, name := range names {
if !valid(name) {
return "", errors.New("invalid character in path")
}
clean = append(clean, filepath.FromSlash(path.Clean("/"+name)))
}
return filepath.Join(clean...), nil
}
// getRepo returns the repository location, relative to Config.Path.
func getRepo(r *http.Request) string {
if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo") {
return pat.Param(r, "repo")
}
return "."
}
// getPath returns the path for a file type in the repo.
func getPath(r *http.Request, fileType string) (string, error) {
if !isValidType(fileType) {
return "", errors.New("invalid file type")
}
return join(Config.Path, getRepo(r), fileType)
}
// getFilePath returns the path for a file in the repo.
func getFilePath(r *http.Request, fileType, name string) (string, error) {
if !isValidType(fileType) {
return "", errors.New("invalid file type")
}
if isHashed(fileType) {
if len(name) < 2 {
return "", errors.New("file name is too short")
}
return join(Config.Path, getRepo(r), fileType, name[:2], name)
}
return join(Config.Path, getRepo(r), fileType, name)
}
// AuthHandler wraps h with a http.HandlerFunc that performs basic authentication against the user/passwords pairs
// stored in f and returns the http.HandlerFunc.
func AuthHandler(f *HtpasswdFile, h http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if username, password, ok := r.BasicAuth(); !ok || !f.Validate(username, password) {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
h.ServeHTTP(w, r)
}
}
// CheckConfig checks whether a configuration exists.
func CheckConfig(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("CheckConfig()")
}
cfg, err := getPath(r, "config")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
st, err := os.Stat(cfg)
if err != nil {
if Config.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
w.Header().Add("Content-Length", fmt.Sprint(st.Size()))
}
// GetConfig allows for a config to be retrieved.
func GetConfig(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("GetConfig()")
}
cfg, err := getPath(r, "config")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
bytes, err := ioutil.ReadFile(cfg)
if err != nil {
if Config.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
_, _ = w.Write(bytes)
}
// SaveConfig allows for a config to be saved.
func SaveConfig(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("SaveConfig()")
}
cfg, err := getPath(r, "config")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
bytes, err := ioutil.ReadAll(r.Body)
if err != nil {
if Config.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
if err := ioutil.WriteFile(cfg, bytes, 0600); err != nil {
if Config.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
}
// DeleteConfig removes a config.
func DeleteConfig(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("DeleteConfig()")
}
if Config.AppendOnly {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
cfg, err := getPath(r, "config")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if err := os.Remove(cfg); err != nil {
if Config.Debug {
log.Print(err)
}
if os.IsNotExist(err) {
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
} else {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
return
}
}
// ListBlobs lists all blobs of a given type in an arbitrary order.
func ListBlobs(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("ListBlobs()")
}
fileType := pat.Param(r, "type")
path, err := getPath(r, fileType)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
items, err := ioutil.ReadDir(path)
if err != nil {
if Config.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
var names []string
for _, i := range items {
if isHashed(fileType) {
subpath := filepath.Join(path, i.Name())
var subitems []os.FileInfo
subitems, err = ioutil.ReadDir(subpath)
if err != nil {
if Config.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
for _, f := range subitems {
names = append(names, f.Name())
}
} else {
names = append(names, i.Name())
}
}
data, err := json.Marshal(names)
if err != nil {
if Config.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
_, _ = w.Write(data)
}
// CheckBlob tests whether a blob exists.
func CheckBlob(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("CheckBlob()")
}
path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
st, err := os.Stat(path)
if err != nil {
if Config.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
w.Header().Add("Content-Length", fmt.Sprint(st.Size()))
}
// GetBlob retrieves a blob from the repository.
func GetBlob(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("GetBlob()")
}
path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
file, err := os.Open(path)
if err != nil {
if Config.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
wc := datacounter.NewResponseWriterCounter(w)
http.ServeContent(wc, r, "", time.Unix(0, 0), file)
if err = file.Close(); err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if Config.Prometheus {
labels := prometheus.Labels{"repo": getRepo(r), "type": pat.Param(r, "type")}
metricBlobReadTotal.With(labels).Inc()
metricBlobReadBytesTotal.With(labels).Add(float64(wc.Count()))
}
}
// SaveBlob saves a blob to the repository.
func SaveBlob(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("SaveBlob()")
}
path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
tf, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600)
if err != nil {
if Config.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
written, err := io.Copy(tf, r.Body)
if err != nil {
_ = tf.Close()
_ = os.Remove(path)
if Config.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
if err := tf.Sync(); err != nil {
_ = tf.Close()
_ = os.Remove(path)
if Config.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if err := tf.Close(); err != nil {
_ = os.Remove(path)
if Config.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if Config.Prometheus {
labels := prometheus.Labels{"repo": getRepo(r), "type": pat.Param(r, "type")}
metricBlobWriteTotal.With(labels).Inc()
metricBlobWriteBytesTotal.With(labels).Add(float64(written))
}
}
// DeleteBlob deletes a blob from the repository.
func DeleteBlob(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("DeleteBlob()")
}
if Config.AppendOnly && pat.Param(r, "type") != "locks" {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if err := os.Remove(path); err != nil {
if Config.Debug {
log.Print(err)
}
if os.IsNotExist(err) {
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
} else {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
return
}
}
// CreateRepo creates repository directories.
func CreateRepo(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("CreateRepo()")
}
repo, err := join(Config.Path, getRepo(r))
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if r.URL.Query().Get("create") != "true" {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
log.Printf("Creating repository directories in %s\n", repo)
if err := os.MkdirAll(repo, 0700); err != nil {
log.Print(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
for _, d := range validTypes {
if d == "config" {
continue
}
if err := os.MkdirAll(filepath.Join(repo, d), 0700); err != nil {
log.Print(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
}
for i := 0; i < 256; i++ {
if err := os.MkdirAll(filepath.Join(repo, "data", fmt.Sprintf("%02x", i)), 0700); err != nil {
log.Print(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
}
}