diff --git a/cmd/rest-server/main.go b/cmd/rest-server/main.go index 708c228..93089fb 100644 --- a/cmd/rest-server/main.go +++ b/cmd/rest-server/main.go @@ -41,6 +41,7 @@ func init() { flags.BoolVar(&server.Debug, "debug", server.Debug, "output debug messages") flags.StringVar(&server.Listen, "listen", server.Listen, "listen address") flags.StringVar(&server.Log, "log", server.Log, "log HTTP requests in the combined log format") + flags.Int64Var(&server.MaxRepoSize, "max-size", server.MaxRepoSize, "the maximum size of the repository in bytes") flags.StringVar(&server.Path, "path", server.Path, "data directory") flags.BoolVar(&server.TLS, "tls", server.TLS, "turn on TLS support") flags.StringVar(&server.TLSCert, "tls-cert", server.TLSCert, "TLS certificate path") diff --git a/handlers.go b/handlers.go index ec66e2f..d0c665f 100644 --- a/handlers.go +++ b/handlers.go @@ -34,9 +34,12 @@ type Server struct { PrivateRepos bool Prometheus bool Debug bool + MaxRepoSize int64 + + repoSize int64 // must be accessed using sync/atomic } -func (s Server) isHashed(dir string) bool { +func (s *Server) isHashed(dir string) bool { return dir == "data" } @@ -55,7 +58,7 @@ func valid(name string) bool { var validTypes = []string{"data", "index", "keys", "locks", "snapshots", "config"} -func (s Server) isValidType(name string) bool { +func (s *Server) isValidType(name string) bool { for _, tpe := range validTypes { if name == tpe { return true @@ -86,7 +89,7 @@ func join(base string, names ...string) (string, error) { } // getRepo returns the repository location, relative to s.Path. -func (s Server) getRepo(r *http.Request) string { +func (s *Server) getRepo(r *http.Request) string { if strings.HasPrefix(fmt.Sprintf("%s", middleware.Pattern(r.Context())), "/:repo") { return pat.Param(r, "repo") } @@ -95,7 +98,7 @@ func (s Server) getRepo(r *http.Request) string { } // getPath returns the path for a file type in the repo. -func (s Server) getPath(r *http.Request, fileType string) (string, error) { +func (s *Server) getPath(r *http.Request, fileType string) (string, error) { if !s.isValidType(fileType) { return "", errors.New("invalid file type") } @@ -103,7 +106,7 @@ func (s Server) getPath(r *http.Request, fileType string) (string, error) { } // getFilePath returns the path for a file in the repo. -func (s Server) getFilePath(r *http.Request, fileType, name string) (string, error) { +func (s *Server) getFilePath(r *http.Request, fileType, name string) (string, error) { if !s.isValidType(fileType) { return "", errors.New("invalid file type") } @@ -120,7 +123,7 @@ func (s Server) getFilePath(r *http.Request, fileType, name string) (string, err } // getUser returns the username from the request, or an empty string if none. -func (s Server) getUser(r *http.Request) string { +func (s *Server) getUser(r *http.Request) string { username, _, ok := r.BasicAuth() if !ok { return "" @@ -129,7 +132,7 @@ func (s Server) getUser(r *http.Request) string { } // getMetricLabels returns the prometheus labels from the request. -func (s Server) getMetricLabels(r *http.Request) prometheus.Labels { +func (s *Server) getMetricLabels(r *http.Request) prometheus.Labels { labels := prometheus.Labels{ "user": s.getUser(r), "repo": s.getRepo(r), @@ -150,7 +153,7 @@ func isUserPath(username, path string) bool { // AuthHandler wraps h with a http.HandlerFunc that performs basic authentication against the user/passwords pairs // stored in f and returns the http.HandlerFunc. -func (s Server) AuthHandler(f *HtpasswdFile, h http.Handler) http.HandlerFunc { +func (s *Server) AuthHandler(f *HtpasswdFile, h http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { username, password, ok := r.BasicAuth() if !ok || !f.Validate(username, password) { @@ -166,7 +169,7 @@ func (s Server) AuthHandler(f *HtpasswdFile, h http.Handler) http.HandlerFunc { } // CheckConfig checks whether a configuration exists. -func (s Server) CheckConfig(w http.ResponseWriter, r *http.Request) { +func (s *Server) CheckConfig(w http.ResponseWriter, r *http.Request) { if s.Debug { log.Println("CheckConfig()") } @@ -189,7 +192,7 @@ func (s Server) CheckConfig(w http.ResponseWriter, r *http.Request) { } // GetConfig allows for a config to be retrieved. -func (s Server) GetConfig(w http.ResponseWriter, r *http.Request) { +func (s *Server) GetConfig(w http.ResponseWriter, r *http.Request) { if s.Debug { log.Println("GetConfig()") } @@ -212,7 +215,7 @@ func (s Server) GetConfig(w http.ResponseWriter, r *http.Request) { } // SaveConfig allows for a config to be saved. -func (s Server) SaveConfig(w http.ResponseWriter, r *http.Request) { +func (s *Server) SaveConfig(w http.ResponseWriter, r *http.Request) { if s.Debug { log.Println("SaveConfig()") } @@ -250,7 +253,7 @@ func (s Server) SaveConfig(w http.ResponseWriter, r *http.Request) { } // DeleteConfig removes a config. -func (s Server) DeleteConfig(w http.ResponseWriter, r *http.Request) { +func (s *Server) DeleteConfig(w http.ResponseWriter, r *http.Request) { if s.Debug { log.Println("DeleteConfig()") } @@ -285,7 +288,7 @@ const ( ) // ListBlobs lists all blobs of a given type in an arbitrary order. -func (s Server) ListBlobs(w http.ResponseWriter, r *http.Request) { +func (s *Server) ListBlobs(w http.ResponseWriter, r *http.Request) { if s.Debug { log.Println("ListBlobs()") } @@ -299,7 +302,7 @@ func (s Server) ListBlobs(w http.ResponseWriter, r *http.Request) { } // ListBlobsV1 lists all blobs of a given type in an arbitrary order. -func (s Server) ListBlobsV1(w http.ResponseWriter, r *http.Request) { +func (s *Server) ListBlobsV1(w http.ResponseWriter, r *http.Request) { if s.Debug { log.Println("ListBlobsV1()") } @@ -360,7 +363,7 @@ type Blob struct { } // ListBlobsV2 lists all blobs of a given type, together with their sizes, in an arbitrary order. -func (s Server) ListBlobsV2(w http.ResponseWriter, r *http.Request) { +func (s *Server) ListBlobsV2(w http.ResponseWriter, r *http.Request) { if s.Debug { log.Println("ListBlobsV2()") } @@ -415,7 +418,7 @@ func (s Server) ListBlobsV2(w http.ResponseWriter, r *http.Request) { } // CheckBlob tests whether a blob exists. -func (s Server) CheckBlob(w http.ResponseWriter, r *http.Request) { +func (s *Server) CheckBlob(w http.ResponseWriter, r *http.Request) { if s.Debug { log.Println("CheckBlob()") } @@ -439,7 +442,7 @@ func (s Server) CheckBlob(w http.ResponseWriter, r *http.Request) { } // GetBlob retrieves a blob from the repository. -func (s Server) GetBlob(w http.ResponseWriter, r *http.Request) { +func (s *Server) GetBlob(w http.ResponseWriter, r *http.Request) { if s.Debug { log.Println("GetBlob()") } @@ -474,8 +477,24 @@ func (s Server) GetBlob(w http.ResponseWriter, r *http.Request) { } } +// tallySize counts the size of the contents of path. +func tallySize(path string) (int64, error) { + if path == "" { + path = "." + } + var size int64 + err := filepath.Walk(path, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + size += info.Size() + return nil + }) + return size, err +} + // SaveBlob saves a blob to the repository. -func (s Server) SaveBlob(w http.ResponseWriter, r *http.Request) { +func (s *Server) SaveBlob(w http.ResponseWriter, r *http.Request) { if s.Debug { log.Println("SaveBlob()") } @@ -497,12 +516,10 @@ func (s Server) SaveBlob(w http.ResponseWriter, r *http.Request) { tf, err = os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600) } } - if os.IsExist(err) { http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } - if err != nil { if s.Debug { log.Print(err) @@ -511,10 +528,29 @@ func (s Server) SaveBlob(w http.ResponseWriter, r *http.Request) { return } - written, err := io.Copy(tf, r.Body) + // ensure this blob does not put us over the repo size limit (if there is one) + var outFile io.Writer = tf + if s.MaxRepoSize != 0 { + var errCode int + outFile, errCode, err = s.maxSizeWriter(r, tf) + if err != nil { + if s.Debug { + log.Println(err) + } + if errCode > 0 { + http.Error(w, http.StatusText(errCode), errCode) + } + return + } + } + + written, err := io.Copy(outFile, r.Body) if err != nil { _ = tf.Close() _ = os.Remove(path) + if s.MaxRepoSize > 0 { + s.incrementRepoSpaceUsage(-written) + } if s.Debug { log.Print(err) } @@ -525,6 +561,9 @@ func (s Server) SaveBlob(w http.ResponseWriter, r *http.Request) { if err := tf.Sync(); err != nil { _ = tf.Close() _ = os.Remove(path) + if s.MaxRepoSize > 0 { + s.incrementRepoSpaceUsage(-written) + } if s.Debug { log.Print(err) } @@ -534,6 +573,9 @@ func (s Server) SaveBlob(w http.ResponseWriter, r *http.Request) { if err := tf.Close(); err != nil { _ = os.Remove(path) + if s.MaxRepoSize > 0 { + s.incrementRepoSpaceUsage(-written) + } if s.Debug { log.Print(err) } @@ -549,7 +591,7 @@ func (s Server) SaveBlob(w http.ResponseWriter, r *http.Request) { } // DeleteBlob deletes a blob from the repository. -func (s Server) DeleteBlob(w http.ResponseWriter, r *http.Request) { +func (s *Server) DeleteBlob(w http.ResponseWriter, r *http.Request) { if s.Debug { log.Println("DeleteBlob()") } @@ -566,9 +608,9 @@ func (s Server) DeleteBlob(w http.ResponseWriter, r *http.Request) { } var size int64 - if s.Prometheus { + if s.Prometheus || s.MaxRepoSize > 0 { stat, err := os.Stat(path) - if err != nil { + if err == nil { size = stat.Size() } } @@ -585,6 +627,9 @@ func (s Server) DeleteBlob(w http.ResponseWriter, r *http.Request) { return } + if s.MaxRepoSize > 0 { + s.incrementRepoSpaceUsage(-size) + } if s.Prometheus { labels := s.getMetricLabels(r) metricBlobDeleteTotal.With(labels).Inc() @@ -593,7 +638,7 @@ func (s Server) DeleteBlob(w http.ResponseWriter, r *http.Request) { } // CreateRepo creates repository directories. -func (s Server) CreateRepo(w http.ResponseWriter, r *http.Request) { +func (s *Server) CreateRepo(w http.ResponseWriter, r *http.Request) { if s.Debug { log.Println("CreateRepo()") } diff --git a/maxsizewriter.go b/maxsizewriter.go new file mode 100644 index 0000000..0258103 --- /dev/null +++ b/maxsizewriter.go @@ -0,0 +1,80 @@ +package restserver + +import ( + "fmt" + "io" + "net/http" + "strconv" + "sync/atomic" +) + +// maxSizeWriter limits the number of bytes written +// to the space that is currently available as given by +// the server's MaxRepoSize. This type is safe for use +// by multiple goroutines sharing the same *Server. +type maxSizeWriter struct { + io.Writer + server *Server +} + +func (w maxSizeWriter) Write(p []byte) (n int, err error) { + if int64(len(p)) > w.server.repoSpaceRemaining() { + return 0, fmt.Errorf("repository has reached maximum size (%d bytes)", w.server.MaxRepoSize) + } + n, err = w.Writer.Write(p) + w.server.incrementRepoSpaceUsage(int64(n)) + return n, err +} + +// maxSizeWriter wraps w in a writer that enforces s.MaxRepoSize. +// If there is an error, a status code and the error are returned. +func (s *Server) maxSizeWriter(req *http.Request, w io.Writer) (io.Writer, int, error) { + // if we haven't yet computed the size of the repo, do so now + currentSize := atomic.LoadInt64(&s.repoSize) + if currentSize == 0 { + initialSize, err := tallySize(s.Path) + if err != nil { + return nil, http.StatusInternalServerError, err + } + atomic.StoreInt64(&s.repoSize, initialSize) + currentSize = initialSize + } + + // if content-length is set and is trustworthy, we can save some time + // and issue a polite error if it declares a size that's too big; since + // we expect the vast majority of clients will be honest, so this check + // can only help save time + if contentLenStr := req.Header.Get("Content-Length"); contentLenStr != "" { + contentLen, err := strconv.ParseInt(contentLenStr, 10, 64) + if err != nil { + return nil, http.StatusLengthRequired, err + } + if currentSize+contentLen > s.MaxRepoSize { + err := fmt.Errorf("incoming blob (%d bytes) would exceed maximum size of repository (%d bytes)", + contentLen, s.MaxRepoSize) + return nil, http.StatusRequestEntityTooLarge, err + } + } + + // since we can't always trust content-length, we will wrap the writer + // in a custom writer that enforces the size limit during writes + return maxSizeWriter{Writer: w, server: s}, 0, nil +} + +// repoSpaceRemaining returns how much space is available in the repo +// according to s.MaxRepoSize. s.repoSize must already be set. +// If there is no limit, -1 is returned. +func (s *Server) repoSpaceRemaining() int64 { + if s.MaxRepoSize == 0 { + return -1 + } + maxSize := s.MaxRepoSize + currentSize := atomic.LoadInt64(&s.repoSize) + return maxSize - currentSize +} + +// incrementRepoSpaceUsage increments the current repo size (which +// must already be initialized). +func (s *Server) incrementRepoSpaceUsage(by int64) { + atomic.AddInt64(&s.repoSize, by) +} diff --git a/mux.go b/mux.go index 51166c1..649abe6 100644 --- a/mux.go +++ b/mux.go @@ -12,7 +12,7 @@ import ( "goji.io/pat" ) -func (s Server) debugHandler(next http.Handler) http.Handler { +func (s *Server) debugHandler(next http.Handler) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { log.Printf("%s %s", r.Method, r.URL) @@ -20,7 +20,7 @@ func (s Server) debugHandler(next http.Handler) http.Handler { }) } -func (s Server) logHandler(next http.Handler) http.Handler { +func (s *Server) logHandler(next http.Handler) http.Handler { accessLog, err := os.OpenFile(s.Log, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) if err != nil { log.Fatalf("error: %v", err)