diff --git a/auth.go b/auth.go index 55814ee..2ef234c 100644 --- a/auth.go +++ b/auth.go @@ -6,14 +6,25 @@ import ( ) func Authorize(r *http.Request) error { + + htpasswd, err := NewHtpasswdFromFile("/tmp/restic/.htpasswd") + if err != nil { + return errors.New("internal server error") + } + username, password, ok := r.BasicAuth() if !ok { return errors.New("malformed basic auth credentials") } - if username != "user" || password != "pass" { + if !htpasswd.Validate(username, password) { return errors.New("unknown user") } + repo, err := RepositoryName(r.RequestURI) + if err != nil || repo != username { + return errors.New("wrong repository") + } + return nil } diff --git a/context.go b/context.go index de87f9e..9864a29 100644 --- a/context.go +++ b/context.go @@ -20,7 +20,16 @@ func NewContext(path string) Context { return Context{path} } -func (c *Context) Repository(name string) (Repository, error) { +func (c *Context) Repository(name string) (*Repository, error) { name, err := ParseRepositoryName(name) - return Repository{filepath.Join(c.path, name)}, err + if err != nil { + return nil, err + } + + repo, err := NewRepository(filepath.Join(c.path, name)) + if err != nil { + return nil, err + } + + return repo, nil } diff --git a/handlers.go b/handlers.go index 6ec1909..a1bbdac 100644 --- a/handlers.go +++ b/handlers.go @@ -11,13 +11,13 @@ type Handler func(w http.ResponseWriter, r *http.Request, c *Context) func HeadConfig(w http.ResponseWriter, r *http.Request, c *Context) { uri := r.RequestURI - name, errrn := RepositoryName(uri) - if errrn != nil { + name, err := RepositoryName(uri) + if err != nil { http.NotFound(w, r) return } - repo, errr := c.Repository(name) - if errr != nil { + repo, err := c.Repository(name) + if err != nil { http.NotFound(w, r) return } @@ -29,18 +29,18 @@ func HeadConfig(w http.ResponseWriter, r *http.Request, c *Context) { func GetConfig(w http.ResponseWriter, r *http.Request, c *Context) { uri := r.RequestURI - name, errrn := RepositoryName(uri) - if errrn != nil { + name, err := RepositoryName(uri) + if err != nil { http.NotFound(w, r) return } - repo, errr := c.Repository(name) - if errr != nil { + repo, err := c.Repository(name) + if err != nil { http.NotFound(w, r) return } - config, errrc := repo.ReadConfig() - if errrc != nil { + config, err := repo.ReadConfig() + if err != nil { http.NotFound(w, r) return } @@ -49,37 +49,37 @@ func GetConfig(w http.ResponseWriter, r *http.Request, c *Context) { func PostConfig(w http.ResponseWriter, r *http.Request, c *Context) { uri := r.RequestURI - name, errrn := RepositoryName(uri) - if errrn != nil { + name, err := RepositoryName(uri) + if err != nil { http.NotFound(w, r) return } - repo, errr := c.Repository(name) - if errr != nil { + repo, err := c.Repository(name) + if err != nil { http.NotFound(w, r) return } - config, errc := ioutil.ReadAll(r.Body) + config, err := ioutil.ReadAll(r.Body) + if err != nil { + http.NotFound(w, r) + return + } + errc := repo.WriteConfig(config) if errc != nil { http.NotFound(w, r) return } - errwc := repo.WriteConfig(config) - if errwc != nil { - http.NotFound(w, r) - return - } } func ListBlob(w http.ResponseWriter, r *http.Request, c *Context) { uri := r.RequestURI - name, errrn := RepositoryName(uri) - if errrn != nil { + name, err := RepositoryName(uri) + if err != nil { http.NotFound(w, r) return } - repo, errr := c.Repository(name) - if errr != nil { + repo, err := c.Repository(name) + if err != nil { http.NotFound(w, r) return } @@ -88,13 +88,13 @@ func ListBlob(w http.ResponseWriter, r *http.Request, c *Context) { http.NotFound(w, r) return } - blobs, errb := repo.ListBlob(bt) - if errb != nil { + blobs, err := repo.ListBlob(bt) + if err != nil { http.NotFound(w, r) return } - json, errj := json.Marshal(blobs) - if errj != nil { + json, err := json.Marshal(blobs) + if err != nil { http.NotFound(w, r) return } @@ -103,13 +103,13 @@ func ListBlob(w http.ResponseWriter, r *http.Request, c *Context) { func HeadBlob(w http.ResponseWriter, r *http.Request, c *Context) { uri := r.RequestURI - name, errrn := RepositoryName(uri) - if errrn != nil { + name, err := RepositoryName(uri) + if err != nil { http.NotFound(w, r) return } - repo, errr := c.Repository(name) - if errr != nil { + repo, err := c.Repository(name) + if err != nil { http.NotFound(w, r) return } @@ -131,13 +131,13 @@ func HeadBlob(w http.ResponseWriter, r *http.Request, c *Context) { func GetBlob(w http.ResponseWriter, r *http.Request, c *Context) { uri := r.RequestURI - name, errrn := RepositoryName(uri) - if errrn != nil { + name, err := RepositoryName(uri) + if err != nil { http.NotFound(w, r) return } - repo, errr := c.Repository(name) - if errr != nil { + repo, err := c.Repository(name) + if err != nil { http.NotFound(w, r) return } @@ -151,8 +151,8 @@ func GetBlob(w http.ResponseWriter, r *http.Request, c *Context) { http.NotFound(w, r) return } - blob, errb := repo.ReadBlob(bt, id) - if errb != nil { + blob, errr := repo.ReadBlob(bt, id) + if errr != nil { http.NotFound(w, r) return } @@ -161,13 +161,13 @@ func GetBlob(w http.ResponseWriter, r *http.Request, c *Context) { func PostBlob(w http.ResponseWriter, r *http.Request, c *Context) { uri := r.RequestURI - name, errrn := RepositoryName(uri) - if errrn != nil { + name, err := RepositoryName(uri) + if err != nil { http.NotFound(w, r) return } - repo, errr := c.Repository(name) - if errr != nil { + repo, err := c.Repository(name) + if err != nil { http.NotFound(w, r) return } @@ -181,13 +181,13 @@ func PostBlob(w http.ResponseWriter, r *http.Request, c *Context) { http.NotFound(w, r) return } - blob, errb := ioutil.ReadAll(r.Body) - if errb != nil { + blob, err := ioutil.ReadAll(r.Body) + if err != nil { http.NotFound(w, r) return } - errwb := repo.WriteBlob(bt, id, blob) - if errwb != nil { + errw := repo.WriteBlob(bt, id, blob) + if errw != nil { http.NotFound(w, r) return } @@ -196,13 +196,13 @@ func PostBlob(w http.ResponseWriter, r *http.Request, c *Context) { func DeleteBlob(w http.ResponseWriter, r *http.Request, c *Context) { uri := r.RequestURI - name, errrn := RepositoryName(uri) - if errrn != nil { + name, err := RepositoryName(uri) + if err != nil { http.NotFound(w, r) return } - repo, errr := c.Repository(name) - if errr != nil { + repo, err := c.Repository(name) + if err != nil { http.NotFound(w, r) return } diff --git a/repository.go b/repository.go index f050e64..379f7ed 100644 --- a/repository.go +++ b/repository.go @@ -14,25 +14,25 @@ type Repository struct { path string } -// Creates the file structure of the Repository -func (r *Repository) Init() error { +func NewRepository(path string) (*Repository, error) { dirs := []string{ - r.path, - filepath.Join(r.path, string(backend.Data)), - filepath.Join(r.path, string(backend.Snapshot)), - filepath.Join(r.path, string(backend.Index)), - filepath.Join(r.path, string(backend.Lock)), - filepath.Join(r.path, string(backend.Key)), + path, + filepath.Join(path, string(backend.Data)), + filepath.Join(path, string(backend.Snapshot)), + filepath.Join(path, string(backend.Index)), + filepath.Join(path, string(backend.Lock)), + filepath.Join(path, string(backend.Key)), } for _, d := range dirs { - if _, errs := os.Stat(d); errs != nil { - errmk := os.MkdirAll(d, backend.Modes.Dir) - if errmk != nil { - return errmk + _, err := os.Stat(d) + if err != nil { + err := os.MkdirAll(d, backend.Modes.Dir) + if err != nil { + return nil, err } } } - return nil + return &Repository{path}, nil } func (r *Repository) HasConfig() bool { diff --git a/router.go b/router.go index 367f22a..4a90365 100644 --- a/router.go +++ b/router.go @@ -12,6 +12,10 @@ type Router struct { Context } +func NewRouter(context Context) *Router { + return &Router{context} +} + func (router Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { m := r.Method u := r.RequestURI diff --git a/server.go b/server.go index fd53066..fc5d621 100644 --- a/server.go +++ b/server.go @@ -1,19 +1,31 @@ package main import ( - //"io/ioutil" + "flag" "log" "net/http" + "path/filepath" +) + +const ( + HTTP = ":8000" + HTTPS = ":8443" ) func main() { - context := NewContext("/tmp/restic") - - repo, _ := context.Repository("user") - repo.Init() + var path = flag.String("path", "/tmp/restic", "specifies the path of the data directory") + var tls = flag.Bool("tls", false, "turns on tls support") + flag.Parse() + context := NewContext(*path) router := Router{context} - port := ":8000" - log.Printf("start server on port %s", port) - http.ListenAndServe(port, router) + if !*tls { + log.Printf("start server on port %s", HTTP) + http.ListenAndServe(HTTP, router) + } else { + log.Printf("start server on port %s", HTTPS) + privateKey := filepath.Join(*path, "private_key") + publicKey := filepath.Join(*path, "public_key") + http.ListenAndServeTLS(HTTPS, privateKey, publicKey, router) + } }