repository: enforce that SaveBlob is called within WithBlobUploader

This is achieved by removing SaveBlob from the public API and only
returning it via a uploader object that is passed in by
WithBlobUploader.
This commit is contained in:
Michael Eischer 2025-10-10 22:41:35 +02:00
parent ac4642b479
commit c6e33c3954
21 changed files with 172 additions and 143 deletions

View file

@ -352,7 +352,7 @@ func loadBlobs(ctx context.Context, opts DebugExamineOptions, repo restic.Reposi
return err return err
} }
err = repo.WithBlobUploader(ctx, func(ctx context.Context) error { err = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error {
for _, blob := range list { for _, blob := range list {
printer.S(" loading blob %v at %v (length %v)", blob.ID, blob.Offset, blob.Length) printer.S(" loading blob %v at %v (length %v)", blob.ID, blob.Offset, blob.Length)
if int(blob.Offset+blob.Length) > len(pack) { if int(blob.Offset+blob.Length) > len(pack) {
@ -410,7 +410,7 @@ func loadBlobs(ctx context.Context, opts DebugExamineOptions, repo restic.Reposi
} }
} }
if opts.ReuploadBlobs { if opts.ReuploadBlobs {
_, _, _, err := repo.SaveBlob(ctx, blob.Type, plaintext, id, true) _, _, _, err := uploader.SaveBlob(ctx, blob.Type, plaintext, id, true)
if err != nil { if err != nil {
return err return err
} }

View file

@ -152,9 +152,9 @@ func runRecover(ctx context.Context, gopts GlobalOptions, term ui.Terminal) erro
} }
var treeID restic.ID var treeID restic.ID
err = repo.WithBlobUploader(ctx, func(ctx context.Context) error { err = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error {
var err error var err error
treeID, err = data.SaveTree(ctx, repo, tree) treeID, err = data.SaveTree(ctx, uploader, tree)
if err != nil { if err != nil {
return errors.Fatalf("unable to save new tree to the repository: %v", err) return errors.Fatalf("unable to save new tree to the repository: %v", err)
} }

View file

@ -129,19 +129,15 @@ func runRepairSnapshots(ctx context.Context, gopts GlobalOptions, opts RepairOpt
node.Size = newSize node.Size = newSize
return node return node
}, },
RewriteFailedTree: func(_ restic.ID, path string, _ error) (restic.ID, error) { RewriteFailedTree: func(_ restic.ID, path string, _ error) (*data.Tree, error) {
if path == "/" { if path == "/" {
printer.P(" dir %q: not readable", path) printer.P(" dir %q: not readable", path)
// remove snapshots with invalid root node // remove snapshots with invalid root node
return restic.ID{}, nil return nil, nil
} }
// If a subtree fails to load, remove it // If a subtree fails to load, remove it
printer.P(" dir %q: replaced with empty directory", path) printer.P(" dir %q: replaced with empty directory", path)
emptyID, err := data.SaveTree(ctx, repo, &data.Tree{}) return &data.Tree{}, nil
if err != nil {
return restic.ID{}, err
}
return emptyID, nil
}, },
AllowUnstableSerialization: true, AllowUnstableSerialization: true,
}) })
@ -150,8 +146,8 @@ func runRepairSnapshots(ctx context.Context, gopts GlobalOptions, opts RepairOpt
for sn := range FindFilteredSnapshots(ctx, snapshotLister, repo, &opts.SnapshotFilter, args, printer) { for sn := range FindFilteredSnapshots(ctx, snapshotLister, repo, &opts.SnapshotFilter, args, printer) {
printer.P("\n%v", sn) printer.P("\n%v", sn)
changed, err := filterAndReplaceSnapshot(ctx, repo, sn, changed, err := filterAndReplaceSnapshot(ctx, repo, sn,
func(ctx context.Context, sn *data.Snapshot) (restic.ID, *data.SnapshotSummary, error) { func(ctx context.Context, sn *data.Snapshot, uploader restic.BlobSaver) (restic.ID, *data.SnapshotSummary, error) {
id, err := rewriter.RewriteTree(ctx, repo, "/", *sn.Tree) id, err := rewriter.RewriteTree(ctx, repo, uploader, "/", *sn.Tree)
return id, nil, err return id, nil, err
}, opts.DryRun, opts.Forget, nil, "repaired", printer) }, opts.DryRun, opts.Forget, nil, "repaired", printer)
if err != nil { if err != nil {

View file

@ -123,7 +123,7 @@ func (opts *RewriteOptions) AddFlags(f *pflag.FlagSet) {
// rewriteFilterFunc returns the filtered tree ID or an error. If a snapshot summary is returned, the snapshot will // rewriteFilterFunc returns the filtered tree ID or an error. If a snapshot summary is returned, the snapshot will
// be updated accordingly. // be updated accordingly.
type rewriteFilterFunc func(ctx context.Context, sn *data.Snapshot) (restic.ID, *data.SnapshotSummary, error) type rewriteFilterFunc func(ctx context.Context, sn *data.Snapshot, uploader restic.BlobSaver) (restic.ID, *data.SnapshotSummary, error)
func rewriteSnapshot(ctx context.Context, repo *repository.Repository, sn *data.Snapshot, opts RewriteOptions, printer progress.Printer) (bool, error) { func rewriteSnapshot(ctx context.Context, repo *repository.Repository, sn *data.Snapshot, opts RewriteOptions, printer progress.Printer) (bool, error) {
if sn.Tree == nil { if sn.Tree == nil {
@ -163,8 +163,8 @@ func rewriteSnapshot(ctx context.Context, repo *repository.Repository, sn *data.
rewriter, querySize := walker.NewSnapshotSizeRewriter(rewriteNode) rewriter, querySize := walker.NewSnapshotSizeRewriter(rewriteNode)
filter = func(ctx context.Context, sn *data.Snapshot) (restic.ID, *data.SnapshotSummary, error) { filter = func(ctx context.Context, sn *data.Snapshot, uploader restic.BlobSaver) (restic.ID, *data.SnapshotSummary, error) {
id, err := rewriter.RewriteTree(ctx, repo, "/", *sn.Tree) id, err := rewriter.RewriteTree(ctx, repo, uploader, "/", *sn.Tree)
if err != nil { if err != nil {
return restic.ID{}, nil, err return restic.ID{}, nil, err
} }
@ -179,7 +179,7 @@ func rewriteSnapshot(ctx context.Context, repo *repository.Repository, sn *data.
} }
} else { } else {
filter = func(_ context.Context, sn *data.Snapshot) (restic.ID, *data.SnapshotSummary, error) { filter = func(_ context.Context, sn *data.Snapshot, _ restic.BlobSaver) (restic.ID, *data.SnapshotSummary, error) {
return *sn.Tree, nil, nil return *sn.Tree, nil, nil
} }
} }
@ -193,9 +193,9 @@ func filterAndReplaceSnapshot(ctx context.Context, repo restic.Repository, sn *d
var filteredTree restic.ID var filteredTree restic.ID
var summary *data.SnapshotSummary var summary *data.SnapshotSummary
err := repo.WithBlobUploader(ctx, func(ctx context.Context) error { err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error {
var err error var err error
filteredTree, summary, err = filter(ctx, sn) filteredTree, summary, err = filter(ctx, sn, uploader)
return err return err
}) })
if err != nil { if err != nil {

View file

@ -74,11 +74,10 @@ type ToNoder interface {
type archiverRepo interface { type archiverRepo interface {
restic.Loader restic.Loader
restic.BlobSaver restic.WithBlobUploader
restic.SaverUnpacked[restic.WriteableFileType] restic.SaverUnpacked[restic.WriteableFileType]
Config() restic.Config Config() restic.Config
WithBlobUploader(ctx context.Context, fn func(ctx context.Context) error) error
} }
// Archiver saves a directory structure to the repo. // Archiver saves a directory structure to the repo.
@ -835,8 +834,8 @@ func (arch *Archiver) loadParentTree(ctx context.Context, sn *data.Snapshot) *da
} }
// runWorkers starts the worker pools, which are stopped when the context is cancelled. // runWorkers starts the worker pools, which are stopped when the context is cancelled.
func (arch *Archiver) runWorkers(ctx context.Context, wg *errgroup.Group) { func (arch *Archiver) runWorkers(ctx context.Context, wg *errgroup.Group, uploader restic.BlobSaver) {
arch.blobSaver = newBlobSaver(ctx, wg, arch.Repo, arch.Options.SaveBlobConcurrency) arch.blobSaver = newBlobSaver(ctx, wg, uploader, arch.Options.SaveBlobConcurrency)
arch.fileSaver = newFileSaver(ctx, wg, arch.fileSaver = newFileSaver(ctx, wg,
arch.blobSaver.Save, arch.blobSaver.Save,
@ -875,12 +874,12 @@ func (arch *Archiver) Snapshot(ctx context.Context, targets []string, opts Snaps
var rootTreeID restic.ID var rootTreeID restic.ID
err = arch.Repo.WithBlobUploader(ctx, func(ctx context.Context) error { err = arch.Repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error {
wg, wgCtx := errgroup.WithContext(ctx) wg, wgCtx := errgroup.WithContext(ctx)
start := time.Now() start := time.Now()
wg.Go(func() error { wg.Go(func() error {
arch.runWorkers(wgCtx, wg) arch.runWorkers(wgCtx, wg, uploader)
debug.Log("starting snapshot") debug.Log("starting snapshot")
fn, nodeCount, err := arch.saveTree(wgCtx, "/", atree, arch.loadParentTree(wgCtx, opts.ParentSnapshot), func(_ *data.Node, is ItemStats) { fn, nodeCount, err := arch.saveTree(wgCtx, "/", atree, arch.loadParentTree(wgCtx, opts.ParentSnapshot), func(_ *data.Node, is ItemStats) {

View file

@ -56,9 +56,9 @@ func saveFile(t testing.TB, repo archiverRepo, filename string, filesystem fs.FS
return err return err
} }
err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
wg, ctx := errgroup.WithContext(ctx) wg, ctx := errgroup.WithContext(ctx)
arch.runWorkers(ctx, wg) arch.runWorkers(ctx, wg, uploader)
completeReading := func() { completeReading := func() {
completeReadingCallback = true completeReadingCallback = true
@ -219,9 +219,9 @@ func TestArchiverSave(t *testing.T) {
arch.summary = &Summary{} arch.summary = &Summary{}
var fnr futureNodeResult var fnr futureNodeResult
err := repo.WithBlobUploader(ctx, func(ctx context.Context) error { err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error {
wg, ctx := errgroup.WithContext(ctx) wg, ctx := errgroup.WithContext(ctx)
arch.runWorkers(ctx, wg) arch.runWorkers(ctx, wg, uploader)
node, excluded, err := arch.save(ctx, "/", filepath.Join(tempdir, "file"), nil) node, excluded, err := arch.save(ctx, "/", filepath.Join(tempdir, "file"), nil)
if err != nil { if err != nil {
@ -296,9 +296,9 @@ func TestArchiverSaveReaderFS(t *testing.T) {
arch.summary = &Summary{} arch.summary = &Summary{}
var fnr futureNodeResult var fnr futureNodeResult
err = repo.WithBlobUploader(ctx, func(ctx context.Context) error { err = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error {
wg, ctx := errgroup.WithContext(ctx) wg, ctx := errgroup.WithContext(ctx)
arch.runWorkers(ctx, wg) arch.runWorkers(ctx, wg, uploader)
node, excluded, err := arch.save(ctx, "/", filename, nil) node, excluded, err := arch.save(ctx, "/", filename, nil)
t.Logf("Save returned %v %v", node, err) t.Logf("Save returned %v %v", node, err)
@ -415,27 +415,29 @@ type blobCountingRepo struct {
saved map[restic.BlobHandle]uint saved map[restic.BlobHandle]uint
} }
func (repo *blobCountingRepo) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (restic.ID, bool, int, error) { func (repo *blobCountingRepo) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaver) error) error {
id, exists, size, err := repo.archiverRepo.SaveBlob(ctx, t, buf, id, storeDuplicate) return repo.archiverRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error {
return fn(ctx, &blobCountingSaver{saver: uploader, blobCountingRepo: repo})
})
}
type blobCountingSaver struct {
saver restic.BlobSaver
blobCountingRepo *blobCountingRepo
}
func (repo *blobCountingSaver) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (restic.ID, bool, int, error) {
id, exists, size, err := repo.saver.SaveBlob(ctx, t, buf, id, storeDuplicate)
if exists { if exists {
return id, exists, size, err return id, exists, size, err
} }
h := restic.BlobHandle{ID: id, Type: t} h := restic.BlobHandle{ID: id, Type: t}
repo.m.Lock() repo.blobCountingRepo.m.Lock()
repo.saved[h]++ repo.blobCountingRepo.saved[h]++
repo.m.Unlock() repo.blobCountingRepo.m.Unlock()
return id, exists, size, err return id, exists, size, err
} }
func (repo *blobCountingRepo) SaveTree(ctx context.Context, t *data.Tree) (restic.ID, error) {
id, err := data.SaveTree(ctx, repo.archiverRepo, t)
h := restic.BlobHandle{ID: id, Type: restic.TreeBlob}
repo.m.Lock()
repo.saved[h]++
repo.m.Unlock()
return id, err
}
func appendToFile(t testing.TB, filename string, data []byte) { func appendToFile(t testing.TB, filename string, data []byte) {
f, err := os.OpenFile(filename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) f, err := os.OpenFile(filename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
if err != nil { if err != nil {
@ -838,9 +840,9 @@ func TestArchiverSaveDir(t *testing.T) {
defer back() defer back()
var treeID restic.ID var treeID restic.ID
err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
wg, ctx := errgroup.WithContext(ctx) wg, ctx := errgroup.WithContext(ctx)
arch.runWorkers(ctx, wg) arch.runWorkers(ctx, wg, uploader)
meta, err := testFS.OpenFile(test.target, fs.O_NOFOLLOW, true) meta, err := testFS.OpenFile(test.target, fs.O_NOFOLLOW, true)
rtest.OK(t, err) rtest.OK(t, err)
ft, err := arch.saveDir(ctx, "/", test.target, meta, nil, nil) ft, err := arch.saveDir(ctx, "/", test.target, meta, nil, nil)
@ -866,7 +868,7 @@ func TestArchiverSaveDir(t *testing.T) {
node.Name = targetNodeName node.Name = targetNodeName
tree := &data.Tree{Nodes: []*data.Node{node}} tree := &data.Tree{Nodes: []*data.Node{node}}
treeID, err = data.SaveTree(ctx, repo, tree) treeID, err = data.SaveTree(ctx, uploader, tree)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -904,9 +906,9 @@ func TestArchiverSaveDirIncremental(t *testing.T) {
arch.summary = &Summary{} arch.summary = &Summary{}
var fnr futureNodeResult var fnr futureNodeResult
err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
wg, ctx := errgroup.WithContext(ctx) wg, ctx := errgroup.WithContext(ctx)
arch.runWorkers(ctx, wg) arch.runWorkers(ctx, wg, uploader)
meta, err := testFS.OpenFile(tempdir, fs.O_NOFOLLOW, true) meta, err := testFS.OpenFile(tempdir, fs.O_NOFOLLOW, true)
rtest.OK(t, err) rtest.OK(t, err)
ft, err := arch.saveDir(ctx, "/", tempdir, meta, nil, nil) ft, err := arch.saveDir(ctx, "/", tempdir, meta, nil, nil)
@ -1094,9 +1096,9 @@ func TestArchiverSaveTree(t *testing.T) {
} }
var treeID restic.ID var treeID restic.ID
err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
wg, ctx := errgroup.WithContext(ctx) wg, ctx := errgroup.WithContext(ctx)
arch.runWorkers(ctx, wg) arch.runWorkers(ctx, wg, uploader)
atree, err := newTree(testFS, test.targets) atree, err := newTree(testFS, test.targets)
if err != nil { if err != nil {
@ -2093,13 +2095,24 @@ type failSaveRepo struct {
err error err error
} }
func (f *failSaveRepo) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (restic.ID, bool, int, error) { func (f *failSaveRepo) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaver) error) error {
val := atomic.AddInt32(&f.cnt, 1) return f.archiverRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error {
if val >= f.failAfter { return fn(ctx, &failSaveSaver{saver: uploader, failSaveRepo: f})
return restic.Hash(buf), false, 0, f.err })
}
type failSaveSaver struct {
saver restic.BlobSaver
failSaveRepo *failSaveRepo
}
func (f *failSaveSaver) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (restic.ID, bool, int, error) {
val := atomic.AddInt32(&f.failSaveRepo.cnt, 1)
if val >= f.failSaveRepo.failAfter {
return restic.Hash(buf), false, 0, f.failSaveRepo.err
} }
return f.archiverRepo.SaveBlob(ctx, t, buf, id, storeDuplicate) return f.saver.SaveBlob(ctx, t, buf, id, storeDuplicate)
} }
func TestArchiverAbortEarlyOnError(t *testing.T) { func TestArchiverAbortEarlyOnError(t *testing.T) {
@ -2412,7 +2425,7 @@ func TestRacyFileTypeSwap(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
_ = repo.WithBlobUploader(ctx, func(ctx context.Context) error { _ = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error {
wg, ctx := errgroup.WithContext(ctx) wg, ctx := errgroup.WithContext(ctx)
arch := New(repo, fs.Track{FS: statfs}, Options{}) arch := New(repo, fs.Track{FS: statfs}, Options{})
@ -2420,7 +2433,7 @@ func TestRacyFileTypeSwap(t *testing.T) {
t.Logf("archiver error as expected for %v: %v", item, err) t.Logf("archiver error as expected for %v: %v", item, err)
return err return err
} }
arch.runWorkers(ctx, wg) arch.runWorkers(ctx, wg, uploader)
// fs.Track will panic if the file was not closed // fs.Track will panic if the file was not closed
_, excluded, err := arch.save(ctx, "/", tempfile, nil) _, excluded, err := arch.save(ctx, "/", tempfile, nil)

View file

@ -525,18 +525,18 @@ func TestCheckerBlobTypeConfusion(t *testing.T) {
} }
var id restic.ID var id restic.ID
test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context) error { test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error {
var err error var err error
id, err = data.SaveTree(ctx, repo, damagedTree) id, err = data.SaveTree(ctx, uploader, damagedTree)
return err return err
})) }))
buf, err := repo.LoadBlob(ctx, restic.TreeBlob, id, nil) buf, err := repo.LoadBlob(ctx, restic.TreeBlob, id, nil)
test.OK(t, err) test.OK(t, err)
test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context) error { test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error {
var err error var err error
_, _, _, err = repo.SaveBlob(ctx, restic.DataBlob, buf, id, false) _, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, buf, id, false)
return err return err
})) }))
@ -559,9 +559,9 @@ func TestCheckerBlobTypeConfusion(t *testing.T) {
} }
var rootID restic.ID var rootID restic.ID
test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context) error { test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error {
var err error var err error
rootID, err = data.SaveTree(ctx, repo, rootTree) rootID, err = data.SaveTree(ctx, uploader, rootTree)
return err return err
})) }))

View file

@ -28,7 +28,7 @@ type fakeFileSystem struct {
// saveFile reads from rd and saves the blobs in the repository. The list of // saveFile reads from rd and saves the blobs in the repository. The list of
// IDs is returned. // IDs is returned.
func (fs *fakeFileSystem) saveFile(ctx context.Context, rd io.Reader) (blobs restic.IDs) { func (fs *fakeFileSystem) saveFile(ctx context.Context, uploader restic.BlobSaver, rd io.Reader) (blobs restic.IDs) {
if fs.buf == nil { if fs.buf == nil {
fs.buf = make([]byte, chunker.MaxSize) fs.buf = make([]byte, chunker.MaxSize)
} }
@ -50,7 +50,7 @@ func (fs *fakeFileSystem) saveFile(ctx context.Context, rd io.Reader) (blobs res
fs.t.Fatalf("unable to save chunk in repo: %v", err) fs.t.Fatalf("unable to save chunk in repo: %v", err)
} }
id, _, _, err := fs.repo.SaveBlob(ctx, restic.DataBlob, chunk.Data, restic.ID{}, false) id, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, chunk.Data, restic.ID{}, false)
if err != nil { if err != nil {
fs.t.Fatalf("error saving chunk: %v", err) fs.t.Fatalf("error saving chunk: %v", err)
} }
@ -68,7 +68,7 @@ const (
) )
// saveTree saves a tree of fake files in the repo and returns the ID. // saveTree saves a tree of fake files in the repo and returns the ID.
func (fs *fakeFileSystem) saveTree(ctx context.Context, seed int64, depth int) restic.ID { func (fs *fakeFileSystem) saveTree(ctx context.Context, uploader restic.BlobSaver, seed int64, depth int) restic.ID {
rnd := rand.NewSource(seed) rnd := rand.NewSource(seed)
numNodes := int(rnd.Int63() % maxNodes) numNodes := int(rnd.Int63() % maxNodes)
@ -78,7 +78,7 @@ func (fs *fakeFileSystem) saveTree(ctx context.Context, seed int64, depth int) r
// randomly select the type of the node, either tree (p = 1/4) or file (p = 3/4). // randomly select the type of the node, either tree (p = 1/4) or file (p = 3/4).
if depth > 1 && rnd.Int63()%4 == 0 { if depth > 1 && rnd.Int63()%4 == 0 {
treeSeed := rnd.Int63() % maxSeed treeSeed := rnd.Int63() % maxSeed
id := fs.saveTree(ctx, treeSeed, depth-1) id := fs.saveTree(ctx, uploader, treeSeed, depth-1)
node := &Node{ node := &Node{
Name: fmt.Sprintf("dir-%v", treeSeed), Name: fmt.Sprintf("dir-%v", treeSeed),
@ -101,13 +101,13 @@ func (fs *fakeFileSystem) saveTree(ctx context.Context, seed int64, depth int) r
Size: uint64(fileSize), Size: uint64(fileSize),
} }
node.Content = fs.saveFile(ctx, fakeFile(fileSeed, fileSize)) node.Content = fs.saveFile(ctx, uploader, fakeFile(fileSeed, fileSize))
tree.Nodes = append(tree.Nodes, node) tree.Nodes = append(tree.Nodes, node)
} }
tree.Sort() tree.Sort()
id, err := SaveTree(ctx, fs.repo, &tree) id, err := SaveTree(ctx, uploader, &tree)
if err != nil { if err != nil {
fs.t.Fatalf("SaveTree returned error: %v", err) fs.t.Fatalf("SaveTree returned error: %v", err)
} }
@ -136,8 +136,8 @@ func TestCreateSnapshot(t testing.TB, repo restic.Repository, at time.Time, dept
} }
var treeID restic.ID var treeID restic.ID
test.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { test.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
treeID = fs.saveTree(ctx, seed, depth) treeID = fs.saveTree(ctx, uploader, seed, depth)
return nil return nil
})) }))
snapshot.Tree = &treeID snapshot.Tree = &treeID

View file

@ -107,10 +107,10 @@ func TestEmptyLoadTree(t *testing.T) {
tree := data.NewTree(0) tree := data.NewTree(0)
var id restic.ID var id restic.ID
rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
var err error var err error
// save tree // save tree
id, err = data.SaveTree(ctx, repo, tree) id, err = data.SaveTree(ctx, uploader, tree)
return err return err
})) }))

View file

@ -20,8 +20,8 @@ func FuzzSaveLoadBlob(f *testing.F) {
id := restic.Hash(blob) id := restic.Hash(blob)
repo, _, _ := TestRepositoryWithVersion(t, 2) repo, _, _ := TestRepositoryWithVersion(t, 2)
rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
_, _, _, err := repo.SaveBlob(ctx, restic.DataBlob, blob, id, false) _, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, blob, id, false)
return err return err
})) }))

View file

@ -47,9 +47,9 @@ func TestPruneMaxUnusedDuplicate(t *testing.T) {
{bufs[1], bufs[3]}, {bufs[1], bufs[3]},
{bufs[2], bufs[3]}, {bufs[2], bufs[3]},
} { } {
rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
for _, blob := range blobs { for _, blob := range blobs {
id, _, _, err := repo.SaveBlob(ctx, restic.DataBlob, blob, restic.ID{}, true) id, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, blob, restic.ID{}, true)
keep.Insert(restic.BlobHandle{Type: restic.DataBlob, ID: id}) keep.Insert(restic.BlobHandle{Type: restic.DataBlob, ID: id})
rtest.OK(t, err) rtest.OK(t, err)
} }

View file

@ -25,12 +25,12 @@ func testPrune(t *testing.T, opts repository.PruneOptions, errOnUnused bool) {
createRandomBlobs(t, random, repo, 5, 0.5, true) createRandomBlobs(t, random, repo, 5, 0.5, true)
keep, _ := selectBlobs(t, random, repo, 0.5) keep, _ := selectBlobs(t, random, repo, 0.5)
rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
// duplicate a few blobs to exercise those code paths // duplicate a few blobs to exercise those code paths
for blob := range keep { for blob := range keep {
buf, err := repo.LoadBlob(ctx, blob.Type, blob.ID, nil) buf, err := repo.LoadBlob(ctx, blob.Type, blob.ID, nil)
rtest.OK(t, err) rtest.OK(t, err)
_, _, _, err = repo.SaveBlob(ctx, blob.Type, buf, blob.ID, true) _, _, _, err = uploader.SaveBlob(ctx, blob.Type, buf, blob.ID, true)
rtest.OK(t, err) rtest.OK(t, err)
} }
return nil return nil
@ -133,13 +133,13 @@ func TestPruneSmall(t *testing.T) {
const numBlobsCreated = 55 const numBlobsCreated = 55
keep := restic.NewBlobSet() keep := restic.NewBlobSet()
rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
// we need a minum of 11 packfiles, each packfile will be about 5 Mb long // we need a minum of 11 packfiles, each packfile will be about 5 Mb long
for i := 0; i < numBlobsCreated; i++ { for i := 0; i < numBlobsCreated; i++ {
buf := make([]byte, blobSize) buf := make([]byte, blobSize)
random.Read(buf) random.Read(buf)
id, _, _, err := repo.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) id, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false)
rtest.OK(t, err) rtest.OK(t, err)
keep.Insert(restic.BlobHandle{Type: restic.DataBlob, ID: id}) keep.Insert(restic.BlobHandle{Type: restic.DataBlob, ID: id})
} }

View file

@ -47,9 +47,9 @@ func Repack(
return nil, errors.New("repack step requires a backend connection limit of at least two") return nil, errors.New("repack step requires a backend connection limit of at least two")
} }
err = dstRepo.WithBlobUploader(ctx, func(ctx context.Context) error { err = dstRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error {
var err error var err error
obsoletePacks, err = repack(ctx, repo, dstRepo, packs, keepBlobs, p, logf) obsoletePacks, err = repack(ctx, repo, dstRepo, uploader, packs, keepBlobs, p, logf)
return err return err
}) })
if err != nil { if err != nil {
@ -62,6 +62,7 @@ func repack(
ctx context.Context, ctx context.Context,
repo restic.Repository, repo restic.Repository,
dstRepo restic.Repository, dstRepo restic.Repository,
uploader restic.BlobSaver,
packs restic.IDSet, packs restic.IDSet,
keepBlobs repackBlobSet, keepBlobs repackBlobSet,
p *progress.Counter, p *progress.Counter,
@ -128,7 +129,7 @@ func repack(
} }
// We do want to save already saved blobs! // We do want to save already saved blobs!
_, _, _, err = dstRepo.SaveBlob(wgCtx, blob.Type, buf, blob.ID, true) _, _, _, err = uploader.SaveBlob(wgCtx, blob.Type, buf, blob.ID, true)
if err != nil { if err != nil {
return err return err
} }

View file

@ -20,7 +20,7 @@ func randomSize(random *rand.Rand, min, max int) int {
func createRandomBlobs(t testing.TB, random *rand.Rand, repo restic.Repository, blobs int, pData float32, smallBlobs bool) { func createRandomBlobs(t testing.TB, random *rand.Rand, repo restic.Repository, blobs int, pData float32, smallBlobs bool) {
// two loops to allow creating multiple pack files // two loops to allow creating multiple pack files
for blobs > 0 { for blobs > 0 {
rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
for blobs > 0 { for blobs > 0 {
blobs-- blobs--
var ( var (
@ -43,7 +43,7 @@ func createRandomBlobs(t testing.TB, random *rand.Rand, repo restic.Repository,
buf := make([]byte, length) buf := make([]byte, length)
random.Read(buf) random.Read(buf)
id, exists, _, err := repo.SaveBlob(ctx, tpe, buf, restic.ID{}, false) id, exists, _, err := uploader.SaveBlob(ctx, tpe, buf, restic.ID{}, false)
if err != nil { if err != nil {
t.Fatalf("SaveFrom() error %v", err) t.Fatalf("SaveFrom() error %v", err)
} }
@ -70,8 +70,8 @@ func createRandomWrongBlob(t testing.TB, random *rand.Rand, repo restic.Reposito
// invert first data byte // invert first data byte
buf[0] ^= 0xff buf[0] ^= 0xff
rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
_, _, _, err := repo.SaveBlob(ctx, restic.DataBlob, buf, id, false) _, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, buf, id, false)
return err return err
})) }))
return restic.BlobHandle{ID: id, Type: restic.DataBlob} return restic.BlobHandle{ID: id, Type: restic.DataBlob}
@ -339,8 +339,8 @@ func testRepackBlobFallback(t *testing.T, version uint) {
modbuf[0] ^= 0xff modbuf[0] ^= 0xff
// create pack with broken copy // create pack with broken copy
rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
_, _, _, err := repo.SaveBlob(ctx, restic.DataBlob, modbuf, id, false) _, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, modbuf, id, false)
return err return err
})) }))
@ -349,8 +349,8 @@ func testRepackBlobFallback(t *testing.T, version uint) {
rewritePacks := findPacksForBlobs(t, repo, keepBlobs) rewritePacks := findPacksForBlobs(t, repo, keepBlobs)
// create pack with valid copy // create pack with valid copy
rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
_, _, _, err := repo.SaveBlob(ctx, restic.DataBlob, buf, id, true) _, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, buf, id, true)
return err return err
})) }))

View file

@ -15,7 +15,7 @@ func RepairPacks(ctx context.Context, repo *Repository, ids restic.IDSet, printe
bar.SetMax(uint64(len(ids))) bar.SetMax(uint64(len(ids)))
defer bar.Done() defer bar.Done()
err := repo.WithBlobUploader(ctx, func(ctx context.Context) error { err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error {
// examine all data the indexes have for the pack file // examine all data the indexes have for the pack file
for b := range repo.ListPacksFromIndex(ctx, ids) { for b := range repo.ListPacksFromIndex(ctx, ids) {
blobs := b.Blobs blobs := b.Blobs
@ -30,7 +30,7 @@ func RepairPacks(ctx context.Context, repo *Repository, ids restic.IDSet, printe
printer.E("failed to load blob %v: %v", blob.ID, err) printer.E("failed to load blob %v: %v", blob.ID, err)
return nil return nil
} }
id, _, _, err := repo.SaveBlob(ctx, blob.Type, buf, restic.ID{}, true) id, _, _, err := uploader.SaveBlob(ctx, blob.Type, buf, restic.ID{}, true)
if !id.Equal(blob.ID) { if !id.Equal(blob.ID) {
panic("pack id mismatch during upload") panic("pack id mismatch during upload")
} }

View file

@ -559,11 +559,11 @@ func (r *Repository) removeUnpacked(ctx context.Context, t restic.FileType, id r
return r.be.Remove(ctx, backend.Handle{Type: t, Name: id.String()}) return r.be.Remove(ctx, backend.Handle{Type: t, Name: id.String()})
} }
func (r *Repository) WithBlobUploader(ctx context.Context, fn func(ctx context.Context) error) error { func (r *Repository) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaver) error) error {
wg, ctx := errgroup.WithContext(ctx) wg, ctx := errgroup.WithContext(ctx)
r.startPackUploader(ctx, wg) r.startPackUploader(ctx, wg)
wg.Go(func() error { wg.Go(func() error {
if err := fn(ctx); err != nil { if err := fn(ctx, &blobSaverRepo{repo: r}); err != nil {
return err return err
} }
if err := r.flush(ctx); err != nil { if err := r.flush(ctx); err != nil {
@ -574,6 +574,14 @@ func (r *Repository) WithBlobUploader(ctx context.Context, fn func(ctx context.C
return wg.Wait() return wg.Wait()
} }
type blobSaverRepo struct {
repo *Repository
}
func (r *blobSaverRepo) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (newID restic.ID, known bool, size int, err error) {
return r.repo.saveBlob(ctx, t, buf, id, storeDuplicate)
}
func (r *Repository) startPackUploader(ctx context.Context, wg *errgroup.Group) { func (r *Repository) startPackUploader(ctx context.Context, wg *errgroup.Group) {
if r.packerWg != nil { if r.packerWg != nil {
panic("uploader already started") panic("uploader already started")
@ -926,14 +934,14 @@ func (r *Repository) Close() error {
return r.be.Close() return r.be.Close()
} }
// SaveBlob saves a blob of type t into the repository. // saveBlob saves a blob of type t into the repository.
// It takes care that no duplicates are saved; this can be overwritten // It takes care that no duplicates are saved; this can be overwritten
// by setting storeDuplicate to true. // by setting storeDuplicate to true.
// If id is the null id, it will be computed and returned. // If id is the null id, it will be computed and returned.
// Also returns if the blob was already known before. // Also returns if the blob was already known before.
// If the blob was not known before, it returns the number of bytes the blob // If the blob was not known before, it returns the number of bytes the blob
// occupies in the repo (compressed or not, including encryption overhead). // occupies in the repo (compressed or not, including encryption overhead).
func (r *Repository) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (newID restic.ID, known bool, size int, err error) { func (r *Repository) saveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (newID restic.ID, known bool, size int, err error) {
if int64(len(buf)) > math.MaxUint32 { if int64(len(buf)) > math.MaxUint32 {
return restic.ID{}, false, 0, fmt.Errorf("blob is larger than 4GB") return restic.ID{}, false, 0, fmt.Errorf("blob is larger than 4GB")

View file

@ -51,13 +51,13 @@ func testSave(t *testing.T, version uint, calculateID bool) {
id := restic.Hash(data) id := restic.Hash(data)
rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
// save // save
inputID := restic.ID{} inputID := restic.ID{}
if !calculateID { if !calculateID {
inputID = id inputID = id
} }
sid, _, _, err := repo.SaveBlob(ctx, restic.DataBlob, data, inputID, false) sid, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, data, inputID, false)
rtest.OK(t, err) rtest.OK(t, err)
rtest.Equals(t, id, sid) rtest.Equals(t, id, sid)
return nil return nil
@ -97,7 +97,7 @@ func testSavePackMerging(t *testing.T, targetPercentage int, expectedPacks int)
}) })
var ids restic.IDs var ids restic.IDs
rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
// add blobs with size targetPercentage / 100 * repo.PackSize to the repository // add blobs with size targetPercentage / 100 * repo.PackSize to the repository
blobSize := repository.MinPackSize / 100 blobSize := repository.MinPackSize / 100
for range targetPercentage { for range targetPercentage {
@ -105,7 +105,7 @@ func testSavePackMerging(t *testing.T, targetPercentage int, expectedPacks int)
_, err := io.ReadFull(rnd, data) _, err := io.ReadFull(rnd, data)
rtest.OK(t, err) rtest.OK(t, err)
sid, _, _, err := repo.SaveBlob(ctx, restic.DataBlob, data, restic.ID{}, false) sid, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, data, restic.ID{}, false)
rtest.OK(t, err) rtest.OK(t, err)
ids = append(ids, sid) ids = append(ids, sid)
} }
@ -147,9 +147,9 @@ func benchmarkSaveAndEncrypt(t *testing.B, version uint) {
t.ResetTimer() t.ResetTimer()
t.SetBytes(int64(size)) t.SetBytes(int64(size))
_ = repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { _ = repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
for i := 0; i < t.N; i++ { for i := 0; i < t.N; i++ {
_, _, _, err = repo.SaveBlob(ctx, restic.DataBlob, data, id, true) _, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, data, id, true)
rtest.OK(t, err) rtest.OK(t, err)
} }
return nil return nil
@ -168,9 +168,9 @@ func testLoadBlob(t *testing.T, version uint) {
rtest.OK(t, err) rtest.OK(t, err)
var id restic.ID var id restic.ID
rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
var err error var err error
id, _, _, err = repo.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) id, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false)
return err return err
})) }))
@ -196,9 +196,9 @@ func TestLoadBlobBroken(t *testing.T) {
buf := rtest.Random(42, 1000) buf := rtest.Random(42, 1000)
var id restic.ID var id restic.ID
rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
var err error var err error
id, _, _, err = repo.SaveBlob(ctx, restic.TreeBlob, buf, restic.ID{}, false) id, _, _, err = uploader.SaveBlob(ctx, restic.TreeBlob, buf, restic.ID{}, false)
return err return err
})) }))
@ -225,9 +225,9 @@ func benchmarkLoadBlob(b *testing.B, version uint) {
rtest.OK(b, err) rtest.OK(b, err)
var id restic.ID var id restic.ID
rtest.OK(b, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { rtest.OK(b, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
var err error var err error
id, _, _, err = repo.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) id, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false)
return err return err
})) }))
@ -361,7 +361,7 @@ func TestRepositoryLoadUnpackedRetryBroken(t *testing.T) {
// saveRandomDataBlobs generates random data blobs and saves them to the repository. // saveRandomDataBlobs generates random data blobs and saves them to the repository.
func saveRandomDataBlobs(t testing.TB, repo restic.Repository, num int, sizeMax int) { func saveRandomDataBlobs(t testing.TB, repo restic.Repository, num int, sizeMax int) {
rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
for i := 0; i < num; i++ { for i := 0; i < num; i++ {
size := rand.Int() % sizeMax size := rand.Int() % sizeMax
@ -369,7 +369,7 @@ func saveRandomDataBlobs(t testing.TB, repo restic.Repository, num int, sizeMax
_, err := io.ReadFull(rnd, buf) _, err := io.ReadFull(rnd, buf)
rtest.OK(t, err) rtest.OK(t, err)
_, _, _, err = repo.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) _, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false)
rtest.OK(t, err) rtest.OK(t, err)
} }
return nil return nil
@ -432,9 +432,9 @@ func TestListPack(t *testing.T) {
buf := rtest.Random(42, 1000) buf := rtest.Random(42, 1000)
var id restic.ID var id restic.ID
rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error {
var err error var err error
id, _, _, err = repo.SaveBlob(ctx, restic.TreeBlob, buf, restic.ID{}, false) id, _, _, err = uploader.SaveBlob(ctx, restic.TreeBlob, buf, restic.ID{}, false)
return err return err
})) }))

View file

@ -39,8 +39,7 @@ type Repository interface {
// WithUploader starts the necessary workers to upload new blobs. Once the callback returns, // WithUploader starts the necessary workers to upload new blobs. Once the callback returns,
// the workers are stopped and the index is written to the repository. The callback must use // the workers are stopped and the index is written to the repository. The callback must use
// the passed context and must not keep references to any of its parameters after returning. // the passed context and must not keep references to any of its parameters after returning.
WithBlobUploader(ctx context.Context, fn func(ctx context.Context) error) error WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader BlobSaver) error) error
SaveBlob(ctx context.Context, t BlobType, buf []byte, id ID, storeDuplicate bool) (newID ID, known bool, size int, err error)
// List calls the function fn for each file of type t in the repository. // List calls the function fn for each file of type t in the repository.
// When an error is returned by fn, processing stops and List() returns the // When an error is returned by fn, processing stops and List() returns the
@ -158,6 +157,10 @@ type BlobLoader interface {
LoadBlob(context.Context, BlobType, ID, []byte) ([]byte, error) LoadBlob(context.Context, BlobType, ID, []byte) ([]byte, error)
} }
type WithBlobUploader interface {
WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader BlobSaver) error) error
}
type BlobSaver interface { type BlobSaver interface {
SaveBlob(context.Context, BlobType, []byte, ID, bool) (ID, bool, int, error) SaveBlob(context.Context, BlobType, []byte, ID, bool) (ID, bool, int, error)
} }

View file

@ -171,8 +171,8 @@ func saveSnapshot(t testing.TB, repo restic.Repository, snapshot Snapshot, getGe
defer cancel() defer cancel()
var treeID restic.ID var treeID restic.ID
rtest.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context) error { rtest.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error {
treeID = saveDir(t, repo, snapshot.Nodes, 1000, getGenericAttributes) treeID = saveDir(t, uploader, snapshot.Nodes, 1000, getGenericAttributes)
return nil return nil
})) }))

View file

@ -11,7 +11,7 @@ import (
) )
type NodeRewriteFunc func(node *data.Node, path string) *data.Node type NodeRewriteFunc func(node *data.Node, path string) *data.Node
type FailedTreeRewriteFunc func(nodeID restic.ID, path string, err error) (restic.ID, error) type FailedTreeRewriteFunc func(nodeID restic.ID, path string, err error) (*data.Tree, error)
type QueryRewrittenSizeFunc func() SnapshotSize type QueryRewrittenSizeFunc func() SnapshotSize
type SnapshotSize struct { type SnapshotSize struct {
@ -52,8 +52,8 @@ func NewTreeRewriter(opts RewriteOpts) *TreeRewriter {
} }
if rw.opts.RewriteFailedTree == nil { if rw.opts.RewriteFailedTree == nil {
// fail with error by default // fail with error by default
rw.opts.RewriteFailedTree = func(_ restic.ID, _ string, err error) (restic.ID, error) { rw.opts.RewriteFailedTree = func(_ restic.ID, _ string, err error) (*data.Tree, error) {
return restic.ID{}, err return nil, err
} }
} }
return rw return rw
@ -82,12 +82,7 @@ func NewSnapshotSizeRewriter(rewriteNode NodeRewriteFunc) (*TreeRewriter, QueryR
return t, ss return t, ss
} }
type BlobLoadSaver interface { func (t *TreeRewriter) RewriteTree(ctx context.Context, loader restic.BlobLoader, saver restic.BlobSaver, nodepath string, nodeID restic.ID) (newNodeID restic.ID, err error) {
restic.BlobSaver
restic.BlobLoader
}
func (t *TreeRewriter) RewriteTree(ctx context.Context, repo BlobLoadSaver, nodepath string, nodeID restic.ID) (newNodeID restic.ID, err error) {
// check if tree was already changed // check if tree was already changed
newID, ok := t.replaces[nodeID] newID, ok := t.replaces[nodeID]
if ok { if ok {
@ -95,16 +90,27 @@ func (t *TreeRewriter) RewriteTree(ctx context.Context, repo BlobLoadSaver, node
} }
// a nil nodeID will lead to a load error // a nil nodeID will lead to a load error
curTree, err := data.LoadTree(ctx, repo, nodeID) curTree, err := data.LoadTree(ctx, loader, nodeID)
if err != nil { if err != nil {
return t.opts.RewriteFailedTree(nodeID, nodepath, err) replacement, err := t.opts.RewriteFailedTree(nodeID, nodepath, err)
if err != nil {
return restic.ID{}, err
}
if replacement != nil {
replacementID, err := data.SaveTree(ctx, saver, replacement)
if err != nil {
return restic.ID{}, err
}
return replacementID, nil
}
return restic.ID{}, nil
} }
if !t.opts.AllowUnstableSerialization { if !t.opts.AllowUnstableSerialization {
// check that we can properly encode this tree without losing information // check that we can properly encode this tree without losing information
// The alternative of using json/Decoder.DisallowUnknownFields() doesn't work as we use // The alternative of using json/Decoder.DisallowUnknownFields() doesn't work as we use
// a custom UnmarshalJSON to decode trees, see also https://github.com/golang/go/issues/41144 // a custom UnmarshalJSON to decode trees, see also https://github.com/golang/go/issues/41144
testID, err := data.SaveTree(ctx, repo, curTree) testID, err := data.SaveTree(ctx, saver, curTree)
if err != nil { if err != nil {
return restic.ID{}, err return restic.ID{}, err
} }
@ -139,7 +145,7 @@ func (t *TreeRewriter) RewriteTree(ctx context.Context, repo BlobLoadSaver, node
if node.Subtree != nil { if node.Subtree != nil {
subtree = *node.Subtree subtree = *node.Subtree
} }
newID, err := t.RewriteTree(ctx, repo, path, subtree) newID, err := t.RewriteTree(ctx, loader, saver, path, subtree)
if err != nil { if err != nil {
return restic.ID{}, err return restic.ID{}, err
} }
@ -156,7 +162,7 @@ func (t *TreeRewriter) RewriteTree(ctx context.Context, repo BlobLoadSaver, node
} }
// Save new tree // Save new tree
newTreeID, _, _, err := repo.SaveBlob(ctx, restic.TreeBlob, tree, restic.ID{}, false) newTreeID, _, _, err := saver.SaveBlob(ctx, restic.TreeBlob, tree, restic.ID{}, false)
if t.replaces != nil { if t.replaces != nil {
t.replaces[nodeID] = newTreeID t.replaces[nodeID] = newTreeID
} }

View file

@ -285,7 +285,7 @@ func TestRewriter(t *testing.T) {
defer cancel() defer cancel()
rewriter, last := test.check(t) rewriter, last := test.check(t)
newRoot, err := rewriter.RewriteTree(ctx, modrepo, "/", root) newRoot, err := rewriter.RewriteTree(ctx, modrepo, modrepo, "/", root)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -335,7 +335,7 @@ func TestSnapshotSizeQuery(t *testing.T) {
return node return node
} }
rewriter, querySize := NewSnapshotSizeRewriter(rewriteNode) rewriter, querySize := NewSnapshotSizeRewriter(rewriteNode)
newRoot, err := rewriter.RewriteTree(ctx, modrepo, "/", root) newRoot, err := rewriter.RewriteTree(ctx, modrepo, modrepo, "/", root)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -373,7 +373,7 @@ func TestRewriterFailOnUnknownFields(t *testing.T) {
return node return node
}, },
}) })
_, err := rewriter.RewriteTree(ctx, tm, "/", id) _, err := rewriter.RewriteTree(ctx, tm, tm, "/", id)
if err == nil { if err == nil {
t.Error("missing error on unknown field") t.Error("missing error on unknown field")
@ -383,7 +383,7 @@ func TestRewriterFailOnUnknownFields(t *testing.T) {
rewriter = NewTreeRewriter(RewriteOpts{ rewriter = NewTreeRewriter(RewriteOpts{
AllowUnstableSerialization: true, AllowUnstableSerialization: true,
}) })
root, err := rewriter.RewriteTree(ctx, tm, "/", id) root, err := rewriter.RewriteTree(ctx, tm, tm, "/", id)
test.OK(t, err) test.OK(t, err)
_, expRoot := BuildTreeMap(TestTree{ _, expRoot := BuildTreeMap(TestTree{
"subfile": TestFile{}, "subfile": TestFile{},
@ -400,21 +400,24 @@ func TestRewriterTreeLoadError(t *testing.T) {
// also check that load error by default cause the operation to fail // also check that load error by default cause the operation to fail
rewriter := NewTreeRewriter(RewriteOpts{}) rewriter := NewTreeRewriter(RewriteOpts{})
_, err := rewriter.RewriteTree(ctx, tm, "/", id) _, err := rewriter.RewriteTree(ctx, tm, tm, "/", id)
if err == nil { if err == nil {
t.Fatal("missing error on unloadable tree") t.Fatal("missing error on unloadable tree")
} }
replacementID := restic.NewRandomID() replacementTree := &data.Tree{Nodes: []*data.Node{{Name: "replacement", Type: data.NodeTypeFile, Size: 42}}}
replacementID, err := data.SaveTree(ctx, tm, replacementTree)
test.OK(t, err)
rewriter = NewTreeRewriter(RewriteOpts{ rewriter = NewTreeRewriter(RewriteOpts{
RewriteFailedTree: func(nodeID restic.ID, path string, err error) (restic.ID, error) { RewriteFailedTree: func(nodeID restic.ID, path string, err error) (*data.Tree, error) {
if nodeID != id || path != "/" { if nodeID != id || path != "/" {
t.Fail() t.Fail()
} }
return replacementID, nil return replacementTree, nil
}, },
}) })
newRoot, err := rewriter.RewriteTree(ctx, tm, "/", id) newRoot, err := rewriter.RewriteTree(ctx, tm, tm, "/", id)
test.OK(t, err) test.OK(t, err)
test.Equals(t, replacementID, newRoot) test.Equals(t, replacementID, newRoot)
} }