diff --git a/cmd/restic/cmd_copy.go b/cmd/restic/cmd_copy.go index d6a5efe57..f209015f0 100644 --- a/cmd/restic/cmd_copy.go +++ b/cmd/restic/cmd_copy.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "iter" "time" "github.com/restic/restic/internal/data" @@ -75,33 +76,33 @@ func (opts *CopyOptions) AddFlags(f *pflag.FlagSet) { func collectAllSnapshots(ctx context.Context, opts CopyOptions, srcSnapshotLister restic.Lister, srcRepo restic.Repository, dstSnapshotByOriginal map[restic.ID][]*data.Snapshot, args []string, printer progress.Printer, -) (selectedSnapshots []*data.Snapshot) { - - selectedSnapshots = make([]*data.Snapshot, 0, 10) - for sn := range FindFilteredSnapshots(ctx, srcSnapshotLister, srcRepo, &opts.SnapshotFilter, args, printer) { - // check whether the destination has a snapshot with the same persistent ID which has similar snapshot fields - srcOriginal := *sn.ID() - if sn.Original != nil { - srcOriginal = *sn.Original - } - if originalSns, ok := dstSnapshotByOriginal[srcOriginal]; ok { - isCopy := false - for _, originalSn := range originalSns { - if similarSnapshots(originalSn, sn) { - printer.V("\n%v", sn) - printer.V("skipping source snapshot %s, was already copied to snapshot %s", sn.ID().Str(), originalSn.ID().Str()) - isCopy = true - break +) iter.Seq[*data.Snapshot] { + return func(yield func(*data.Snapshot) bool) { + for sn := range FindFilteredSnapshots(ctx, srcSnapshotLister, srcRepo, &opts.SnapshotFilter, args, printer) { + // check whether the destination has a snapshot with the same persistent ID which has similar snapshot fields + srcOriginal := *sn.ID() + if sn.Original != nil { + srcOriginal = *sn.Original + } + if originalSns, ok := dstSnapshotByOriginal[srcOriginal]; ok { + isCopy := false + for _, originalSn := range originalSns { + if similarSnapshots(originalSn, sn) { + printer.V("\n%v", sn) + printer.V("skipping source snapshot %s, was already copied to snapshot %s", sn.ID().Str(), originalSn.ID().Str()) + isCopy = true + break + } + } + if isCopy { + continue } } - if isCopy { - continue + if !yield(sn) { + return } } - selectedSnapshots = append(selectedSnapshots, sn) } - - return selectedSnapshots } func runCopy(ctx context.Context, opts CopyOptions, gopts global.Options, args []string, term ui.Terminal) error { @@ -189,7 +190,7 @@ func similarSnapshots(sna *data.Snapshot, snb *data.Snapshot) bool { // copyTreeBatched copies multiple snapshots in one go. Snapshots are written after // data equivalent to at least 10 packfiles was written. func copyTreeBatched(ctx context.Context, srcRepo restic.Repository, dstRepo restic.Repository, - selectedSnapshots []*data.Snapshot, printer progress.Printer) error { + selectedSnapshots iter.Seq[*data.Snapshot], printer progress.Printer) error { // remember already processed trees across all snapshots visitedTrees := srcRepo.NewAssociatedBlobSet() @@ -197,16 +198,23 @@ func copyTreeBatched(ctx context.Context, srcRepo restic.Repository, dstRepo res targetSize := uint64(dstRepo.PackSize()) * 100 minDuration := 1 * time.Minute - for len(selectedSnapshots) > 0 { + // use pull-based iterator to allow iteration in multiple steps + next, stop := iter.Pull(selectedSnapshots) + defer stop() + + for { var batch []*data.Snapshot batchSize := uint64(0) startTime := time.Now() // call WithBlobUploader() once and then loop over all selectedSnapshots err := dstRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { - for len(selectedSnapshots) > 0 && (batchSize < targetSize || time.Since(startTime) < minDuration) { - sn := selectedSnapshots[0] - selectedSnapshots = selectedSnapshots[1:] + for batchSize < targetSize || time.Since(startTime) < minDuration { + sn, ok := next() + if !ok { + break + } + batch = append(batch, sn) printer.P("\n%v", sn) @@ -225,6 +233,11 @@ func copyTreeBatched(ctx context.Context, srcRepo restic.Repository, dstRepo res return err } + // if no snapshots were processed in this batch, we're done + if len(batch) == 0 { + break + } + // add a newline to separate saved snapshot messages from the other messages if len(batch) > 1 { printer.P("")