Merge pull request #112 from wojas/split-repo-handler

Split Server component and add support for subrepositories
This commit is contained in:
Alexander Neumann 2021-08-09 10:55:29 +02:00 committed by GitHub
commit d39bc8e6cf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 1166 additions and 805 deletions

View file

@ -60,6 +60,7 @@ To learn how to use restic backup client with REST backend, please consult [rest
--path string data directory (default "/tmp/restic")
--private-repos users can only access their private repo
--prometheus enable Prometheus metrics
--prometheus-no-auth disable auth for Prometheus /metrics endpoint
--tls turn on TLS support
--tls-cert string TLS certificate path
--tls-key string TLS key path
@ -86,7 +87,7 @@ Signed certificate is required by the restic backend, but if you just want to te
The `--append-only` mode allows creation of new backups but prevents deletion and modification of existing backups. This can be useful when backing up systems that have a potential of being hacked.
To prevent your users from accessing each others' repositories, you may use the `--private-repos` flag which grants access only when a subdirectory with the same name as the user is specified in the repository URL. For example, user "foo" using the repository URLs `rest:https://foo:pass@host:8000/foo` or `rest:https://foo:pass@host:8000/foo/` would be granted access, but the same user using repository URLs `rest:https://foo:pass@host:8000/` or `rest:https://foo:pass@host:8000/foobar/` would be denied access.
To prevent your users from accessing each others' repositories, you may use the `--private-repos` flag which grants access only when a subdirectory with the same name as the user is specified in the repository URL. For example, user "foo" using the repository URLs `rest:https://foo:pass@host:8000/foo` or `rest:https://foo:pass@host:8000/foo/` would be granted access, but the same user using repository URLs `rest:https://foo:pass@host:8000/` or `rest:https://foo:pass@host:8000/foobar/` would be denied access. Users can also create their own subrepositories, like `/foo/bar/`.
Rest Server uses exactly the same directory structure as local backend, so you should be able to access it both locally and via HTTP, even simultaneously.
@ -125,7 +126,7 @@ or
## Prometheus support and Grafana dashboard
The server can be started with `--prometheus` to expose [Prometheus](https://prometheus.io/) metrics at `/metrics`.
The server can be started with `--prometheus` to expose [Prometheus](https://prometheus.io/) metrics at `/metrics`. If authentication is enabled, this endpoint requires authentication for the 'metrics' user, but this can be overridden with the `--prometheus-no-auth` flag.
This repository contains an example full stack Docker Compose setup with a Grafana dashboard in [examples/compose-with-grafana/](examples/compose-with-grafana/).
@ -136,6 +137,8 @@ Compared to the SFTP backend, the REST backend has better performance, especiall
But, even if you use HTTPS transport, the REST protocol should be faster and more scalable, due to some inefficiencies of the SFTP protocol (everything needs to be transferred in chunks of 32 KiB at most, each packet needs to be acknowledged by the server).
One important safety feature that Rest Server adds is the optional ability to run in append-only mode. This prevents an attacker from wiping your server backups when access is gained to the server being backed up.
Finally, the Rest Server implementation is really simple and as such could be used on the low-end devices, no problem. Also, in some cases, for example behind corporate firewalls, HTTP/S might be the only protocol allowed. Here too REST backend might be the perfect option for your backup needs.
## Contributors

View file

@ -0,0 +1,22 @@
Change: refactor handlers, add subrepo support
We have split out all HTTP handling to a separate `repo` subpackage to cleanly
separate the server code from the code that handles a single repository.
The refactoring makes the code significantly easier to follow and understand,
which in turn makes it easier to add new features, audit for security and debug
issues.
The new RepoHandler also makes it easier to reuse rest-server as a Go component in
any other HTTP server.
As part of the refactoring, support for multi-level repositories has been added, so
now each user can have its own subrepositories. This feature is always enabled.
Authentication for the Prometheus /metrics endpoint can now be disabled with the
new `--prometheus-no-auth` flag.
https://github.com/restic/restic/pull/112
https://github.com/restic/restic/issues/109
https://github.com/restic/restic/issues/107

View file

@ -50,6 +50,7 @@ func init() {
flags.BoolVar(&server.AppendOnly, "append-only", server.AppendOnly, "enable append only mode")
flags.BoolVar(&server.PrivateRepos, "private-repos", server.PrivateRepos, "users can only access their private repo")
flags.BoolVar(&server.Prometheus, "prometheus", server.Prometheus, "enable Prometheus metrics")
flags.BoolVar(&server.Prometheus, "prometheus-no-auth", server.PrometheusNoAuth, "disable auth for Prometheus /metrics endpoint")
flags.BoolVarP(&showVersion, "version", "V", showVersion, "output version and exit")
}
@ -75,21 +76,6 @@ func tlsSettings() (bool, string, string, error) {
return server.TLS, key, cert, nil
}
func getHandler(server restserver.Server) (http.Handler, error) {
mux := restserver.NewHandler(server)
if server.NoAuth {
log.Println("Authentication disabled")
return mux, nil
}
log.Println("Authentication enabled")
htpasswdFile, err := restserver.NewHtpasswdFromFile(filepath.Join(server.Path, ".htpasswd"))
if err != nil {
return nil, fmt.Errorf("cannot load .htpasswd (use --no-auth to disable): %v", err)
}
return server.AuthHandler(htpasswdFile, mux), nil
}
func runRoot(cmd *cobra.Command, args []string) error {
if showVersion {
fmt.Printf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH)
@ -112,7 +98,13 @@ func runRoot(cmd *cobra.Command, args []string) error {
defer pprof.StopCPUProfile()
}
handler, err := getHandler(server)
if server.NoAuth {
log.Println("Authentication disabled")
} else {
log.Println("Authentication enabled")
}
handler, err := restserver.NewHandler(&server)
if err != nil {
log.Fatalf("error: %v", err)
}

View file

@ -86,14 +86,16 @@ func TestGetHandler(t *testing.T) {
}
}()
getHandler := restserver.NewHandler
// With NoAuth = false and no .htpasswd
_, err = getHandler(restserver.Server{Path: dir})
_, err = getHandler(&restserver.Server{Path: dir})
if err == nil {
t.Errorf("NoAuth=false: expected error, got nil")
}
// With NoAuth = true and no .htpasswd
_, err = getHandler(restserver.Server{NoAuth: true, Path: dir})
_, err = getHandler(&restserver.Server{NoAuth: true, Path: dir})
if err != nil {
t.Errorf("NoAuth=true: expected no error, got %v", err)
}
@ -112,7 +114,7 @@ func TestGetHandler(t *testing.T) {
}()
// With NoAuth = false and with .htpasswd
_, err = getHandler(restserver.Server{Path: dir})
_, err = getHandler(&restserver.Server{Path: dir})
if err != nil {
t.Errorf("NoAuth=false with .htpasswd: expected no error, got %v", err)
}

1
go.mod
View file

@ -15,7 +15,6 @@ require (
github.com/prometheus/procfs v0.0.0-20180212145926-282c8707aa21 // indirect
github.com/spf13/cobra v0.0.1
github.com/spf13/pflag v1.0.0 // indirect
goji.io v2.0.2+incompatible
golang.org/x/crypto v0.0.0-20180214000028-650f4a345ab4
golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a // indirect
)

2
go.sum
View file

@ -22,8 +22,6 @@ github.com/spf13/cobra v0.0.1 h1:zZh3X5aZbdnoj+4XkaBxKfhO4ot82icYdhhREIAXIj8=
github.com/spf13/cobra v0.0.1/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ=
github.com/spf13/pflag v1.0.0 h1:oaPbdDe/x0UncahuwiPxW1GYJyilRAdsPnq3e1yaPcI=
github.com/spf13/pflag v1.0.0/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
goji.io v2.0.2+incompatible h1:uIssv/elbKRLznFUy3Xj4+2Mz/qKhek/9aZQDUMae7c=
goji.io v2.0.2+incompatible/go.mod h1:sbqFwrtqZACxLBTQcdgVjFh54yGVCvwq8+w49MVMMIk=
golang.org/x/crypto v0.0.0-20180214000028-650f4a345ab4 h1:OfaUle5HH9Y0obNU74mlOZ/Igdtwi3eGOKcljJsTnbw=
golang.org/x/crypto v0.0.0-20180214000028-650f4a345ab4/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a h1:WXEvlFVvvGxCJLG6REjsT03iWnKLEWinaScsxF2Vm2o=

View file

@ -1,26 +1,18 @@
package restserver
import (
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"path"
"path/filepath"
"strings"
"time"
"github.com/miolini/datacounter"
"github.com/prometheus/client_golang/prometheus"
"goji.io/middleware"
"goji.io/pat"
"github.com/restic/rest-server/quota"
"github.com/restic/rest-server/repo"
)
// Server determines how a Mux's handlers behave.
// Server encapsulates the rest-server's settings and repo management logic
type Server struct {
Path string
Listen string
@ -33,29 +25,89 @@ type Server struct {
AppendOnly bool
PrivateRepos bool
Prometheus bool
PrometheusNoAuth bool
Debug bool
MaxRepoSize int64
PanicOnError bool
repoSize int64 // must be accessed using sync/atomic
htpasswdFile *HtpasswdFile
quotaManager *quota.Manager
}
func (s *Server) isHashed(dir string) bool {
return dir == "data"
// MaxFolderDepth is the maxDepth param passed to splitURLPath.
// A max depth of 2 mean that we accept folders like: '/', '/foo' and '/foo/bar'
// TODO: Move to a Server option
const MaxFolderDepth = 2
// httpDefaultError write a HTTP error with the default description
func httpDefaultError(w http.ResponseWriter, code int) {
http.Error(w, http.StatusText(code), code)
}
// ServeHTTP makes this server an http.Handler. It handlers the administrative
// part of the request (figuring out the filesystem location, performing
// authentication, etc) and then passes it on to repo.Handler for actual
// REST API processing.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// First of all, check auth (will always pass if NoAuth is set)
username, ok := s.checkAuth(r)
if !ok {
httpDefaultError(w, http.StatusUnauthorized)
return
}
// Perform the path parsing to determine the repo folder and remainder for the
// repo handler.
folderPath, remainder := splitURLPath(r.URL.Path, MaxFolderDepth)
if !folderPathValid(folderPath) {
log.Printf("Invalid request path: %s", r.URL.Path)
httpDefaultError(w, http.StatusNotFound)
return
}
// Check if the current user is allowed to access this path
if !s.NoAuth && s.PrivateRepos {
if len(folderPath) == 0 || folderPath[0] != username {
httpDefaultError(w, http.StatusUnauthorized)
return
}
}
// Determine filesystem path for this repo
fsPath, err := join(s.Path, folderPath...)
if err != nil {
// We did not expect an error at this stage, because we just checked the path
log.Printf("Unexpected join error for path %q", r.URL.Path)
httpDefaultError(w, http.StatusNotFound)
return
}
// Pass the request to the repo.Handler
opt := repo.Options{
AppendOnly: s.AppendOnly,
Debug: s.Debug,
QuotaManager: s.quotaManager, // may be nil
PanicOnError: s.PanicOnError,
}
if s.Prometheus {
opt.BlobMetricFunc = makeBlobMetricFunc(username, folderPath)
}
repoHandler, err := repo.New(fsPath, opt)
if err != nil {
log.Printf("repo.New error: %v", err)
httpDefaultError(w, http.StatusInternalServerError)
return
}
r.URL.Path = remainder // strip folderPath for next handler
repoHandler.ServeHTTP(w, r)
}
func valid(name string) bool {
// Based on net/http.Dir
// taken from net/http.Dir
if strings.Contains(name, "\x00") {
return false
}
// Path characters that are disallowed or unsafe under some operating systems
// are not allowed here.
// The most important one here is '/', since Goji does not decode '%2F' to '/'
// during routing, so we can end up with a '/' in the name here.
if strings.ContainsAny(name, "/\\:*?\"<>|") {
return false
}
if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) {
return false
}
@ -63,15 +115,17 @@ func valid(name string) bool {
return true
}
var validTypes = []string{"data", "index", "keys", "locks", "snapshots", "config"}
func (s *Server) isValidType(name string) bool {
for _, tpe := range validTypes {
func isValidType(name string) bool {
for _, tpe := range repo.ObjectTypes {
if name == tpe {
return true
}
}
for _, tpe := range repo.FileTypes {
if name == tpe {
return true
}
}
return false
}
@ -95,600 +149,44 @@ func join(base string, names ...string) (string, error) {
return filepath.Join(clean...), nil
}
// getRepo returns the repository location, relative to s.Path.
func (s *Server) getRepo(r *http.Request) string {
if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo") {
return pat.Param(r, "repo")
// splitURLPath splits the URL path into a folderPath of the subrepo, and
// a remainder that can be passed to repo.Handler.
// Example: /foo/bar/locks/0123... will be split into:
// ["foo", "bar"] and "/locks/0123..."
func splitURLPath(urlPath string, maxDepth int) (folderPath []string, remainder string) {
if !strings.HasPrefix(urlPath, "/") {
// Really should start with "/"
return nil, urlPath
}
return "."
p := strings.SplitN(urlPath, "/", maxDepth+2)
// Skip the empty first one and the remainder in the last one
for _, name := range p[1 : len(p)-1] {
if isValidType(name) {
// We found a part that is a special repo file or dir
break
}
folderPath = append(folderPath, name)
}
// If the folder path is empty, the whole path is the remainder (do not strip '/')
if len(folderPath) == 0 {
return nil, urlPath
}
// Check that the urlPath starts with the reconstructed path, which should
// always be the case.
fullFolderPath := "/" + strings.Join(folderPath, "/")
if !strings.HasPrefix(urlPath, fullFolderPath) {
return nil, urlPath
}
return folderPath, urlPath[len(fullFolderPath):]
}
// getPath returns the path for a file type in the repo.
func (s *Server) getPath(r *http.Request, fileType string) (string, error) {
if !s.isValidType(fileType) {
return "", errors.New("invalid file type")
}
return join(s.Path, s.getRepo(r), fileType)
}
// getFilePath returns the path for a file in the repo.
func (s *Server) getFilePath(r *http.Request, fileType, name string) (string, error) {
if !s.isValidType(fileType) {
return "", errors.New("invalid file type")
}
if s.isHashed(fileType) {
if len(name) < 2 {
return "", errors.New("file name is too short")
}
return join(s.Path, s.getRepo(r), fileType, name[:2], name)
}
return join(s.Path, s.getRepo(r), fileType, name)
}
// getUser returns the username from the request, or an empty string if none.
func (s *Server) getUser(r *http.Request) string {
username, _, ok := r.BasicAuth()
if !ok {
return ""
}
return username
}
// getMetricLabels returns the prometheus labels from the request.
func (s *Server) getMetricLabels(r *http.Request) prometheus.Labels {
labels := prometheus.Labels{
"user": s.getUser(r),
"repo": s.getRepo(r),
"type": pat.Param(r, "type"),
}
return labels
}
// isUserPath checks if a request path is accessible by the user when using
// private repositories.
func isUserPath(username, path string) bool {
prefix := "/" + username
if !strings.HasPrefix(path, prefix) {
// folderPathValid checks if a folderPath returned by splitURLPath is valid and
// safe.
func folderPathValid(folderPath []string) bool {
for _, name := range folderPath {
if name == "" || name == ".." || name == "." || !valid(name) {
return false
}
return len(path) == len(prefix) || path[len(prefix)] == '/'
}
// AuthHandler wraps h with a http.HandlerFunc that performs basic authentication against the user/passwords pairs
// stored in f and returns the http.HandlerFunc.
func (s *Server) AuthHandler(f *HtpasswdFile, h http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok || !f.Validate(username, password) {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
// resolve all relative elements in the path
urlPath := path.Clean(r.URL.Path)
if s.PrivateRepos && !isUserPath(username, urlPath) && urlPath != "/metrics" {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
h.ServeHTTP(w, r)
}
}
// CheckConfig checks whether a configuration exists.
func (s *Server) CheckConfig(w http.ResponseWriter, r *http.Request) {
if s.Debug {
log.Println("CheckConfig()")
}
cfg, err := s.getPath(r, "config")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
st, err := os.Stat(cfg)
if err != nil {
if s.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
w.Header().Add("Content-Length", fmt.Sprint(st.Size()))
}
// GetConfig allows for a config to be retrieved.
func (s *Server) GetConfig(w http.ResponseWriter, r *http.Request) {
if s.Debug {
log.Println("GetConfig()")
}
cfg, err := s.getPath(r, "config")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
bytes, err := ioutil.ReadFile(cfg)
if err != nil {
if s.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
_, _ = w.Write(bytes)
}
// SaveConfig allows for a config to be saved.
func (s *Server) SaveConfig(w http.ResponseWriter, r *http.Request) {
if s.Debug {
log.Println("SaveConfig()")
}
cfg, err := s.getPath(r, "config")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
f, err := os.OpenFile(cfg, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600)
if err != nil && os.IsExist(err) {
if s.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
_, err = io.Copy(f, r.Body)
if err != nil {
if s.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
err = f.Close()
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
_ = r.Body.Close()
}
// DeleteConfig removes a config.
func (s *Server) DeleteConfig(w http.ResponseWriter, r *http.Request) {
if s.Debug {
log.Println("DeleteConfig()")
}
if s.AppendOnly {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
cfg, err := s.getPath(r, "config")
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if err := os.Remove(cfg); err != nil {
if s.Debug {
log.Print(err)
}
if os.IsNotExist(err) {
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
} else {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
return
}
}
const (
mimeTypeAPIV1 = "application/vnd.x.restic.rest.v1"
mimeTypeAPIV2 = "application/vnd.x.restic.rest.v2"
)
// ListBlobs lists all blobs of a given type in an arbitrary order.
func (s *Server) ListBlobs(w http.ResponseWriter, r *http.Request) {
if s.Debug {
log.Println("ListBlobs()")
}
switch r.Header.Get("Accept") {
case mimeTypeAPIV2:
s.ListBlobsV2(w, r)
default:
s.ListBlobsV1(w, r)
}
}
// ListBlobsV1 lists all blobs of a given type in an arbitrary order.
func (s *Server) ListBlobsV1(w http.ResponseWriter, r *http.Request) {
if s.Debug {
log.Println("ListBlobsV1()")
}
fileType := pat.Param(r, "type")
path, err := s.getPath(r, fileType)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
items, err := ioutil.ReadDir(path)
if err != nil {
if s.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
var names []string
for _, i := range items {
if s.isHashed(fileType) {
subpath := filepath.Join(path, i.Name())
var subitems []os.FileInfo
subitems, err = ioutil.ReadDir(subpath)
if err != nil {
if s.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
for _, f := range subitems {
names = append(names, f.Name())
}
} else {
names = append(names, i.Name())
}
}
data, err := json.Marshal(names)
if err != nil {
if s.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", mimeTypeAPIV1)
_, _ = w.Write(data)
}
// Blob represents a single blob, its name and its size.
type Blob struct {
Name string `json:"name"`
Size int64 `json:"size"`
}
// ListBlobsV2 lists all blobs of a given type, together with their sizes, in an arbitrary order.
func (s *Server) ListBlobsV2(w http.ResponseWriter, r *http.Request) {
if s.Debug {
log.Println("ListBlobsV2()")
}
fileType := pat.Param(r, "type")
path, err := s.getPath(r, fileType)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
items, err := ioutil.ReadDir(path)
if err != nil {
if s.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
var blobs []Blob
for _, i := range items {
if s.isHashed(fileType) {
subpath := filepath.Join(path, i.Name())
var subitems []os.FileInfo
subitems, err = ioutil.ReadDir(subpath)
if err != nil {
if s.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
for _, f := range subitems {
blobs = append(blobs, Blob{Name: f.Name(), Size: f.Size()})
}
} else {
blobs = append(blobs, Blob{Name: i.Name(), Size: i.Size()})
}
}
data, err := json.Marshal(blobs)
if err != nil {
if s.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", mimeTypeAPIV2)
_, _ = w.Write(data)
}
// CheckBlob tests whether a blob exists.
func (s *Server) CheckBlob(w http.ResponseWriter, r *http.Request) {
if s.Debug {
log.Println("CheckBlob()")
}
path, err := s.getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
st, err := os.Stat(path)
if err != nil {
if s.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
w.Header().Add("Content-Length", fmt.Sprint(st.Size()))
}
// GetBlob retrieves a blob from the repository.
func (s *Server) GetBlob(w http.ResponseWriter, r *http.Request) {
if s.Debug {
log.Println("GetBlob()")
}
path, err := s.getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
file, err := os.Open(path)
if err != nil {
if s.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
wc := datacounter.NewResponseWriterCounter(w)
http.ServeContent(wc, r, "", time.Unix(0, 0), file)
if err = file.Close(); err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if s.Prometheus {
labels := s.getMetricLabels(r)
metricBlobReadTotal.With(labels).Inc()
metricBlobReadBytesTotal.With(labels).Add(float64(wc.Count()))
}
}
// tallySize counts the size of the contents of path.
func tallySize(path string) (int64, error) {
if path == "" {
path = "."
}
var size int64
err := filepath.Walk(path, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
size += info.Size()
return nil
})
return size, err
}
// SaveBlob saves a blob to the repository.
func (s *Server) SaveBlob(w http.ResponseWriter, r *http.Request) {
if s.Debug {
log.Println("SaveBlob()")
}
path, err := s.getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
tf, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600)
if os.IsNotExist(err) {
// the error is caused by a missing directory, create it and retry
mkdirErr := os.MkdirAll(filepath.Dir(path), 0700)
if mkdirErr != nil {
log.Print(mkdirErr)
} else {
// try again
tf, err = os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600)
}
}
if os.IsExist(err) {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
if err != nil {
if s.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
// ensure this blob does not put us over the repo size limit (if there is one)
var outFile io.Writer = tf
if s.MaxRepoSize != 0 {
var errCode int
outFile, errCode, err = s.maxSizeWriter(r, tf)
if err != nil {
if s.Debug {
log.Println(err)
}
if errCode > 0 {
http.Error(w, http.StatusText(errCode), errCode)
}
return
}
}
written, err := io.Copy(outFile, r.Body)
if err != nil {
_ = tf.Close()
_ = os.Remove(path)
if s.MaxRepoSize > 0 {
s.incrementRepoSpaceUsage(-written)
}
if s.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
if err := tf.Sync(); err != nil {
_ = tf.Close()
_ = os.Remove(path)
if s.MaxRepoSize > 0 {
s.incrementRepoSpaceUsage(-written)
}
if s.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if err := tf.Close(); err != nil {
_ = os.Remove(path)
if s.MaxRepoSize > 0 {
s.incrementRepoSpaceUsage(-written)
}
if s.Debug {
log.Print(err)
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if s.Prometheus {
labels := s.getMetricLabels(r)
metricBlobWriteTotal.With(labels).Inc()
metricBlobWriteBytesTotal.With(labels).Add(float64(written))
}
}
// DeleteBlob deletes a blob from the repository.
func (s *Server) DeleteBlob(w http.ResponseWriter, r *http.Request) {
if s.Debug {
log.Println("DeleteBlob()")
}
if s.AppendOnly && pat.Param(r, "type") != "locks" {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
path, err := s.getFilePath(r, pat.Param(r, "type"), pat.Param(r, "name"))
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
var size int64
if s.Prometheus || s.MaxRepoSize > 0 {
stat, err := os.Stat(path)
if err == nil {
size = stat.Size()
}
}
if err := os.Remove(path); err != nil {
if s.Debug {
log.Print(err)
}
if os.IsNotExist(err) {
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
} else {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
return
}
if s.MaxRepoSize > 0 {
s.incrementRepoSpaceUsage(-size)
}
if s.Prometheus {
labels := s.getMetricLabels(r)
metricBlobDeleteTotal.With(labels).Inc()
metricBlobDeleteBytesTotal.With(labels).Add(float64(size))
}
}
// CreateRepo creates repository directories.
func (s *Server) CreateRepo(w http.ResponseWriter, r *http.Request) {
if s.Debug {
log.Println("CreateRepo()")
}
repo, err := join(s.Path, s.getRepo(r))
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if r.URL.Query().Get("create") != "true" {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
log.Printf("Creating repository directories in %s\n", repo)
if err := os.MkdirAll(repo, 0700); err != nil {
log.Print(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
for _, d := range validTypes {
if d == "config" {
continue
}
if err := os.MkdirAll(filepath.Join(repo, d), 0700); err != nil {
log.Print(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
}
for i := 0; i < 256; i++ {
if err := os.MkdirAll(filepath.Join(repo, "data", fmt.Sprintf("%02x", i)), 0700); err != nil {
log.Print(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
}
}
return true
}

View file

@ -4,12 +4,14 @@ import (
"bytes"
"crypto/rand"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
)
@ -46,27 +48,6 @@ func TestJoin(t *testing.T) {
}
}
func TestIsUserPath(t *testing.T) {
var tests = []struct {
username string
path string
result bool
}{
{"foo", "/", false},
{"foo", "/foo", true},
{"foo", "/foo/", true},
{"foo", "/foo/bar", true},
{"foo", "/foobar", false},
}
for _, test := range tests {
result := isUserPath(test.username, test.path)
if result != test.result {
t.Errorf("isUserPath(%q, %q) was incorrect, got: %v, want: %v.", test.username, test.path, result, test.result)
}
}
}
// declare a few helper functions
// wantFunc tests the HTTP response in res and calls t.Error() if something is incorrect.
@ -212,6 +193,14 @@ func TestResticHandler(t *testing.T) {
},
},
},
// Test subrepos
{createOverwriteDeleteSeq(t, "/parent1/sub1/config")},
{createOverwriteDeleteSeq(t, "/parent1/sub1/data/"+randomID)},
{createOverwriteDeleteSeq(t, "/parent1/config")},
{createOverwriteDeleteSeq(t, "/parent1/data/"+randomID)},
{createOverwriteDeleteSeq(t, "/parent2/config")},
{createOverwriteDeleteSeq(t, "/parent2/data/"+randomID)},
}
// setup rclone with a local backend in a temporary directory
@ -229,15 +218,23 @@ func TestResticHandler(t *testing.T) {
}()
// set append-only mode and configure path
mux := NewHandler(Server{
mux, err := NewHandler(&Server{
AppendOnly: true,
Path: tempdir,
NoAuth: true,
Debug: true,
PanicOnError: true,
})
if err != nil {
t.Fatalf("error from NewHandler: %v", err)
}
// create the repo
// create the repos
for _, path := range []string{"/", "/parent1/sub1/", "/parent1/", "/parent2/"} {
checkRequest(t, mux.ServeHTTP,
newRequest(t, "POST", "/?create=true", nil),
newRequest(t, "POST", path+"?create=true", nil),
[]wantFunc{wantCode(http.StatusOK)})
}
for _, test := range tests {
t.Run("", func(t *testing.T) {
@ -248,3 +245,64 @@ func TestResticHandler(t *testing.T) {
})
}
}
func TestSplitURLPath(t *testing.T) {
var tests = []struct {
// Params
urlPath string
maxDepth int
// Expected result
folderPath []string
remainder string
}{
{"/", 0, nil, "/"},
{"/", 2, nil, "/"},
{"/foo/bar/locks/0123", 0, nil, "/foo/bar/locks/0123"},
{"/foo/bar/locks/0123", 1, []string{"foo"}, "/bar/locks/0123"},
{"/foo/bar/locks/0123", 2, []string{"foo", "bar"}, "/locks/0123"},
{"/foo/bar/locks/0123", 3, []string{"foo", "bar"}, "/locks/0123"},
{"/foo/bar/zzz/locks/0123", 2, []string{"foo", "bar"}, "/zzz/locks/0123"},
{"/foo/bar/zzz/locks/0123", 3, []string{"foo", "bar", "zzz"}, "/locks/0123"},
{"/foo/bar/locks/", 2, []string{"foo", "bar"}, "/locks/"},
{"/foo/locks/", 2, []string{"foo"}, "/locks/"},
{"/foo/data/", 2, []string{"foo"}, "/data/"},
{"/foo/index/", 2, []string{"foo"}, "/index/"},
{"/foo/keys/", 2, []string{"foo"}, "/keys/"},
{"/foo/snapshots/", 2, []string{"foo"}, "/snapshots/"},
{"/foo/config", 2, []string{"foo"}, "/config"},
{"/foo/", 2, []string{"foo"}, "/"},
{"/foo/bar/", 2, []string{"foo", "bar"}, "/"},
{"/foo/bar", 2, []string{"foo"}, "/bar"},
{"/locks/", 2, nil, "/locks/"},
// This function only splits, it does not check the path components!
{"/././locks/", 2, []string{".", "."}, "/locks/"},
{"/../../locks/", 2, []string{"..", ".."}, "/locks/"},
{"///locks/", 2, []string{"", ""}, "/locks/"},
{"////locks/", 2, []string{"", ""}, "//locks/"},
// Robustness against broken input
{"/", -42, nil, "/"},
{"foo", 2, nil, "foo"},
{"foo/bar", 2, nil, "foo/bar"},
{"", 2, nil, ""},
}
for i, test := range tests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
folderPath, remainder := splitURLPath(test.urlPath, test.maxDepth)
var fpEqual bool
if len(test.folderPath) == 0 && len(folderPath) == 0 {
fpEqual = true // this check allows for nil vs empty slice
} else {
fpEqual = reflect.DeepEqual(test.folderPath, folderPath)
}
if !fpEqual {
t.Errorf("wrong folderPath: want %v, got %v", test.folderPath, folderPath)
}
if test.remainder != remainder {
t.Errorf("wrong remainder: want %v, got %v", test.remainder, remainder)
}
})
}
}

View file

@ -1,80 +0,0 @@
package restserver
import (
"fmt"
"io"
"net/http"
"strconv"
"sync/atomic"
)
// maxSizeWriter limits the number of bytes written
// to the space that is currently available as given by
// the server's MaxRepoSize. This type is safe for use
// by multiple goroutines sharing the same *Server.
type maxSizeWriter struct {
io.Writer
server *Server
}
func (w maxSizeWriter) Write(p []byte) (n int, err error) {
if int64(len(p)) > w.server.repoSpaceRemaining() {
return 0, fmt.Errorf("repository has reached maximum size (%d bytes)", w.server.MaxRepoSize)
}
n, err = w.Writer.Write(p)
w.server.incrementRepoSpaceUsage(int64(n))
return n, err
}
// maxSizeWriter wraps w in a writer that enforces s.MaxRepoSize.
// If there is an error, a status code and the error are returned.
func (s *Server) maxSizeWriter(req *http.Request, w io.Writer) (io.Writer, int, error) {
// if we haven't yet computed the size of the repo, do so now
currentSize := atomic.LoadInt64(&s.repoSize)
if currentSize == 0 {
initialSize, err := tallySize(s.Path)
if err != nil {
return nil, http.StatusInternalServerError, err
}
atomic.StoreInt64(&s.repoSize, initialSize)
currentSize = initialSize
}
// if content-length is set and is trustworthy, we can save some time
// and issue a polite error if it declares a size that's too big; since
// we expect the vast majority of clients will be honest, so this check
// can only help save time
if contentLenStr := req.Header.Get("Content-Length"); contentLenStr != "" {
contentLen, err := strconv.ParseInt(contentLenStr, 10, 64)
if err != nil {
return nil, http.StatusLengthRequired, err
}
if currentSize+contentLen > s.MaxRepoSize {
err := fmt.Errorf("incoming blob (%d bytes) would exceed maximum size of repository (%d bytes)",
contentLen, s.MaxRepoSize)
return nil, http.StatusRequestEntityTooLarge, err
}
}
// since we can't always trust content-length, we will wrap the writer
// in a custom writer that enforces the size limit during writes
return maxSizeWriter{Writer: w, server: s}, 0, nil
}
// repoSpaceRemaining returns how much space is available in the repo
// according to s.MaxRepoSize. s.repoSize must already be set.
// If there is no limit, -1 is returned.
func (s *Server) repoSpaceRemaining() int64 {
if s.MaxRepoSize == 0 {
return -1
}
maxSize := s.MaxRepoSize
currentSize := atomic.LoadInt64(&s.repoSize)
return maxSize - currentSize
}
// incrementRepoSpaceUsage increments the current repo size (which
// must already be initialized).
func (s *Server) incrementRepoSpaceUsage(by int64) {
atomic.AddInt64(&s.repoSize, by)
}

View file

@ -1,6 +1,11 @@
package restserver
import "github.com/prometheus/client_golang/prometheus"
import (
"strings"
"github.com/prometheus/client_golang/prometheus"
"github.com/restic/rest-server/repo"
)
var metricLabelList = []string{"user", "repo", "type"}
@ -52,6 +57,30 @@ var metricBlobDeleteBytesTotal = prometheus.NewCounterVec(
metricLabelList,
)
// makeBlobMetricFunc creates a metrics callback function that increments the
// Prometheus metrics.
func makeBlobMetricFunc(username string, folderPath []string) repo.BlobMetricFunc {
var f repo.BlobMetricFunc = func(objectType string, operation repo.BlobOperation, nBytes uint64) {
labels := prometheus.Labels{
"user": username,
"repo": strings.Join(folderPath, "/"),
"type": objectType,
}
switch operation {
case repo.BlobRead:
metricBlobReadTotal.With(labels).Inc()
metricBlobReadBytesTotal.With(labels).Add(float64(nBytes))
case repo.BlobWrite:
metricBlobWriteTotal.With(labels).Inc()
metricBlobWriteBytesTotal.With(labels).Add(float64(nBytes))
case repo.BlobDelete:
metricBlobDeleteTotal.With(labels).Inc()
metricBlobDeleteBytesTotal.With(labels).Add(float64(nBytes))
}
}
return f
}
func init() {
// These are always initialized, but only updated if Config.Prometheus is set
prometheus.MustRegister(metricBlobWriteTotal)

108
mux.go
View file

@ -1,15 +1,15 @@
package restserver
import (
"fmt"
"log"
"net/http"
"os"
goji "goji.io"
"path/filepath"
"github.com/gorilla/handlers"
"github.com/prometheus/client_golang/prometheus/promhttp"
"goji.io/pat"
"github.com/restic/rest-server/quota"
)
func (s *Server) debugHandler(next http.Handler) http.Handler {
@ -29,43 +29,71 @@ func (s *Server) logHandler(next http.Handler) http.Handler {
return handlers.CombinedLoggingHandler(accessLog, next)
}
// NewHandler returns the master HTTP multiplexer/router.
func NewHandler(server Server) *goji.Mux {
mux := goji.NewMux()
if server.Debug {
mux.Use(server.debugHandler)
func (s *Server) checkAuth(r *http.Request) (username string, ok bool) {
if s.NoAuth {
return username, true
}
if server.Log != "" {
mux.Use(server.logHandler)
var password string
username, password, ok = r.BasicAuth()
if !ok || !s.htpasswdFile.Validate(username, password) {
return "", false
}
if server.Prometheus {
mux.Handle(pat.Get("/metrics"), promhttp.Handler())
}
mux.HandleFunc(pat.Head("/config"), server.CheckConfig)
mux.HandleFunc(pat.Head("/:repo/config"), server.CheckConfig)
mux.HandleFunc(pat.Get("/config"), server.GetConfig)
mux.HandleFunc(pat.Get("/:repo/config"), server.GetConfig)
mux.HandleFunc(pat.Post("/config"), server.SaveConfig)
mux.HandleFunc(pat.Post("/:repo/config"), server.SaveConfig)
mux.HandleFunc(pat.Delete("/config"), server.DeleteConfig)
mux.HandleFunc(pat.Delete("/:repo/config"), server.DeleteConfig)
mux.HandleFunc(pat.Get("/:type/"), server.ListBlobs)
mux.HandleFunc(pat.Get("/:repo/:type/"), server.ListBlobs)
mux.HandleFunc(pat.Head("/:type/:name"), server.CheckBlob)
mux.HandleFunc(pat.Head("/:repo/:type/:name"), server.CheckBlob)
mux.HandleFunc(pat.Get("/:type/:name"), server.GetBlob)
mux.HandleFunc(pat.Get("/:repo/:type/:name"), server.GetBlob)
mux.HandleFunc(pat.Post("/:type/:name"), server.SaveBlob)
mux.HandleFunc(pat.Post("/:repo/:type/:name"), server.SaveBlob)
mux.HandleFunc(pat.Delete("/:type/:name"), server.DeleteBlob)
mux.HandleFunc(pat.Delete("/:repo/:type/:name"), server.DeleteBlob)
mux.HandleFunc(pat.Post("/"), server.CreateRepo)
mux.HandleFunc(pat.Post("/:repo"), server.CreateRepo)
mux.HandleFunc(pat.Post("/:repo/"), server.CreateRepo)
return mux
return username, true
}
func (s *Server) wrapMetricsAuth(f http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
username, ok := s.checkAuth(r)
if !ok {
httpDefaultError(w, http.StatusUnauthorized)
return
}
if s.PrivateRepos && username != "metrics" {
httpDefaultError(w, http.StatusUnauthorized)
return
}
f(w, r)
}
}
// NewHandler returns the master HTTP multiplexer/router.
func NewHandler(server *Server) (http.Handler, error) {
if !server.NoAuth {
var err error
server.htpasswdFile, err = NewHtpasswdFromFile(filepath.Join(server.Path, ".htpasswd"))
if err != nil {
return nil, fmt.Errorf("cannot load .htpasswd (use --no-auth to disable): %v", err)
}
}
const GiB = 1024 * 1024 * 1024
if server.MaxRepoSize > 0 {
log.Printf("Initializing quota (can take a while)...")
qm, err := quota.New(server.Path, server.MaxRepoSize)
if err != nil {
return nil, err
}
server.quotaManager = qm
log.Printf("Quota initialized, currently using %.2f GiB", float64(qm.SpaceUsed())/GiB)
}
mux := http.NewServeMux()
if server.Prometheus {
if server.PrometheusNoAuth {
mux.Handle("/metrics", promhttp.Handler())
} else {
mux.HandleFunc("/metrics", server.wrapMetricsAuth(promhttp.Handler().ServeHTTP))
}
}
mux.Handle("/", server)
var handler http.Handler = mux
if server.Debug {
handler = server.debugHandler(handler)
}
if server.Log != "" {
handler = server.logHandler(handler)
}
return handler, nil
}

124
quota/quota.go Normal file
View file

@ -0,0 +1,124 @@
package quota
import (
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
"sync/atomic"
)
// New creates a new quota Manager for given path.
// It will tally the current disk usage before returning.
func New(path string, maxSize int64) (*Manager, error) {
m := &Manager{
path: path,
maxRepoSize: maxSize,
}
if err := m.updateSize(); err != nil {
return nil, err
}
return m, nil
}
// Manager manages the repo quota for given filesystem root path, including subrepos
type Manager struct {
path string
maxRepoSize int64
repoSize int64 // must be accessed using sync/atomic
}
// WrapWriter limits the number of bytes written
// to the space that is currently available as given by
// the server's MaxRepoSize. This type is safe for use
// by multiple goroutines sharing the same *Server.
type maxSizeWriter struct {
io.Writer
m *Manager
}
func (w maxSizeWriter) Write(p []byte) (n int, err error) {
if int64(len(p)) > w.m.SpaceRemaining() {
return 0, fmt.Errorf("repository has reached maximum size (%d bytes)", w.m.maxRepoSize)
}
n, err = w.Writer.Write(p)
w.m.IncUsage(int64(n))
return n, err
}
func (m *Manager) updateSize() error {
// if we haven't yet computed the size of the repo, do so now
initialSize, err := tallySize(m.path)
if err != nil {
return err
}
atomic.StoreInt64(&m.repoSize, initialSize)
return nil
}
// WrapWriter wraps w in a writer that enforces s.MaxRepoSize.
// If there is an error, a status code and the error are returned.
func (m *Manager) WrapWriter(req *http.Request, w io.Writer) (io.Writer, int, error) {
currentSize := atomic.LoadInt64(&m.repoSize)
// if content-length is set and is trustworthy, we can save some time
// and issue a polite error if it declares a size that's too big; since
// we expect the vast majority of clients will be honest, so this check
// can only help save time
if contentLenStr := req.Header.Get("Content-Length"); contentLenStr != "" {
contentLen, err := strconv.ParseInt(contentLenStr, 10, 64)
if err != nil {
return nil, http.StatusLengthRequired, err
}
if currentSize+contentLen > m.maxRepoSize {
err := fmt.Errorf("incoming blob (%d bytes) would exceed maximum size of repository (%d bytes)",
contentLen, m.maxRepoSize)
return nil, http.StatusRequestEntityTooLarge, err
}
}
// since we can't always trust content-length, we will wrap the writer
// in a custom writer that enforces the size limit during writes
return maxSizeWriter{Writer: w, m: m}, 0, nil
}
// SpaceRemaining returns how much space is available in the repo
// according to s.MaxRepoSize. s.repoSize must already be set.
// If there is no limit, -1 is returned.
func (m *Manager) SpaceRemaining() int64 {
if m.maxRepoSize == 0 {
return -1
}
maxSize := m.maxRepoSize
currentSize := atomic.LoadInt64(&m.repoSize)
return maxSize - currentSize
}
// SpaceUsed returns how much space is used in the repo.
func (m *Manager) SpaceUsed() int64 {
return atomic.LoadInt64(&m.repoSize)
}
// IncUsage increments the current repo size (which
// must already be initialized).
func (m *Manager) IncUsage(by int64) {
atomic.AddInt64(&m.repoSize, by)
}
// tallySize counts the size of the contents of path.
func tallySize(path string) (int64, error) {
if path == "" {
path = "."
}
var size int64
err := filepath.Walk(path, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
size += info.Size()
return nil
})
return size, err
}

688
repo/repo.go Normal file
View file

@ -0,0 +1,688 @@
package repo
import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/miolini/datacounter"
"github.com/restic/rest-server/quota"
)
// Options are options for the Handler accepted by New
type Options struct {
AppendOnly bool // if set, delete actions are not allowed
Debug bool
DirMode os.FileMode
FileMode os.FileMode
// If set, we will panic when an internal server error happens. This
// makes it easier to debug such errors.
PanicOnError bool
BlobMetricFunc BlobMetricFunc
QuotaManager *quota.Manager
}
// DefaultDirMode is the file mode used for directory creation if not
// overridden in the Options
const DefaultDirMode os.FileMode = 0700
// DefaultFileMode is the file mode used for file creation if not
// overridden in the Options
const DefaultFileMode os.FileMode = 0600
// New creates a new Handler for a single Restic backup repo.
// path is the full filesystem path to this repo directory.
// opt is a set of options.
func New(path string, opt Options) (*Handler, error) {
if path == "" {
return nil, fmt.Errorf("path is required")
}
if opt.DirMode == 0 {
opt.DirMode = DefaultDirMode
}
if opt.FileMode == 0 {
opt.FileMode = DefaultFileMode
}
h := Handler{
path: path,
opt: opt,
}
return &h, nil
}
// Handler handles all REST API requests for a single Restic backup repo
// Spec: https://restic.readthedocs.io/en/latest/100_references.html#rest-backend
type Handler struct {
path string // filesystem path of repo
opt Options
}
// httpDefaultError write a HTTP error with the default description
func httpDefaultError(w http.ResponseWriter, code int) {
http.Error(w, http.StatusText(code), code)
}
// httpMethodNotAllowed writes a 405 Method Not Allowed HTTP error with
// the required Allow header listing the methods that are allowed.
func httpMethodNotAllowed(w http.ResponseWriter, allowed []string) {
w.Header().Set("Allow", strings.Join(allowed, ", "))
httpDefaultError(w, http.StatusMethodNotAllowed)
}
// BlobPathRE matches valid blob URI paths with optional object IDs
var BlobPathRE = regexp.MustCompile(`^/(data|index|keys|locks|snapshots)/([0-9a-f]{64})?$`)
// ObjectTypes are subdirs that are used for object storage
var ObjectTypes = []string{"data", "index", "keys", "locks", "snapshots"}
// FileTypes are files stored directly under the repo direct that are accessible
// through a request
var FileTypes = []string{"config"}
func isHashed(objectType string) bool {
return objectType == "data"
}
// BlobOperation describe the current blob operation in the BlobMetricFunc callback.
type BlobOperation byte
// Define all valid operations.
const (
BlobRead = 'R' // A blob has been read
BlobWrite = 'W' // A blob has been written
BlobDelete = 'D' // A blob has been deleted
)
// BlobMetricFunc is the callback signature for blob metrics. Such a callback
// can be passed in the Options to keep track of various metrics.
// objectType: one of ObjectTypes
// operation: one of the BlobOperations above
// nBytes: the number of bytes affected, or 0 if not relevant
// TODO: Perhaps add http.Request for the username so that this can be cached?
type BlobMetricFunc func(objectType string, operation BlobOperation, nBytes uint64)
// ServeHTTP performs strict matching on the repo part of the URL path and
// dispatches the request to the appropriate handler.
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
urlPath := r.URL.Path
if urlPath == "/" {
// TODO: add HEAD and GET
switch r.Method {
case "POST":
h.createRepo(w, r)
default:
httpMethodNotAllowed(w, []string{"POST"})
}
return
} else if urlPath == "/config" {
switch r.Method {
case "HEAD":
h.checkConfig(w, r)
case "GET":
h.getConfig(w, r)
case "POST":
h.saveConfig(w, r)
case "DELETE":
h.deleteConfig(w, r)
default:
httpMethodNotAllowed(w, []string{"HEAD", "GET", "POST", "DELETE"})
}
return
} else if objectType, objectID := h.getObject(urlPath); objectType != "" {
if objectID == "" {
// TODO: add HEAD
switch r.Method {
case "GET":
h.listBlobs(w, r)
default:
httpMethodNotAllowed(w, []string{"GET"})
}
return
}
switch r.Method {
case "HEAD":
h.checkBlob(w, r)
case "GET":
h.getBlob(w, r)
case "POST":
h.saveBlob(w, r)
case "DELETE":
h.deleteBlob(w, r)
default:
httpMethodNotAllowed(w, []string{"HEAD", "GET", "POST", "DELETE"})
}
return
}
httpDefaultError(w, http.StatusNotFound)
}
// getObject parses the URL path and returns the objectType and objectID,
// if any. The objectID is optional.
func (h *Handler) getObject(urlPath string) (objectType, objectID string) {
m := BlobPathRE.FindStringSubmatch(urlPath)
if len(m) == 0 {
return "", "" // no match
}
if len(m) == 2 || m[2] == "" {
return m[1], "" // no objectID
}
return m[1], m[2]
}
// getSubPath returns the path for a file or subdir in the root of the repo.
func (h *Handler) getSubPath(name string) string {
return filepath.Join(h.path, name)
}
// getObjectPath returns the path for an object file in the repo.
// The passed in objectType and objectID must be valid due to earlier validation
func (h *Handler) getObjectPath(objectType, objectID string) string {
// If we hit an error, this is a programming error, because all of these
// must have been validated before. We still check them here as a safeguard.
if objectType == "" || objectID == "" {
panic("invalid objectType or objectID")
}
if isHashed(objectType) {
if len(objectID) < 2 {
// Should never happen, because BlobPathRE checked this
panic("getObjectPath: objectID shorter than 2 chars")
}
// Added another dir in between with the first two characters of the hash
return filepath.Join(h.path, objectType, objectID[:2], objectID)
}
return filepath.Join(h.path, objectType, objectID)
}
// sendMetric calls op.BlobMetricFunc if set. See its signature for details.
func (h *Handler) sendMetric(objectType string, operation BlobOperation, nBytes uint64) {
if f := h.opt.BlobMetricFunc; f != nil {
f(objectType, operation, nBytes)
}
}
// needSize tells you if we need the file size for metrics of quota accounting
func (h *Handler) needSize() bool {
return h.opt.BlobMetricFunc != nil || h.opt.QuotaManager != nil
}
// incrementRepoSpaceUsage increments the repo space usage if quota are enabled
func (h *Handler) incrementRepoSpaceUsage(by int64) {
if h.opt.QuotaManager != nil {
h.opt.QuotaManager.IncUsage(by)
}
}
// wrapFileWriter wraps the file writer if repo quota are enabled, and returns it
// as is if not.
// If an error occurs, it returns both an error and the appropriate HTTP error code.
func (h *Handler) wrapFileWriter(r *http.Request, w io.Writer) (io.Writer, int, error) {
if h.opt.QuotaManager == nil {
return w, 0, nil // unmodified
}
return h.opt.QuotaManager.WrapWriter(r, w)
}
// checkConfig checks whether a configuration exists.
func (h *Handler) checkConfig(w http.ResponseWriter, r *http.Request) {
if h.opt.Debug {
log.Println("checkConfig()")
}
cfg := h.getSubPath("config")
st, err := os.Stat(cfg)
if err != nil {
if h.opt.Debug {
log.Print(err)
}
httpDefaultError(w, http.StatusNotFound)
return
}
w.Header().Add("Content-Length", fmt.Sprint(st.Size()))
}
// getConfig allows for a config to be retrieved.
func (h *Handler) getConfig(w http.ResponseWriter, r *http.Request) {
if h.opt.Debug {
log.Println("getConfig()")
}
cfg := h.getSubPath("config")
bytes, err := ioutil.ReadFile(cfg)
if err != nil {
if h.opt.Debug {
log.Print(err)
}
httpDefaultError(w, http.StatusNotFound)
return
}
_, _ = w.Write(bytes)
}
// saveConfig allows for a config to be saved.
func (h *Handler) saveConfig(w http.ResponseWriter, r *http.Request) {
if h.opt.Debug {
log.Println("saveConfig()")
}
cfg := h.getSubPath("config")
f, err := os.OpenFile(cfg, os.O_CREATE|os.O_WRONLY|os.O_EXCL, h.opt.FileMode)
if err != nil && os.IsExist(err) {
if h.opt.Debug {
log.Print(err)
}
httpDefaultError(w, http.StatusForbidden)
return
}
_, err = io.Copy(f, r.Body)
if err != nil {
h.internalServerError(w, err)
return
}
err = f.Close()
if err != nil {
h.internalServerError(w, err)
return
}
_ = r.Body.Close()
}
// deleteConfig removes a config.
func (h *Handler) deleteConfig(w http.ResponseWriter, r *http.Request) {
if h.opt.Debug {
log.Println("deleteConfig()")
}
if h.opt.AppendOnly {
httpDefaultError(w, http.StatusForbidden)
return
}
cfg := h.getSubPath("config")
if err := os.Remove(cfg); err != nil {
if h.opt.Debug {
log.Print(err)
}
if os.IsNotExist(err) {
httpDefaultError(w, http.StatusNotFound)
} else {
h.internalServerError(w, err)
}
return
}
}
const (
mimeTypeAPIV1 = "application/vnd.x.restic.rest.v1"
mimeTypeAPIV2 = "application/vnd.x.restic.rest.v2"
)
// listBlobs lists all blobs of a given type in an arbitrary order.
func (h *Handler) listBlobs(w http.ResponseWriter, r *http.Request) {
if h.opt.Debug {
log.Println("listBlobs()")
}
switch r.Header.Get("Accept") {
case mimeTypeAPIV2:
h.listBlobsV2(w, r)
default:
h.listBlobsV1(w, r)
}
}
// listBlobsV1 lists all blobs of a given type in an arbitrary order.
// TODO: unify listBlobsV1 and listBlobsV2
func (h *Handler) listBlobsV1(w http.ResponseWriter, r *http.Request) {
if h.opt.Debug {
log.Println("listBlobsV1()")
}
objectType, _ := h.getObject(r.URL.Path)
if objectType == "" {
h.internalServerError(w, fmt.Errorf(
"cannot determine object type: %s", r.URL.Path))
return
}
path := h.getSubPath(objectType)
items, err := ioutil.ReadDir(path)
if err != nil {
if h.opt.Debug {
log.Print(err)
}
httpDefaultError(w, http.StatusNotFound)
return
}
var names []string
for _, i := range items {
if isHashed(objectType) {
subpath := filepath.Join(path, i.Name())
var subitems []os.FileInfo
subitems, err = ioutil.ReadDir(subpath)
if err != nil {
if h.opt.Debug {
log.Print(err)
}
httpDefaultError(w, http.StatusNotFound)
return
}
for _, f := range subitems {
names = append(names, f.Name())
}
} else {
names = append(names, i.Name())
}
}
data, err := json.Marshal(names)
if err != nil {
h.internalServerError(w, err)
return
}
w.Header().Set("Content-Type", mimeTypeAPIV1)
_, _ = w.Write(data)
}
// Blob represents a single blob, its name and its size.
type Blob struct {
Name string `json:"name"`
Size int64 `json:"size"`
}
// listBlobsV2 lists all blobs of a given type, together with their sizes, in an arbitrary order.
// TODO: unify listBlobsV1 and listBlobsV2
func (h *Handler) listBlobsV2(w http.ResponseWriter, r *http.Request) {
if h.opt.Debug {
log.Println("listBlobsV2()")
}
objectType, _ := h.getObject(r.URL.Path)
if objectType == "" {
h.internalServerError(w, fmt.Errorf(
"cannot determine object type: %s", r.URL.Path))
return
}
path := h.getSubPath(objectType)
items, err := ioutil.ReadDir(path)
if err != nil {
if h.opt.Debug {
log.Print(err)
}
httpDefaultError(w, http.StatusNotFound)
return
}
var blobs []Blob
for _, i := range items {
if isHashed(objectType) {
subpath := filepath.Join(path, i.Name())
var subitems []os.FileInfo
subitems, err = ioutil.ReadDir(subpath)
if err != nil {
if h.opt.Debug {
log.Print(err)
}
httpDefaultError(w, http.StatusNotFound)
return
}
for _, f := range subitems {
blobs = append(blobs, Blob{Name: f.Name(), Size: f.Size()})
}
} else {
blobs = append(blobs, Blob{Name: i.Name(), Size: i.Size()})
}
}
data, err := json.Marshal(blobs)
if err != nil {
h.internalServerError(w, err)
return
}
w.Header().Set("Content-Type", mimeTypeAPIV2)
_, _ = w.Write(data)
}
// checkBlob tests whether a blob exists.
func (h *Handler) checkBlob(w http.ResponseWriter, r *http.Request) {
if h.opt.Debug {
log.Println("checkBlob()")
}
objectType, objectID := h.getObject(r.URL.Path)
if objectType == "" || objectID == "" {
h.internalServerError(w, fmt.Errorf(
"cannot determine object type or id: %s", r.URL.Path))
return
}
path := h.getObjectPath(objectType, objectID)
st, err := os.Stat(path)
if err != nil {
if h.opt.Debug {
log.Print(err)
}
httpDefaultError(w, http.StatusNotFound)
return
}
w.Header().Add("Content-Length", fmt.Sprint(st.Size()))
}
// getBlob retrieves a blob from the repository.
func (h *Handler) getBlob(w http.ResponseWriter, r *http.Request) {
if h.opt.Debug {
log.Println("getBlob()")
}
objectType, objectID := h.getObject(r.URL.Path)
if objectType == "" || objectID == "" {
h.internalServerError(w, fmt.Errorf(
"cannot determine object type or id: %s", r.URL.Path))
return
}
path := h.getObjectPath(objectType, objectID)
file, err := os.Open(path)
if err != nil {
if h.opt.Debug {
log.Print(err)
}
httpDefaultError(w, http.StatusNotFound)
return
}
wc := datacounter.NewResponseWriterCounter(w)
http.ServeContent(wc, r, "", time.Unix(0, 0), file)
if err = file.Close(); err != nil {
h.internalServerError(w, err)
return
}
h.sendMetric(objectType, BlobRead, wc.Count())
}
// saveBlob saves a blob to the repository.
func (h *Handler) saveBlob(w http.ResponseWriter, r *http.Request) {
if h.opt.Debug {
log.Println("saveBlob()")
}
objectType, objectID := h.getObject(r.URL.Path)
if objectType == "" || objectID == "" {
h.internalServerError(w, fmt.Errorf(
"cannot determine object type or id: %s", r.URL.Path))
return
}
path := h.getObjectPath(objectType, objectID)
tf, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_EXCL, h.opt.FileMode)
if os.IsNotExist(err) {
// the error is caused by a missing directory, create it and retry
mkdirErr := os.MkdirAll(filepath.Dir(path), h.opt.DirMode)
if mkdirErr != nil {
log.Print(mkdirErr)
} else {
// try again
tf, err = os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_EXCL, h.opt.FileMode)
}
}
if os.IsExist(err) {
httpDefaultError(w, http.StatusForbidden)
return
}
if err != nil {
h.internalServerError(w, err)
return
}
// ensure this blob does not put us over the quota size limit (if there is one)
outFile, errCode, err := h.wrapFileWriter(r, tf)
if err != nil {
if h.opt.Debug {
log.Println(err)
}
httpDefaultError(w, errCode)
return
}
written, err := io.Copy(outFile, r.Body)
if err != nil {
_ = tf.Close()
_ = os.Remove(path)
h.incrementRepoSpaceUsage(-written)
if h.opt.Debug {
log.Print(err)
}
httpDefaultError(w, http.StatusBadRequest)
return
}
if err := tf.Sync(); err != nil {
_ = tf.Close()
_ = os.Remove(path)
h.incrementRepoSpaceUsage(-written)
h.internalServerError(w, err)
return
}
if err := tf.Close(); err != nil {
_ = os.Remove(path)
h.incrementRepoSpaceUsage(-written)
h.internalServerError(w, err)
return
}
h.sendMetric(objectType, BlobWrite, uint64(written))
}
// deleteBlob deletes a blob from the repository.
func (h *Handler) deleteBlob(w http.ResponseWriter, r *http.Request) {
if h.opt.Debug {
log.Println("deleteBlob()")
}
objectType, objectID := h.getObject(r.URL.Path)
if objectType == "" || objectID == "" {
h.internalServerError(w, fmt.Errorf(
"cannot determine object type or id: %s", r.URL.Path))
return
}
if h.opt.AppendOnly && objectType != "locks" {
httpDefaultError(w, http.StatusForbidden)
return
}
path := h.getObjectPath(objectType, objectID)
var size int64
if h.needSize() {
stat, err := os.Stat(path)
if err == nil {
size = stat.Size()
}
}
if err := os.Remove(path); err != nil {
if h.opt.Debug {
log.Print(err)
}
if os.IsNotExist(err) {
httpDefaultError(w, http.StatusNotFound)
} else {
h.internalServerError(w, err)
}
return
}
h.incrementRepoSpaceUsage(-size)
h.sendMetric(objectType, BlobDelete, uint64(size))
}
// createRepo creates repository directories.
func (h *Handler) createRepo(w http.ResponseWriter, r *http.Request) {
if h.opt.Debug {
log.Println("createRepo()")
}
if r.URL.Query().Get("create") != "true" {
httpDefaultError(w, http.StatusBadRequest)
return
}
log.Printf("Creating repository directories in %s\n", h.path)
if err := os.MkdirAll(h.path, h.opt.DirMode); err != nil {
h.internalServerError(w, err)
return
}
for _, d := range ObjectTypes {
if err := os.Mkdir(filepath.Join(h.path, d), h.opt.DirMode); err != nil && !os.IsExist(err) {
h.internalServerError(w, err)
return
}
}
for i := 0; i < 256; i++ {
dirPath := filepath.Join(h.path, "data", fmt.Sprintf("%02x", i))
if err := os.Mkdir(dirPath, h.opt.DirMode); err != nil && !os.IsExist(err) {
h.internalServerError(w, err)
return
}
}
}
// internalServerError is called to repot an internal server error.
// The error message will be reported in the server logs. If PanicOnError
// is set, this will panic instead, which makes debugging easier.
func (h *Handler) internalServerError(w http.ResponseWriter, err error) {
log.Printf("ERROR: %v", err)
if h.opt.PanicOnError {
panic(fmt.Sprintf("internal server error: %v", err))
}
httpDefaultError(w, http.StatusInternalServerError)
}