mirror of
https://github.com/restic/rest-server.git
synced 2025-10-19 07:33:21 +00:00
Fix directory traversal
This commit introduces the strict checks from net/http.Dir, which fixes a directory traversal issue. Closes #22
This commit is contained in:
parent
9a6bb5eebe
commit
a628c4e01a
2 changed files with 172 additions and 34 deletions
167
handlers.go
167
handlers.go
|
@ -2,12 +2,14 @@ package restserver
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -20,12 +22,83 @@ func isHashed(dir string) bool {
|
|||
return dir == "data"
|
||||
}
|
||||
|
||||
func getRepo(r *http.Request) string {
|
||||
if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo") {
|
||||
return filepath.Join(Config.Path, pat.Param(r, "repo"))
|
||||
func valid(name string) bool {
|
||||
// taken from net/http.Dir
|
||||
if strings.Contains(name, "\x00") {
|
||||
return false
|
||||
}
|
||||
|
||||
return Config.Path
|
||||
if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
var validTypes = []string{"data", "index", "keys", "locks", "snapshots", "config"}
|
||||
|
||||
func isValidType(name string) bool {
|
||||
for _, tpe := range validTypes {
|
||||
if name == tpe {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// join takes a number of path names, sanitizes them, and returns them joined
|
||||
// with base for the current operating system to use (dirs separated by
|
||||
// filepath.Separator). The returned path is always either equal to base or a
|
||||
// subdir of base.
|
||||
func join(base string, names ...string) (string, error) {
|
||||
clean := make([]string, 0, len(names)+1)
|
||||
clean = append(clean, base)
|
||||
|
||||
// taken from net/http.Dir
|
||||
for _, name := range names {
|
||||
if !valid(name) {
|
||||
return "", errors.New("invalid character in path")
|
||||
}
|
||||
|
||||
clean = append(clean, filepath.FromSlash(path.Clean("/"+name)))
|
||||
}
|
||||
|
||||
return filepath.Join(clean...), nil
|
||||
}
|
||||
|
||||
// getRepo returns the repository location, relative to Config.Path.
|
||||
func getRepo(r *http.Request) string {
|
||||
if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo") {
|
||||
return pat.Param(r, "repo")
|
||||
}
|
||||
|
||||
return "."
|
||||
}
|
||||
|
||||
// getPath returns the path for a file type in the repo.
|
||||
func getPath(r *http.Request, fileType string) (string, error) {
|
||||
if !isValidType(fileType) {
|
||||
return "", errors.New("invalid file type")
|
||||
}
|
||||
return join(Config.Path, getRepo(r), fileType)
|
||||
}
|
||||
|
||||
// getFilePath returns the path for a file in the repo.
|
||||
func getFilePath(r *http.Request, fileType, name string) (string, error) {
|
||||
if !isValidType(fileType) {
|
||||
return "", errors.New("invalid file type")
|
||||
}
|
||||
|
||||
if isHashed(fileType) {
|
||||
if len(name) < 2 {
|
||||
return "", errors.New("file name is too short")
|
||||
}
|
||||
|
||||
return join(Config.Path, getRepo(r), fileType, name[:2], name)
|
||||
}
|
||||
|
||||
return join(Config.Path, getRepo(r), fileType, name)
|
||||
}
|
||||
|
||||
// AuthHandler wraps h with a http.HandlerFunc that performs basic authentication against the user/passwords pairs
|
||||
|
@ -46,7 +119,12 @@ func CheckConfig(w http.ResponseWriter, r *http.Request) {
|
|||
if Config.Debug {
|
||||
log.Println("CheckConfig()")
|
||||
}
|
||||
cfg := filepath.Join(getRepo(r), "config")
|
||||
cfg, err := getPath(r, "config")
|
||||
if err != nil {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
st, err := os.Stat(cfg)
|
||||
if err != nil {
|
||||
if Config.Debug {
|
||||
|
@ -64,7 +142,11 @@ func GetConfig(w http.ResponseWriter, r *http.Request) {
|
|||
if Config.Debug {
|
||||
log.Println("GetConfig()")
|
||||
}
|
||||
cfg := filepath.Join(getRepo(r), "config")
|
||||
cfg, err := getPath(r, "config")
|
||||
if err != nil {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
bytes, err := ioutil.ReadFile(cfg)
|
||||
if err != nil {
|
||||
|
@ -83,7 +165,11 @@ func SaveConfig(w http.ResponseWriter, r *http.Request) {
|
|||
if Config.Debug {
|
||||
log.Println("SaveConfig()")
|
||||
}
|
||||
cfg := filepath.Join(getRepo(r), "config")
|
||||
cfg, err := getPath(r, "config")
|
||||
if err != nil {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
bytes, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
|
@ -108,8 +194,13 @@ func DeleteConfig(w http.ResponseWriter, r *http.Request) {
|
|||
if Config.Debug {
|
||||
log.Println("DeleteConfig()")
|
||||
}
|
||||
cfg, err := getPath(r, "config")
|
||||
if err != nil {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.Remove(filepath.Join(getRepo(r), "config")); err != nil {
|
||||
if err := os.Remove(cfg); err != nil {
|
||||
if Config.Debug {
|
||||
log.Print(err)
|
||||
}
|
||||
|
@ -127,8 +218,12 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) {
|
|||
if Config.Debug {
|
||||
log.Println("ListBlobs()")
|
||||
}
|
||||
dir := pat.Param(r, "type")
|
||||
path := filepath.Join(getRepo(r), dir)
|
||||
fileType := pat.Param(r, "type")
|
||||
path, err := getPath(r, fileType)
|
||||
if err != nil {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
items, err := ioutil.ReadDir(path)
|
||||
if err != nil {
|
||||
|
@ -141,7 +236,7 @@ func ListBlobs(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var names []string
|
||||
for _, i := range items {
|
||||
if isHashed(dir) {
|
||||
if isHashed(fileType) {
|
||||
subpath := filepath.Join(path, i.Name())
|
||||
subitems, err := ioutil.ReadDir(subpath)
|
||||
if err != nil {
|
||||
|
@ -176,13 +271,12 @@ func CheckBlob(w http.ResponseWriter, r *http.Request) {
|
|||
if Config.Debug {
|
||||
log.Println("CheckBlob()")
|
||||
}
|
||||
dir := pat.Param(r, "type")
|
||||
name := pat.Param(r, "name")
|
||||
|
||||
if isHashed(dir) {
|
||||
name = filepath.Join(name[:2], name)
|
||||
path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
|
||||
if err != nil {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
path := filepath.Join(getRepo(r), dir, name)
|
||||
|
||||
st, err := os.Stat(path)
|
||||
if err != nil {
|
||||
|
@ -201,13 +295,12 @@ func GetBlob(w http.ResponseWriter, r *http.Request) {
|
|||
if Config.Debug {
|
||||
log.Println("GetBlob()")
|
||||
}
|
||||
dir := pat.Param(r, "type")
|
||||
name := pat.Param(r, "name")
|
||||
|
||||
if isHashed(dir) {
|
||||
name = filepath.Join(name[:2], name)
|
||||
path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
|
||||
if err != nil {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
path := filepath.Join(getRepo(r), dir, name)
|
||||
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
|
@ -227,14 +320,12 @@ func SaveBlob(w http.ResponseWriter, r *http.Request) {
|
|||
if Config.Debug {
|
||||
log.Println("SaveBlob()")
|
||||
}
|
||||
repo := getRepo(r)
|
||||
dir := pat.Param(r, "type")
|
||||
name := pat.Param(r, "name")
|
||||
|
||||
if isHashed(dir) {
|
||||
name = filepath.Join(name[:2], name)
|
||||
path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
|
||||
if err != nil {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
path := filepath.Join(repo, dir, name)
|
||||
|
||||
tf, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600)
|
||||
if err != nil {
|
||||
|
@ -280,13 +371,12 @@ func DeleteBlob(w http.ResponseWriter, r *http.Request) {
|
|||
if Config.Debug {
|
||||
log.Println("DeleteBlob()")
|
||||
}
|
||||
dir := pat.Param(r, "type")
|
||||
name := pat.Param(r, "name")
|
||||
|
||||
if isHashed(dir) {
|
||||
name = filepath.Join(name[:2], name)
|
||||
path, err := getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
|
||||
if err != nil {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
path := filepath.Join(getRepo(r), dir, name)
|
||||
|
||||
if err := os.Remove(path); err != nil {
|
||||
if Config.Debug {
|
||||
|
@ -306,7 +396,12 @@ func CreateRepo(w http.ResponseWriter, r *http.Request) {
|
|||
if Config.Debug {
|
||||
log.Println("CreateRepo()")
|
||||
}
|
||||
repo := getRepo(r)
|
||||
|
||||
repo, err := join(Config.Path, getRepo(r))
|
||||
if err != nil {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Query().Get("create") != "true" {
|
||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
|
@ -321,7 +416,11 @@ func CreateRepo(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
for _, d := range []string{"data", "index", "keys", "locks", "snapshots", "tmp"} {
|
||||
for _, d := range validTypes {
|
||||
if d == "config" {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(repo, d), 0700); err != nil {
|
||||
log.Print(err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue