Fix directory traversal

This commit introduces the strict checks from net/http.Dir, which fixes
a directory traversal issue.

Closes #22
This commit is contained in:
Alexander Neumann 2017-07-30 14:30:18 +02:00 committed by Zlatko Čalušić
parent 9a6bb5eebe
commit a628c4e01a
2 changed files with 172 additions and 34 deletions

View file

@ -2,12 +2,14 @@ package restserver
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
"net/http" "net/http"
"os" "os"
"path"
"path/filepath" "path/filepath"
"strings" "strings"
"time" "time"
@ -20,12 +22,83 @@ func isHashed(dir string) bool {
return dir == "data" return dir == "data"
} }
func getRepo(r *http.Request) string { func valid(name string) bool {
if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo") { // taken from net/http.Dir
return filepath.Join(Config.Path, pat.Param(r, "repo")) if strings.Contains(name, "\x00") {
return false
} }
return Config.Path if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) {
return false
}
return true
}
var validTypes = []string{"data", "index", "keys", "locks", "snapshots", "config"}
func isValidType(name string) bool {
for _, tpe := range validTypes {
if name == tpe {
return true
}
}
return false
}
// join takes a number of path names, sanitizes them, and returns them joined
// with base for the current operating system to use (dirs separated by
// filepath.Separator). The returned path is always either equal to base or a
// subdir of base.
func join(base string, names ...string) (string, error) {
clean := make([]string, 0, len(names)+1)
clean = append(clean, base)
// taken from net/http.Dir
for _, name := range names {
if !valid(name) {
return "", errors.New("invalid character in path")
}
clean = append(clean, filepath.FromSlash(path.Clean("/"+name)))
}
return filepath.Join(clean...), nil
}
// getRepo returns the repository location, relative to Config.Path.
func getRepo(r *http.Request) string {
if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo") {
return pat.Param(r, "repo")
}
return "."
}
// getPath returns the path for a file type in the repo.
func getPath(r *http.Request, fileType string) (string, error) {
if !isValidType(fileType) {
return "", errors.New("invalid file type")
}
return join(Config.Path, getRepo(r), fileType)
}
// getFilePath returns the path for a file in the repo.
func getFilePath(r *http.Request, fileType, name string) (string, error) {
if !isValidType(fileType) {
return "", errors.New("invalid file type")
}
if isHashed(fileType) {
if len(name) < 2 {
return "", errors.New("file name is too short")
}
return join(Config.Path, getRepo(r), fileType, name[:2], name)
}
return join(Config.Path, getRepo(r), fileType, name)
} }
// AuthHandler wraps h with a http.HandlerFunc that performs basic authentication against the user/passwords pairs // AuthHandler wraps h with a http.HandlerFunc that performs basic authentication against the user/passwords pairs
@ -46,7 +119,12 @@ func CheckConfig(w http.ResponseWriter, r *http.Request) {
if Config.Debug { if Config.Debug {
log.Println("CheckConfig()") log.Println("CheckConfig()")
} }
cfg := filepath.Join(getRepo(r), "config") cfg, err := getPath(r, "config")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
st, err := os.Stat(cfg) st, err := os.Stat(cfg)
if err != nil { if err != nil {
if Config.Debug { if Config.Debug {
@ -64,7 +142,11 @@ func GetConfig(w http.ResponseWriter, r *http.Request) {
if Config.Debug { if Config.Debug {
log.Println("GetConfig()") log.Println("GetConfig()")
} }
cfg := filepath.Join(getRepo(r), "config") cfg, err := getPath(r, "config")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
bytes, err := ioutil.ReadFile(cfg) bytes, err := ioutil.ReadFile(cfg)
if err != nil { if err != nil {
@ -83,7 +165,11 @@ func SaveConfig(w http.ResponseWriter, r *http.Request) {
if Config.Debug { if Config.Debug {
log.Println("SaveConfig()") log.Println("SaveConfig()")
} }
cfg := filepath.Join(getRepo(r), "config") cfg, err := getPath(r, "config")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
bytes, err := ioutil.ReadAll(r.Body) bytes, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
@ -108,8 +194,13 @@ func DeleteConfig(w http.ResponseWriter, r *http.Request) {
if Config.Debug { if Config.Debug {
log.Println("DeleteConfig()") log.Println("DeleteConfig()")
} }
cfg, err := getPath(r, "config")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if err := os.Remove(filepath.Join(getRepo(r), "config")); err != nil { if err := os.Remove(cfg); err != nil {
if Config.Debug { if Config.Debug {
log.Print(err) log.Print(err)
} }
@ -127,8 +218,12 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) {
if Config.Debug { if Config.Debug {
log.Println("ListBlobs()") log.Println("ListBlobs()")
} }
dir := pat.Param(r, "type") fileType := pat.Param(r, "type")
path := filepath.Join(getRepo(r), dir) path, err := getPath(r, fileType)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
items, err := ioutil.ReadDir(path) items, err := ioutil.ReadDir(path)
if err != nil { if err != nil {
@ -141,7 +236,7 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) {
var names []string var names []string
for _, i := range items { for _, i := range items {
if isHashed(dir) { if isHashed(fileType) {
subpath := filepath.Join(path, i.Name()) subpath := filepath.Join(path, i.Name())
subitems, err := ioutil.ReadDir(subpath) subitems, err := ioutil.ReadDir(subpath)
if err != nil { if err != nil {
@ -176,13 +271,12 @@ func CheckBlob(w http.ResponseWriter, r *http.Request) {
if Config.Debug { if Config.Debug {
log.Println("CheckBlob()") log.Println("CheckBlob()")
} }
dir := pat.Param(r, "type")
name := pat.Param(r, "name")
if isHashed(dir) { path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
name = filepath.Join(name[:2], name) if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
} }
path := filepath.Join(getRepo(r), dir, name)
st, err := os.Stat(path) st, err := os.Stat(path)
if err != nil { if err != nil {
@ -201,13 +295,12 @@ func GetBlob(w http.ResponseWriter, r *http.Request) {
if Config.Debug { if Config.Debug {
log.Println("GetBlob()") log.Println("GetBlob()")
} }
dir := pat.Param(r, "type")
name := pat.Param(r, "name")
if isHashed(dir) { path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
name = filepath.Join(name[:2], name) if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
} }
path := filepath.Join(getRepo(r), dir, name)
file, err := os.Open(path) file, err := os.Open(path)
if err != nil { if err != nil {
@ -227,14 +320,12 @@ func SaveBlob(w http.ResponseWriter, r *http.Request) {
if Config.Debug { if Config.Debug {
log.Println("SaveBlob()") log.Println("SaveBlob()")
} }
repo := getRepo(r)
dir := pat.Param(r, "type")
name := pat.Param(r, "name")
if isHashed(dir) { path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
name = filepath.Join(name[:2], name) if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
} }
path := filepath.Join(repo, dir, name)
tf, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600) tf, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600)
if err != nil { if err != nil {
@ -280,13 +371,12 @@ func DeleteBlob(w http.ResponseWriter, r *http.Request) {
if Config.Debug { if Config.Debug {
log.Println("DeleteBlob()") log.Println("DeleteBlob()")
} }
dir := pat.Param(r, "type")
name := pat.Param(r, "name")
if isHashed(dir) { path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
name = filepath.Join(name[:2], name) if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
} }
path := filepath.Join(getRepo(r), dir, name)
if err := os.Remove(path); err != nil { if err := os.Remove(path); err != nil {
if Config.Debug { if Config.Debug {
@ -306,7 +396,12 @@ func CreateRepo(w http.ResponseWriter, r *http.Request) {
if Config.Debug { if Config.Debug {
log.Println("CreateRepo()") log.Println("CreateRepo()")
} }
repo := getRepo(r)
repo, err := join(Config.Path, getRepo(r))
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if r.URL.Query().Get("create") != "true" { if r.URL.Query().Get("create") != "true" {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
@ -321,7 +416,11 @@ func CreateRepo(w http.ResponseWriter, r *http.Request) {
return return
} }
for _, d := range []string{"data", "index", "keys", "locks", "snapshots", "tmp"} { for _, d := range validTypes {
if d == "config" {
continue
}
if err := os.MkdirAll(filepath.Join(repo, d), 0700); err != nil { if err := os.MkdirAll(filepath.Join(repo, d), 0700); err != nil {
log.Print(err) log.Print(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)

39
handlers_test.go Normal file
View file

@ -0,0 +1,39 @@
package restserver
import (
"path/filepath"
"testing"
)
func TestJoin(t *testing.T) {
var tests = []struct {
base, name string
result string
}{
{"/", "foo/bar", "/foo/bar"},
{"/srv/server", "foo/bar", "/srv/server/foo/bar"},
{"/srv/server", "/foo/bar", "/srv/server/foo/bar"},
{"/srv/server", "foo/../bar", "/srv/server/bar"},
{"/srv/server", "../bar", "/srv/server/bar"},
{"/srv/server", "..", "/srv/server"},
{"/srv/server", "../..", "/srv/server"},
{"/srv/server", "/repo/data/", "/srv/server/repo/data"},
{"/srv/server", "/repo/data/../..", "/srv/server"},
{"/srv/server", "/repo/data/../data/../../..", "/srv/server"},
{"/srv/server", "/repo/data/../data/../../..", "/srv/server"},
}
for _, test := range tests {
t.Run("", func(t *testing.T) {
got, err := join(filepath.FromSlash(test.base), test.name)
if err != nil {
t.Fatal(err)
}
want := filepath.FromSlash(test.result)
if got != want {
t.Fatalf("wrong result returned, want %v, got %v", want, got)
}
})
}
}