Verify uploaded files

Restic uses the sha256 hash to calculate filenames based on the file
content. Check on the rest-server side that the uploaded file is intact
and reject it otherwise.
This commit is contained in:
Michael Eischer 2021-08-09 15:35:13 +02:00 committed by Alexander Neumann
parent 96a6f0a5c4
commit 54adcb1fc7
3 changed files with 66 additions and 30 deletions

View file

@ -0,0 +1,8 @@
Enhancement: Verify uploaded files
rest-server now verifies that the hash of content of uploaded files matches
their filename. This ensures that transmission errors are detected and forces
restic to retry the upload.
https://github.com/restic/rest-server/issues/122
https://github.com/restic/rest-server/pull/130

View file

@ -3,6 +3,7 @@ package restserver
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"crypto/sha256"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io" "io"
@ -103,47 +104,58 @@ type TestRequest struct {
// createOverwriteDeleteSeq returns a sequence which will create a new file at // createOverwriteDeleteSeq returns a sequence which will create a new file at
// path, and then try to overwrite and delete it. // path, and then try to overwrite and delete it.
func createOverwriteDeleteSeq(t testing.TB, path string) []TestRequest { func createOverwriteDeleteSeq(t testing.TB, path string, data string) []TestRequest {
// add a file, try to overwrite and delete it // add a file, try to overwrite and delete it
req := []TestRequest{ req := []TestRequest{
{ {
req: newRequest(t, "GET", path, nil), req: newRequest(t, "GET", path, nil),
want: []wantFunc{wantCode(http.StatusNotFound)}, want: []wantFunc{wantCode(http.StatusNotFound)},
}, },
{ }
req: newRequest(t, "POST", path, strings.NewReader("foobar test config")),
if !strings.HasSuffix(path, "/config") {
req = append(req, TestRequest{
// broken upload must fail
req: newRequest(t, "POST", path, strings.NewReader(data+"broken")),
want: []wantFunc{wantCode(http.StatusBadRequest)},
})
}
req = append(req,
TestRequest{
req: newRequest(t, "POST", path, strings.NewReader(data)),
want: []wantFunc{wantCode(http.StatusOK)}, want: []wantFunc{wantCode(http.StatusOK)},
}, },
{ TestRequest{
req: newRequest(t, "GET", path, nil), req: newRequest(t, "GET", path, nil),
want: []wantFunc{ want: []wantFunc{
wantCode(http.StatusOK), wantCode(http.StatusOK),
wantBody("foobar test config"), wantBody(data),
}, },
}, },
{ TestRequest{
req: newRequest(t, "POST", path, strings.NewReader("other config")), req: newRequest(t, "POST", path, strings.NewReader(data+"other stuff")),
want: []wantFunc{wantCode(http.StatusForbidden)}, want: []wantFunc{wantCode(http.StatusForbidden)},
}, },
{ TestRequest{
req: newRequest(t, "GET", path, nil), req: newRequest(t, "GET", path, nil),
want: []wantFunc{ want: []wantFunc{
wantCode(http.StatusOK), wantCode(http.StatusOK),
wantBody("foobar test config"), wantBody(data),
}, },
}, },
{ TestRequest{
req: newRequest(t, "DELETE", path, nil), req: newRequest(t, "DELETE", path, nil),
want: []wantFunc{wantCode(http.StatusForbidden)}, want: []wantFunc{wantCode(http.StatusForbidden)},
}, },
{ TestRequest{
req: newRequest(t, "GET", path, nil), req: newRequest(t, "GET", path, nil),
want: []wantFunc{ want: []wantFunc{
wantCode(http.StatusOK), wantCode(http.StatusOK),
wantBody("foobar test config"), wantBody(data),
}, },
}, },
} )
return req return req
} }
@ -154,53 +166,59 @@ func TestResticHandler(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
randomID := hex.EncodeToString(buf) data := "random data file " + hex.EncodeToString(buf)
dataHash := sha256.Sum256([]byte(data))
fileID := hex.EncodeToString(dataHash[:])
var tests = []struct { var tests = []struct {
seq []TestRequest seq []TestRequest
}{ }{
{createOverwriteDeleteSeq(t, "/config")}, {createOverwriteDeleteSeq(t, "/config", data)},
{createOverwriteDeleteSeq(t, "/data/"+randomID)}, {createOverwriteDeleteSeq(t, "/data/"+fileID, data)},
{ {
// ensure we can add and remove lock files // ensure we can add and remove lock files
[]TestRequest{ []TestRequest{
{ {
req: newRequest(t, "GET", "/locks/"+randomID, nil), req: newRequest(t, "GET", "/locks/"+fileID, nil),
want: []wantFunc{wantCode(http.StatusNotFound)}, want: []wantFunc{wantCode(http.StatusNotFound)},
}, },
{ {
req: newRequest(t, "POST", "/locks/"+randomID, strings.NewReader("lock file")), req: newRequest(t, "POST", "/locks/"+fileID, strings.NewReader(data+"broken")),
want: []wantFunc{wantCode(http.StatusBadRequest)},
},
{
req: newRequest(t, "POST", "/locks/"+fileID, strings.NewReader(data)),
want: []wantFunc{wantCode(http.StatusOK)}, want: []wantFunc{wantCode(http.StatusOK)},
}, },
{ {
req: newRequest(t, "GET", "/locks/"+randomID, nil), req: newRequest(t, "GET", "/locks/"+fileID, nil),
want: []wantFunc{ want: []wantFunc{
wantCode(http.StatusOK), wantCode(http.StatusOK),
wantBody("lock file"), wantBody(data),
}, },
}, },
{ {
req: newRequest(t, "POST", "/locks/"+randomID, strings.NewReader("other lock file")), req: newRequest(t, "POST", "/locks/"+fileID, strings.NewReader(data+"other data")),
want: []wantFunc{wantCode(http.StatusForbidden)}, want: []wantFunc{wantCode(http.StatusForbidden)},
}, },
{ {
req: newRequest(t, "DELETE", "/locks/"+randomID, nil), req: newRequest(t, "DELETE", "/locks/"+fileID, nil),
want: []wantFunc{wantCode(http.StatusOK)}, want: []wantFunc{wantCode(http.StatusOK)},
}, },
{ {
req: newRequest(t, "GET", "/locks/"+randomID, nil), req: newRequest(t, "GET", "/locks/"+fileID, nil),
want: []wantFunc{wantCode(http.StatusNotFound)}, want: []wantFunc{wantCode(http.StatusNotFound)},
}, },
}, },
}, },
// Test subrepos // Test subrepos
{createOverwriteDeleteSeq(t, "/parent1/sub1/config")}, {createOverwriteDeleteSeq(t, "/parent1/sub1/config", "foobar")},
{createOverwriteDeleteSeq(t, "/parent1/sub1/data/"+randomID)}, {createOverwriteDeleteSeq(t, "/parent1/sub1/data/"+fileID, data)},
{createOverwriteDeleteSeq(t, "/parent1/config")}, {createOverwriteDeleteSeq(t, "/parent1/config", "foobar")},
{createOverwriteDeleteSeq(t, "/parent1/data/"+randomID)}, {createOverwriteDeleteSeq(t, "/parent1/data/"+fileID, data)},
{createOverwriteDeleteSeq(t, "/parent2/config")}, {createOverwriteDeleteSeq(t, "/parent2/config", "foobar")},
{createOverwriteDeleteSeq(t, "/parent2/data/"+randomID)}, {createOverwriteDeleteSeq(t, "/parent2/data/"+fileID, data)},
} }
// setup rclone with a local backend in a temporary directory // setup rclone with a local backend in a temporary directory

View file

@ -1,6 +1,8 @@
package repo package repo
import ( import (
"crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -569,7 +571,15 @@ func (h *Handler) saveBlob(w http.ResponseWriter, r *http.Request) {
return return
} }
written, err := io.Copy(outFile, r.Body) // calculate hash for current request
hasher := sha256.New()
written, err := io.Copy(outFile, io.TeeReader(r.Body, hasher))
// reject if file content doesn't match file name
if err == nil && hex.EncodeToString(hasher.Sum(nil)) != objectID {
err = fmt.Errorf("file content does not match hash")
}
if err != nil { if err != nil {
_ = tf.Close() _ = tf.Close()
_ = os.Remove(path) _ = os.Remove(path)