| 
									
										
										
										
											2017-07-30 14:30:18 +02:00
										 |  |  | package restserver | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							| 
									
										
										
										
											2018-04-02 12:50:45 +02:00
										 |  |  | 	"bytes" | 
					
						
							|  |  |  | 	"crypto/rand" | 
					
						
							|  |  |  | 	"encoding/hex" | 
					
						
							|  |  |  | 	"io" | 
					
						
							|  |  |  | 	"io/ioutil" | 
					
						
							|  |  |  | 	"net/http" | 
					
						
							|  |  |  | 	"net/http/httptest" | 
					
						
							|  |  |  | 	"os" | 
					
						
							| 
									
										
										
										
											2017-07-30 14:30:18 +02:00
										 |  |  | 	"path/filepath" | 
					
						
							| 
									
										
										
										
											2018-04-02 12:50:45 +02:00
										 |  |  | 	"strings" | 
					
						
							| 
									
										
										
										
											2017-07-30 14:30:18 +02:00
										 |  |  | 	"testing" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestJoin(t *testing.T) { | 
					
						
							|  |  |  | 	var tests = []struct { | 
					
						
							| 
									
										
										
										
											2020-09-13 12:08:46 +02:00
										 |  |  | 		base   string | 
					
						
							|  |  |  | 		names  []string | 
					
						
							|  |  |  | 		result string | 
					
						
							| 
									
										
										
										
											2017-07-30 14:30:18 +02:00
										 |  |  | 	}{ | 
					
						
							| 
									
										
										
										
											2020-09-13 12:08:46 +02:00
										 |  |  | 		{"/", []string{"foo", "bar"}, "/foo/bar"}, | 
					
						
							|  |  |  | 		{"/srv/server", []string{"foo", "bar"}, "/srv/server/foo/bar"}, | 
					
						
							|  |  |  | 		{"/srv/server", []string{"foo", "..", "bar"}, "/srv/server/foo/bar"}, | 
					
						
							|  |  |  | 		{"/srv/server", []string{"..", "bar"}, "/srv/server/bar"}, | 
					
						
							|  |  |  | 		{"/srv/server", []string{".."}, "/srv/server"}, | 
					
						
							|  |  |  | 		{"/srv/server", []string{"..", ".."}, "/srv/server"}, | 
					
						
							|  |  |  | 		{"/srv/server", []string{"repo", "data"}, "/srv/server/repo/data"}, | 
					
						
							|  |  |  | 		{"/srv/server", []string{"repo", "data", "..", ".."}, "/srv/server/repo/data"}, | 
					
						
							|  |  |  | 		{"/srv/server", []string{"repo", "data", "..", "data", "..", "..", ".."}, "/srv/server/repo/data/data"}, | 
					
						
							| 
									
										
										
										
											2017-07-30 14:30:18 +02:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, test := range tests { | 
					
						
							|  |  |  | 		t.Run("", func(t *testing.T) { | 
					
						
							| 
									
										
										
										
											2020-09-13 12:08:46 +02:00
										 |  |  | 			got, err := join(filepath.FromSlash(test.base), test.names...) | 
					
						
							| 
									
										
										
										
											2017-07-30 14:30:18 +02:00
										 |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				t.Fatal(err) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			want := filepath.FromSlash(test.result) | 
					
						
							|  |  |  | 			if got != want { | 
					
						
							|  |  |  | 				t.Fatalf("wrong result returned, want %v, got %v", want, got) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		}) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2018-03-20 20:45:59 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | func TestIsUserPath(t *testing.T) { | 
					
						
							|  |  |  | 	var tests = []struct { | 
					
						
							|  |  |  | 		username string | 
					
						
							|  |  |  | 		path     string | 
					
						
							|  |  |  | 		result   bool | 
					
						
							|  |  |  | 	}{ | 
					
						
							|  |  |  | 		{"foo", "/", false}, | 
					
						
							|  |  |  | 		{"foo", "/foo", true}, | 
					
						
							|  |  |  | 		{"foo", "/foo/", true}, | 
					
						
							|  |  |  | 		{"foo", "/foo/bar", true}, | 
					
						
							|  |  |  | 		{"foo", "/foobar", false}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, test := range tests { | 
					
						
							|  |  |  | 		result := isUserPath(test.username, test.path) | 
					
						
							|  |  |  | 		if result != test.result { | 
					
						
							|  |  |  | 			t.Errorf("isUserPath(%q, %q) was incorrect, got: %v, want: %v.", test.username, test.path, result, test.result) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2018-04-02 12:50:45 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | // declare a few helper functions | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // wantFunc tests the HTTP response in res and calls t.Error() if something is incorrect. | 
					
						
							|  |  |  | type wantFunc func(t testing.TB, res *httptest.ResponseRecorder) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // newRequest returns a new HTTP request with the given params. On error, t.Fatal is called. | 
					
						
							|  |  |  | func newRequest(t testing.TB, method, path string, body io.Reader) *http.Request { | 
					
						
							|  |  |  | 	req, err := http.NewRequest(method, path, body) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		t.Fatal(err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return req | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // wantCode returns a function which checks that the response has the correct HTTP status code. | 
					
						
							|  |  |  | func wantCode(code int) wantFunc { | 
					
						
							|  |  |  | 	return func(t testing.TB, res *httptest.ResponseRecorder) { | 
					
						
							|  |  |  | 		if res.Code != code { | 
					
						
							|  |  |  | 			t.Errorf("wrong response code, want %v, got %v", code, res.Code) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // wantBody returns a function which checks that the response has the data in the body. | 
					
						
							|  |  |  | func wantBody(body string) wantFunc { | 
					
						
							|  |  |  | 	return func(t testing.TB, res *httptest.ResponseRecorder) { | 
					
						
							|  |  |  | 		if res.Body == nil { | 
					
						
							|  |  |  | 			t.Errorf("body is nil, want %q", body) | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if !bytes.Equal(res.Body.Bytes(), []byte(body)) { | 
					
						
							|  |  |  | 			t.Errorf("wrong response body, want:\n  %q\ngot:\n  %q", body, res.Body.Bytes()) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // checkRequest uses f to process the request and runs the checker functions on the result. | 
					
						
							|  |  |  | func checkRequest(t testing.TB, f http.HandlerFunc, req *http.Request, want []wantFunc) { | 
					
						
							|  |  |  | 	rr := httptest.NewRecorder() | 
					
						
							|  |  |  | 	f(rr, req) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, fn := range want { | 
					
						
							|  |  |  | 		fn(t, rr) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // TestRequest is a sequence of HTTP requests with (optional) tests for the response. | 
					
						
							|  |  |  | type TestRequest struct { | 
					
						
							|  |  |  | 	req  *http.Request | 
					
						
							|  |  |  | 	want []wantFunc | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // createOverwriteDeleteSeq returns a sequence which will create a new file at | 
					
						
							|  |  |  | // path, and then try to overwrite and delete it. | 
					
						
							|  |  |  | func createOverwriteDeleteSeq(t testing.TB, path string) []TestRequest { | 
					
						
							|  |  |  | 	// add a file, try to overwrite and delete it | 
					
						
							|  |  |  | 	req := []TestRequest{ | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			req:  newRequest(t, "GET", path, nil), | 
					
						
							|  |  |  | 			want: []wantFunc{wantCode(http.StatusNotFound)}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			req:  newRequest(t, "POST", path, strings.NewReader("foobar test config")), | 
					
						
							|  |  |  | 			want: []wantFunc{wantCode(http.StatusOK)}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			req: newRequest(t, "GET", path, nil), | 
					
						
							|  |  |  | 			want: []wantFunc{ | 
					
						
							|  |  |  | 				wantCode(http.StatusOK), | 
					
						
							|  |  |  | 				wantBody("foobar test config"), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			req:  newRequest(t, "POST", path, strings.NewReader("other config")), | 
					
						
							|  |  |  | 			want: []wantFunc{wantCode(http.StatusForbidden)}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			req: newRequest(t, "GET", path, nil), | 
					
						
							|  |  |  | 			want: []wantFunc{ | 
					
						
							|  |  |  | 				wantCode(http.StatusOK), | 
					
						
							|  |  |  | 				wantBody("foobar test config"), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			req:  newRequest(t, "DELETE", path, nil), | 
					
						
							|  |  |  | 			want: []wantFunc{wantCode(http.StatusForbidden)}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			req: newRequest(t, "GET", path, nil), | 
					
						
							|  |  |  | 			want: []wantFunc{ | 
					
						
							|  |  |  | 				wantCode(http.StatusOK), | 
					
						
							|  |  |  | 				wantBody("foobar test config"), | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return req | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // TestResticHandler runs tests on the restic handler code, especially in append-only mode. | 
					
						
							|  |  |  | func TestResticHandler(t *testing.T) { | 
					
						
							|  |  |  | 	buf := make([]byte, 32) | 
					
						
							|  |  |  | 	_, err := io.ReadFull(rand.Reader, buf) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		t.Fatal(err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	randomID := hex.EncodeToString(buf) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	var tests = []struct { | 
					
						
							|  |  |  | 		seq []TestRequest | 
					
						
							|  |  |  | 	}{ | 
					
						
							|  |  |  | 		{createOverwriteDeleteSeq(t, "/config")}, | 
					
						
							|  |  |  | 		{createOverwriteDeleteSeq(t, "/data/"+randomID)}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			// ensure we can add and remove lock files | 
					
						
							|  |  |  | 			[]TestRequest{ | 
					
						
							|  |  |  | 				{ | 
					
						
							|  |  |  | 					req:  newRequest(t, "GET", "/locks/"+randomID, nil), | 
					
						
							|  |  |  | 					want: []wantFunc{wantCode(http.StatusNotFound)}, | 
					
						
							|  |  |  | 				}, | 
					
						
							|  |  |  | 				{ | 
					
						
							|  |  |  | 					req:  newRequest(t, "POST", "/locks/"+randomID, strings.NewReader("lock file")), | 
					
						
							|  |  |  | 					want: []wantFunc{wantCode(http.StatusOK)}, | 
					
						
							|  |  |  | 				}, | 
					
						
							|  |  |  | 				{ | 
					
						
							|  |  |  | 					req: newRequest(t, "GET", "/locks/"+randomID, nil), | 
					
						
							|  |  |  | 					want: []wantFunc{ | 
					
						
							|  |  |  | 						wantCode(http.StatusOK), | 
					
						
							|  |  |  | 						wantBody("lock file"), | 
					
						
							|  |  |  | 					}, | 
					
						
							|  |  |  | 				}, | 
					
						
							|  |  |  | 				{ | 
					
						
							|  |  |  | 					req:  newRequest(t, "POST", "/locks/"+randomID, strings.NewReader("other lock file")), | 
					
						
							|  |  |  | 					want: []wantFunc{wantCode(http.StatusForbidden)}, | 
					
						
							|  |  |  | 				}, | 
					
						
							|  |  |  | 				{ | 
					
						
							|  |  |  | 					req:  newRequest(t, "DELETE", "/locks/"+randomID, nil), | 
					
						
							|  |  |  | 					want: []wantFunc{wantCode(http.StatusOK)}, | 
					
						
							|  |  |  | 				}, | 
					
						
							|  |  |  | 				{ | 
					
						
							|  |  |  | 					req:  newRequest(t, "GET", "/locks/"+randomID, nil), | 
					
						
							|  |  |  | 					want: []wantFunc{wantCode(http.StatusNotFound)}, | 
					
						
							|  |  |  | 				}, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// setup rclone with a local backend in a temporary directory | 
					
						
							|  |  |  | 	tempdir, err := ioutil.TempDir("", "rclone-restic-test-") | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		t.Fatal(err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// make sure the tempdir is properly removed | 
					
						
							|  |  |  | 	defer func() { | 
					
						
							|  |  |  | 		err := os.RemoveAll(tempdir) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			t.Fatal(err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	}() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-04-12 19:55:44 -06:00
										 |  |  | 	// set append-only mode and configure path | 
					
						
							| 
									
										
										
										
											2018-04-15 08:31:50 -06:00
										 |  |  | 	mux := NewHandler(Server{ | 
					
						
							| 
									
										
										
										
											2018-04-12 19:55:44 -06:00
										 |  |  | 		AppendOnly: true, | 
					
						
							|  |  |  | 		Path:       tempdir, | 
					
						
							|  |  |  | 	}) | 
					
						
							| 
									
										
										
										
											2018-04-02 12:50:45 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	// create the repo | 
					
						
							|  |  |  | 	checkRequest(t, mux.ServeHTTP, | 
					
						
							|  |  |  | 		newRequest(t, "POST", "/?create=true", nil), | 
					
						
							|  |  |  | 		[]wantFunc{wantCode(http.StatusOK)}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, test := range tests { | 
					
						
							|  |  |  | 		t.Run("", func(t *testing.T) { | 
					
						
							|  |  |  | 			for i, seq := range test.seq { | 
					
						
							|  |  |  | 				t.Logf("request %v: %v %v", i, seq.req.Method, seq.req.URL.Path) | 
					
						
							|  |  |  | 				checkRequest(t, mux.ServeHTTP, seq.req, seq.want) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		}) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } |