Fix directory traversal

This commit introduces the strict checks from net/http.Dir, which fixes
a directory traversal issue.

Closes #22
This commit is contained in:
Alexander Neumann 2017-07-30 14:30:18 +02:00 committed by Zlatko Čalušić
parent 9a6bb5eebe
commit a628c4e01a
2 changed files with 172 additions and 34 deletions

View file

@ -2,12 +2,14 @@ package restserver
import (
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"path"
"path/filepath"
"strings"
"time"
@ -20,12 +22,83 @@ func isHashed(dir string) bool {
return dir == "data"
}
func getRepo(r *http.Request) string {
if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo") {
return filepath.Join(Config.Path, pat.Param(r, "repo"))
func valid(name string) bool {
// taken from net/http.Dir
if strings.Contains(name, "\x00") {
return false
}
return Config.Path
if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) {
return false
}
return true
}
var validTypes = []string{"data", "index", "keys", "locks", "snapshots", "config"}
func isValidType(name string) bool {
for _, tpe := range validTypes {
if name == tpe {
return true
}
}
return false
}
// join takes a number of path names, sanitizes them, and returns them joined
// with base for the current operating system to use (dirs separated by
// filepath.Separator). The returned path is always either equal to base or a
// subdir of base.
func join(base string, names ...string) (string, error) {
clean := make([]string, 0, len(names)+1)
clean = append(clean, base)
// taken from net/http.Dir
for _, name := range names {
if !valid(name) {
return "", errors.New("invalid character in path")
}
clean = append(clean, filepath.FromSlash(path.Clean("/"+name)))
}
return filepath.Join(clean...), nil
}
// getRepo returns the repository location, relative to Config.Path.
func getRepo(r *http.Request) string {
if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo") {
return pat.Param(r, "repo")
}
return "."
}
// getPath returns the path for a file type in the repo.
func getPath(r *http.Request, fileType string) (string, error) {
if !isValidType(fileType) {
return "", errors.New("invalid file type")
}
return join(Config.Path, getRepo(r), fileType)
}
// getFilePath returns the path for a file in the repo.
func getFilePath(r *http.Request, fileType, name string) (string, error) {
if !isValidType(fileType) {
return "", errors.New("invalid file type")
}
if isHashed(fileType) {
if len(name) < 2 {
return "", errors.New("file name is too short")
}
return join(Config.Path, getRepo(r), fileType, name[:2], name)
}
return join(Config.Path, getRepo(r), fileType, name)
}
// AuthHandler wraps h with a http.HandlerFunc that performs basic authentication against the user/passwords pairs
@ -46,7 +119,12 @@ func CheckConfig(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("CheckConfig()")
}
cfg := filepath.Join(getRepo(r), "config")
cfg, err := getPath(r, "config")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
st, err := os.Stat(cfg)
if err != nil {
if Config.Debug {
@ -64,7 +142,11 @@ func GetConfig(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("GetConfig()")
}
cfg := filepath.Join(getRepo(r), "config")
cfg, err := getPath(r, "config")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
bytes, err := ioutil.ReadFile(cfg)
if err != nil {
@ -83,7 +165,11 @@ func SaveConfig(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("SaveConfig()")
}
cfg := filepath.Join(getRepo(r), "config")
cfg, err := getPath(r, "config")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
bytes, err := ioutil.ReadAll(r.Body)
if err != nil {
@ -108,8 +194,13 @@ func DeleteConfig(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("DeleteConfig()")
}
cfg, err := getPath(r, "config")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if err := os.Remove(filepath.Join(getRepo(r), "config")); err != nil {
if err := os.Remove(cfg); err != nil {
if Config.Debug {
log.Print(err)
}
@ -127,8 +218,12 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("ListBlobs()")
}
dir := pat.Param(r, "type")
path := filepath.Join(getRepo(r), dir)
fileType := pat.Param(r, "type")
path, err := getPath(r, fileType)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
items, err := ioutil.ReadDir(path)
if err != nil {
@ -141,7 +236,7 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) {
var names []string
for _, i := range items {
if isHashed(dir) {
if isHashed(fileType) {
subpath := filepath.Join(path, i.Name())
subitems, err := ioutil.ReadDir(subpath)
if err != nil {
@ -176,13 +271,12 @@ func CheckBlob(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("CheckBlob()")
}
dir := pat.Param(r, "type")
name := pat.Param(r, "name")
if isHashed(dir) {
name = filepath.Join(name[:2], name)
path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
path := filepath.Join(getRepo(r), dir, name)
st, err := os.Stat(path)
if err != nil {
@ -201,13 +295,12 @@ func GetBlob(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("GetBlob()")
}
dir := pat.Param(r, "type")
name := pat.Param(r, "name")
if isHashed(dir) {
name = filepath.Join(name[:2], name)
path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
path := filepath.Join(getRepo(r), dir, name)
file, err := os.Open(path)
if err != nil {
@ -227,14 +320,12 @@ func SaveBlob(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("SaveBlob()")
}
repo := getRepo(r)
dir := pat.Param(r, "type")
name := pat.Param(r, "name")
if isHashed(dir) {
name = filepath.Join(name[:2], name)
path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
path := filepath.Join(repo, dir, name)
tf, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600)
if err != nil {
@ -280,13 +371,12 @@ func DeleteBlob(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("DeleteBlob()")
}
dir := pat.Param(r, "type")
name := pat.Param(r, "name")
if isHashed(dir) {
name = filepath.Join(name[:2], name)
path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
path := filepath.Join(getRepo(r), dir, name)
if err := os.Remove(path); err != nil {
if Config.Debug {
@ -306,7 +396,12 @@ func CreateRepo(w http.ResponseWriter, r *http.Request) {
if Config.Debug {
log.Println("CreateRepo()")
}
repo := getRepo(r)
repo, err := join(Config.Path, getRepo(r))
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if r.URL.Query().Get("create") != "true" {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
@ -321,7 +416,11 @@ func CreateRepo(w http.ResponseWriter, r *http.Request) {
return
}
for _, d := range []string{"data", "index", "keys", "locks", "snapshots", "tmp"} {
for _, d := range validTypes {
if d == "config" {
continue
}
if err := os.MkdirAll(filepath.Join(repo, d), 0700); err != nil {
log.Print(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)