Enhancement: can now listen on a unix socket

This commit is contained in:
Adam Eijdenberg 2024-02-05 10:37:01 +11:00
parent 55f43b815c
commit e0674c6150
3 changed files with 102 additions and 3 deletions

View file

@ -0,0 +1,12 @@
Enhancement: can now listen on a unix socket
If `--listen unix:/tmp/foo` is passed, the server will listen on a unix socket. This is triggered by the prefix `unix:`.
This is useful in combination with remote port portforwarding to enable remote server to backup locally, e.g.
```bash
rest-server --listen unix:/tmp/foo &
ssh -R /tmp/foo:/tmp/foo user@host restic -r rest:http+unix:/tmp/foo:/repo backup
```
https://github.com/restic/rest-server/pull/272

View file

@ -7,6 +7,7 @@ import (
"fmt"
"log"
"net"
"strings"
"github.com/coreos/go-systemd/v22/activation"
)
@ -23,9 +24,20 @@ func findListener(addr string) (listener net.Listener, err error) {
switch len(listeners) {
case 0:
// no listeners found, listen manually
listener, err = net.Listen("tcp", addr)
if err != nil {
return nil, fmt.Errorf("listen on %v failed: %w", addr, err)
if strings.HasPrefix(addr, "unix:") { // if we want to listen on a unix socket
unixAddr, err := net.ResolveUnixAddr("unix", strings.TrimPrefix(addr, "unix:"))
if err != nil {
return nil, fmt.Errorf("unable to understand unix address %s: %w", addr, err)
}
listener, err = net.ListenUnix("unix", unixAddr)
if err != nil {
return nil, fmt.Errorf("listen on %v failed: %w", addr, err)
}
} else { // assume tcp
listener, err = net.Listen("tcp", addr)
if err != nil {
return nil, fmt.Errorf("listen on %v failed: %w", addr, err)
}
}
log.Printf("start server on %v", listener.Addr())

View file

@ -0,0 +1,75 @@
//go:build !windows
// +build !windows
package main
import (
"context"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"testing"
"time"
)
func TestUnixSocket(t *testing.T) {
td := t.TempDir()
// this is the socket we'll listen on and connect to
tempSocket := filepath.Join(td, "sock")
// create some content and parent dirs
if err := os.MkdirAll(filepath.Join(td, "data", "repo1"), 0700); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(td, "data", "repo1", "config"), []byte("foo"), 0700); err != nil {
t.Fatal(err)
}
// run the following twice, to test that the server will
// cleanup its socket file when quitting, which won't happen
// if it doesn't exit gracefully
for i := 0; i < 2; i++ {
err := testServerWithArgs([]string{
"--no-auth",
"--path", filepath.Join(td, "data"),
"--listen", fmt.Sprintf("unix:%s", tempSocket),
}, time.Second, func(ctx context.Context, _ *restServerApp) error {
// custom client that will talk HTTP to unix socket
client := http.Client{
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", tempSocket)
},
},
}
for _, test := range []struct {
Path string
StatusCode int
}{
{"/repo1/", http.StatusMethodNotAllowed},
{"/repo1/config", http.StatusOK},
{"/repo2/config", http.StatusNotFound},
} {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://ignored"+test.Path, nil)
if err != nil {
return err
}
resp, err := client.Do(req)
if err != nil {
return err
}
resp.Body.Close()
if resp.StatusCode != test.StatusCode {
return fmt.Errorf("expected %d from server, instead got %d (path %s)", test.StatusCode, resp.StatusCode, test.Path)
}
}
return nil
})
if err != nil {
t.Fatal(err)
}
}
}