diff --git a/src/mime/multipart/formdata.go b/src/mime/multipart/formdata.go index fca5f9e15fb..41bc886d167 100644 --- a/src/mime/multipart/formdata.go +++ b/src/mime/multipart/formdata.go @@ -7,6 +7,7 @@ package multipart import ( "bytes" "errors" + "internal/godebug" "io" "math" "net/textproto" @@ -31,25 +32,61 @@ func (r *Reader) ReadForm(maxMemory int64) (*Form, error) { return r.readForm(maxMemory) } +var multipartFiles = godebug.New("multipartfiles") + func (r *Reader) readForm(maxMemory int64) (_ *Form, err error) { form := &Form{make(map[string][]string), make(map[string][]*FileHeader)} + var ( + file *os.File + fileOff int64 + ) + numDiskFiles := 0 + combineFiles := multipartFiles.Value() != "distinct" defer func() { + if file != nil { + if cerr := file.Close(); err == nil { + err = cerr + } + } + if combineFiles && numDiskFiles > 1 { + for _, fhs := range form.File { + for _, fh := range fhs { + fh.tmpshared = true + } + } + } if err != nil { form.RemoveAll() + if file != nil { + os.Remove(file.Name()) + } } }() - // Reserve an additional 10 MB for non-file parts. - maxValueBytes := maxMemory + int64(10<<20) - if maxValueBytes <= 0 { + // maxFileMemoryBytes is the maximum bytes of file data we will store in memory. + // Data past this limit is written to disk. + // This limit strictly applies to content, not metadata (filenames, MIME headers, etc.), + // since metadata is always stored in memory, not disk. + // + // maxMemoryBytes is the maximum bytes we will store in memory, including file content, + // non-file part values, metdata, and map entry overhead. + // + // We reserve an additional 10 MB in maxMemoryBytes for non-file data. + // + // The relationship between these parameters, as well as the overly-large and + // unconfigurable 10 MB added on to maxMemory, is unfortunate but difficult to change + // within the constraints of the API as documented. + maxFileMemoryBytes := maxMemory + maxMemoryBytes := maxMemory + int64(10<<20) + if maxMemoryBytes <= 0 { if maxMemory < 0 { - maxValueBytes = 0 + maxMemoryBytes = 0 } else { - maxValueBytes = math.MaxInt64 + maxMemoryBytes = math.MaxInt64 } } for { - p, err := r.NextPart() + p, err := r.nextPart(false, maxMemoryBytes) if err == io.EOF { break } @@ -63,16 +100,27 @@ func (r *Reader) readForm(maxMemory int64) (_ *Form, err error) { } filename := p.FileName() + // Multiple values for the same key (one map entry, longer slice) are cheaper + // than the same number of values for different keys (many map entries), but + // using a consistent per-value cost for overhead is simpler. + maxMemoryBytes -= int64(len(name)) + maxMemoryBytes -= 100 // map overhead + if maxMemoryBytes < 0 { + // We can't actually take this path, since nextPart would already have + // rejected the MIME headers for being too large. Check anyway. + return nil, ErrMessageTooLarge + } + var b bytes.Buffer if filename == "" { // value, store as string in memory - n, err := io.CopyN(&b, p, maxValueBytes+1) + n, err := io.CopyN(&b, p, maxMemoryBytes+1) if err != nil && err != io.EOF { return nil, err } - maxValueBytes -= n - if maxValueBytes < 0 { + maxMemoryBytes -= n + if maxMemoryBytes < 0 { return nil, ErrMessageTooLarge } form.Value[name] = append(form.Value[name], b.String()) @@ -80,35 +128,45 @@ func (r *Reader) readForm(maxMemory int64) (_ *Form, err error) { } // file, store in memory or on disk + maxMemoryBytes -= mimeHeaderSize(p.Header) + if maxMemoryBytes < 0 { + return nil, ErrMessageTooLarge + } fh := &FileHeader{ Filename: filename, Header: p.Header, } - n, err := io.CopyN(&b, p, maxMemory+1) + n, err := io.CopyN(&b, p, maxFileMemoryBytes+1) if err != nil && err != io.EOF { return nil, err } - if n > maxMemory { - // too big, write to disk and flush buffer - file, err := os.CreateTemp("", "multipart-") - if err != nil { - return nil, err + if n > maxFileMemoryBytes { + if file == nil { + file, err = os.CreateTemp(r.tempDir, "multipart-") + if err != nil { + return nil, err + } } + numDiskFiles++ size, err := io.Copy(file, io.MultiReader(&b, p)) - if cerr := file.Close(); err == nil { - err = cerr - } if err != nil { - os.Remove(file.Name()) return nil, err } fh.tmpfile = file.Name() fh.Size = size + fh.tmpoff = fileOff + fileOff += size + if !combineFiles { + if err := file.Close(); err != nil { + return nil, err + } + file = nil + } } else { fh.content = b.Bytes() fh.Size = int64(len(fh.content)) - maxMemory -= n - maxValueBytes -= n + maxFileMemoryBytes -= n + maxMemoryBytes -= n } form.File[name] = append(form.File[name], fh) } @@ -116,6 +174,17 @@ func (r *Reader) readForm(maxMemory int64) (_ *Form, err error) { return form, nil } +func mimeHeaderSize(h textproto.MIMEHeader) (size int64) { + for k, vs := range h { + size += int64(len(k)) + size += 100 // map entry overhead + for _, v := range vs { + size += int64(len(v)) + } + } + return size +} + // Form is a parsed multipart form. // Its File parts are stored either in memory or on disk, // and are accessible via the *FileHeader's Open method. @@ -133,7 +202,7 @@ func (f *Form) RemoveAll() error { for _, fh := range fhs { if fh.tmpfile != "" { e := os.Remove(fh.tmpfile) - if e != nil && err == nil { + if e != nil && !errors.Is(e, os.ErrNotExist) && err == nil { err = e } } @@ -148,15 +217,25 @@ type FileHeader struct { Header textproto.MIMEHeader Size int64 - content []byte - tmpfile string + content []byte + tmpfile string + tmpoff int64 + tmpshared bool } // Open opens and returns the FileHeader's associated File. func (fh *FileHeader) Open() (File, error) { if b := fh.content; b != nil { r := io.NewSectionReader(bytes.NewReader(b), 0, int64(len(b))) - return sectionReadCloser{r}, nil + return sectionReadCloser{r, nil}, nil + } + if fh.tmpshared { + f, err := os.Open(fh.tmpfile) + if err != nil { + return nil, err + } + r := io.NewSectionReader(f, fh.tmpoff, fh.Size) + return sectionReadCloser{r, f}, nil } return os.Open(fh.tmpfile) } @@ -175,8 +254,12 @@ type File interface { type sectionReadCloser struct { *io.SectionReader + io.Closer } func (rc sectionReadCloser) Close() error { + if rc.Closer != nil { + return rc.Closer.Close() + } return nil } diff --git a/src/mime/multipart/formdata_test.go b/src/mime/multipart/formdata_test.go index 8a4eabcee03..8a862be7174 100644 --- a/src/mime/multipart/formdata_test.go +++ b/src/mime/multipart/formdata_test.go @@ -5,8 +5,11 @@ package multipart import ( + "bytes" + "fmt" "io" "math" + "net/textproto" "os" "strings" "testing" @@ -207,8 +210,8 @@ Content-Disposition: form-data; name="largetext" maxMemory int64 err error }{ - {"smaller", 50, nil}, - {"exact-fit", 25, nil}, + {"smaller", 50 + int64(len("largetext")) + 100, nil}, + {"exact-fit", 25 + int64(len("largetext")) + 100, nil}, {"too-large", 0, ErrMessageTooLarge}, } for _, tc := range testCases { @@ -223,7 +226,7 @@ Content-Disposition: form-data; name="largetext" defer f.RemoveAll() } if tc.err != err { - t.Fatalf("ReadForm error - got: %v; expected: %v", tc.err, err) + t.Fatalf("ReadForm error - got: %v; expected: %v", err, tc.err) } if err == nil { if g := f.Value["largetext"][0]; g != largeTextValue { @@ -233,3 +236,135 @@ Content-Disposition: form-data; name="largetext" }) } } + +// TestReadForm_MetadataTooLarge verifies that we account for the size of field names, +// MIME headers, and map entry overhead while limiting the memory consumption of parsed forms. +func TestReadForm_MetadataTooLarge(t *testing.T) { + for _, test := range []struct { + name string + f func(*Writer) + }{{ + name: "large name", + f: func(fw *Writer) { + name := strings.Repeat("a", 10<<20) + w, _ := fw.CreateFormField(name) + w.Write([]byte("value")) + }, + }, { + name: "large MIME header", + f: func(fw *Writer) { + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", `form-data; name="a"`) + h.Set("X-Foo", strings.Repeat("a", 10<<20)) + w, _ := fw.CreatePart(h) + w.Write([]byte("value")) + }, + }, { + name: "many parts", + f: func(fw *Writer) { + for i := 0; i < 110000; i++ { + w, _ := fw.CreateFormField("f") + w.Write([]byte("v")) + } + }, + }} { + t.Run(test.name, func(t *testing.T) { + var buf bytes.Buffer + fw := NewWriter(&buf) + test.f(fw) + if err := fw.Close(); err != nil { + t.Fatal(err) + } + fr := NewReader(&buf, fw.Boundary()) + _, err := fr.ReadForm(0) + if err != ErrMessageTooLarge { + t.Errorf("fr.ReadForm() = %v, want ErrMessageTooLarge", err) + } + }) + } +} + +// TestReadForm_ManyFiles_Combined tests that a multipart form containing many files only +// results in a single on-disk file. +func TestReadForm_ManyFiles_Combined(t *testing.T) { + const distinct = false + testReadFormManyFiles(t, distinct) +} + +// TestReadForm_ManyFiles_Distinct tests that setting GODEBUG=multipartfiles=distinct +// results in every file in a multipart form being placed in a distinct on-disk file. +func TestReadForm_ManyFiles_Distinct(t *testing.T) { + t.Setenv("GODEBUG", "multipartfiles=distinct") + const distinct = true + testReadFormManyFiles(t, distinct) +} + +func testReadFormManyFiles(t *testing.T, distinct bool) { + var buf bytes.Buffer + fw := NewWriter(&buf) + const numFiles = 10 + for i := 0; i < numFiles; i++ { + name := fmt.Sprint(i) + w, err := fw.CreateFormFile(name, name) + if err != nil { + t.Fatal(err) + } + w.Write([]byte(name)) + } + if err := fw.Close(); err != nil { + t.Fatal(err) + } + fr := NewReader(&buf, fw.Boundary()) + fr.tempDir = t.TempDir() + form, err := fr.ReadForm(0) + if err != nil { + t.Fatal(err) + } + for i := 0; i < numFiles; i++ { + name := fmt.Sprint(i) + if got := len(form.File[name]); got != 1 { + t.Fatalf("form.File[%q] has %v entries, want 1", name, got) + } + fh := form.File[name][0] + file, err := fh.Open() + if err != nil { + t.Fatalf("form.File[%q].Open() = %v", name, err) + } + if distinct { + if _, ok := file.(*os.File); !ok { + t.Fatalf("form.File[%q].Open: %T, want *os.File", name, file) + } + } + got, err := io.ReadAll(file) + file.Close() + if string(got) != name || err != nil { + t.Fatalf("read form.File[%q]: %q, %v; want %q, nil", name, string(got), err, name) + } + } + dir, err := os.Open(fr.tempDir) + if err != nil { + t.Fatal(err) + } + defer dir.Close() + names, err := dir.Readdirnames(0) + if err != nil { + t.Fatal(err) + } + wantNames := 1 + if distinct { + wantNames = numFiles + } + if len(names) != wantNames { + t.Fatalf("temp dir contains %v files; want 1", len(names)) + } + if err := form.RemoveAll(); err != nil { + t.Fatalf("form.RemoveAll() = %v", err) + } + names, err = dir.Readdirnames(0) + if err != nil { + t.Fatal(err) + } + if len(names) != 0 { + t.Fatalf("temp dir contains %v files; want 0", len(names)) + } +} diff --git a/src/mime/multipart/multipart.go b/src/mime/multipart/multipart.go index b3a904f0aff..86ea926346e 100644 --- a/src/mime/multipart/multipart.go +++ b/src/mime/multipart/multipart.go @@ -128,12 +128,12 @@ func (r *stickyErrorReader) Read(p []byte) (n int, _ error) { return n, r.err } -func newPart(mr *Reader, rawPart bool) (*Part, error) { +func newPart(mr *Reader, rawPart bool, maxMIMEHeaderSize int64) (*Part, error) { bp := &Part{ Header: make(map[string][]string), mr: mr, } - if err := bp.populateHeaders(); err != nil { + if err := bp.populateHeaders(maxMIMEHeaderSize); err != nil { return nil, err } bp.r = partReader{bp} @@ -149,12 +149,16 @@ func newPart(mr *Reader, rawPart bool) (*Part, error) { return bp, nil } -func (p *Part) populateHeaders() error { +func (p *Part) populateHeaders(maxMIMEHeaderSize int64) error { r := textproto.NewReader(p.mr.bufReader) - header, err := r.ReadMIMEHeader() + header, err := readMIMEHeader(r, maxMIMEHeaderSize) if err == nil { p.Header = header } + // TODO: Add a distinguishable error to net/textproto. + if err != nil && err.Error() == "message too large" { + err = ErrMessageTooLarge + } return err } @@ -311,6 +315,7 @@ func (p *Part) Close() error { // isn't supported. type Reader struct { bufReader *bufio.Reader + tempDir string // used in tests currentPart *Part partsRead int @@ -321,6 +326,10 @@ type Reader struct { dashBoundary []byte // "--boundary" } +// maxMIMEHeaderSize is the maximum size of a MIME header we will parse, +// including header keys, values, and map overhead. +const maxMIMEHeaderSize = 10 << 20 + // NextPart returns the next part in the multipart or an error. // When there are no more parts, the error io.EOF is returned. // @@ -328,7 +337,7 @@ type Reader struct { // has a value of "quoted-printable", that header is instead // hidden and the body is transparently decoded during Read calls. func (r *Reader) NextPart() (*Part, error) { - return r.nextPart(false) + return r.nextPart(false, maxMIMEHeaderSize) } // NextRawPart returns the next part in the multipart or an error. @@ -337,10 +346,10 @@ func (r *Reader) NextPart() (*Part, error) { // Unlike NextPart, it does not have special handling for // "Content-Transfer-Encoding: quoted-printable". func (r *Reader) NextRawPart() (*Part, error) { - return r.nextPart(true) + return r.nextPart(true, maxMIMEHeaderSize) } -func (r *Reader) nextPart(rawPart bool) (*Part, error) { +func (r *Reader) nextPart(rawPart bool, maxMIMEHeaderSize int64) (*Part, error) { if r.currentPart != nil { r.currentPart.Close() } @@ -365,7 +374,7 @@ func (r *Reader) nextPart(rawPart bool) (*Part, error) { if r.isBoundaryDelimiterLine(line) { r.partsRead++ - bp, err := newPart(r, rawPart) + bp, err := newPart(r, rawPart, maxMIMEHeaderSize) if err != nil { return nil, err } diff --git a/src/mime/multipart/readmimeheader.go b/src/mime/multipart/readmimeheader.go new file mode 100644 index 00000000000..6836928c9e8 --- /dev/null +++ b/src/mime/multipart/readmimeheader.go @@ -0,0 +1,14 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package multipart + +import ( + "net/textproto" + _ "unsafe" // for go:linkname +) + +// readMIMEHeader is defined in package net/textproto. +// +//go:linkname readMIMEHeader net/textproto.readMIMEHeader +func readMIMEHeader(r *textproto.Reader, lim int64) (textproto.MIMEHeader, error) diff --git a/src/net/http/request_test.go b/src/net/http/request_test.go index 686a8699fb0..23e49d6b8e3 100644 --- a/src/net/http/request_test.go +++ b/src/net/http/request_test.go @@ -1097,7 +1097,7 @@ func testMissingFile(t *testing.T, req *Request) { t.Errorf("FormFile file = %v, want nil", f) } if fh != nil { - t.Errorf("FormFile file header = %q, want nil", fh) + t.Errorf("FormFile file header = %v, want nil", fh) } if err != ErrMissingFile { t.Errorf("FormFile err = %q, want ErrMissingFile", err) diff --git a/src/net/textproto/reader.go b/src/net/textproto/reader.go index 4e4999b3c95..8e800088c1f 100644 --- a/src/net/textproto/reader.go +++ b/src/net/textproto/reader.go @@ -7,8 +7,10 @@ package textproto import ( "bufio" "bytes" + "errors" "fmt" "io" + "math" "strconv" "strings" "sync" @@ -477,6 +479,12 @@ var colon = []byte(":") // "Long-Key": {"Even Longer Value"}, // } func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { + return readMIMEHeader(r, math.MaxInt64) +} + +// readMIMEHeader is a version of ReadMIMEHeader which takes a limit on the header size. +// It is called by the mime/multipart package. +func readMIMEHeader(r *Reader, lim int64) (MIMEHeader, error) { // Avoid lots of small slice allocations later by allocating one // large one ahead of time which we'll cut up into smaller // slices. If this isn't big enough later, we allocate small ones. @@ -526,9 +534,19 @@ func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { } // Skip initial spaces in value. - value := strings.TrimLeft(string(v), " \t") + value := string(bytes.TrimLeft(v, " \t")) vv := m[key] + if vv == nil { + lim -= int64(len(key)) + lim -= 100 // map entry overhead + } + lim -= int64(len(value)) + if lim < 0 { + // TODO: This should be a distinguishable error (ErrMessageTooLarge) + // to allow mime/multipart to detect it. + return m, errors.New("message too large") + } if vv == nil && len(strs) > 0 { // More than likely this will be a single-element key. // Most headers aren't multi-valued.