From 2cc8e7506efa0e6389656a194c72e1b5e763a0d0 Mon Sep 17 00:00:00 2001 From: Ascii Moth Date: Wed, 13 May 2026 00:40:59 +0400 Subject: [PATCH] Allow WebSocket listeners to configure browser origins (#1342) Adds an `origin` query option for `ws://` listener URLs so peer operators can explicitly allow browser-hosted WebSocket clients. - `ws://host:port` keeps the existing same-origin behavior - `ws://host:port?origin=demo.example.org` allows that origin host - `ws://host:port?origin=https://demo.example.org` allows that scheme and host - repeated `origin=` parameters allow multiple origin patterns - `origin=*` intentionally disables origin verification for public WebSocket peer endpoints ## Problem I've implemented a WASM based browser demo yggdrasil node to found that it cannot directly dial any existing public `ws://` or `wss://` peers. Browsers always include an `Origin` header in WebSocket handshakes, and the JavaScript `WebSocket()` constructor does not allow applications to override or remove arbitrary handshake headers. This means a browser demo served from an origin such as `http://127.0.0.1:8000` cannot connect to a public peer whose WebSocket server only accepts same-origin handshakes. --- src/core/link_ws.go | 33 +++++++-- src/core/link_ws_test.go | 152 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 178 insertions(+), 7 deletions(-) create mode 100644 src/core/link_ws_test.go diff --git a/src/core/link_ws.go b/src/core/link_ws.go index b4638d5c..2194a10c 100644 --- a/src/core/link_ws.go +++ b/src/core/link_ws.go @@ -30,8 +30,9 @@ type linkWSListener struct { } type wsServer struct { - ch chan *linkWSConn - ctx context.Context + ch chan *linkWSConn + ctx context.Context + acceptOptions *websocket.AcceptOptions } func (l *linkWSListener) Accept() (net.Conn, error) { @@ -60,9 +61,7 @@ func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"ygg-ws"}, - }) + c, err := websocket.Accept(w, r, s.acceptOptions) if err != nil { return } @@ -87,6 +86,25 @@ func (l *links) newLinkWS() *linkWS { return lt } +func wsAcceptOptions(url *url.URL) *websocket.AcceptOptions { + opts := &websocket.AcceptOptions{ + Subprotocols: []string{"ygg-ws"}, + } + for _, origin := range url.Query()["origin"] { + switch origin { + case "": + continue + case "*": + opts.InsecureSkipVerify = true + opts.OriginPatterns = nil + return opts + default: + opts.OriginPatterns = append(opts.OriginPatterns, origin) + } + } + return opts +} + func (l *linkWS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { return l.findSuitableIP(url, func(hostname string, ip net.IP, port int) (net.Conn, error) { u := *url @@ -129,8 +147,9 @@ func (l *linkWS) listen(ctx context.Context, url *url.URL, _ string) (net.Listen httpServer := &http.Server{ Handler: &wsServer{ - ch: ch, - ctx: ctx, + ch: ch, + ctx: ctx, + acceptOptions: wsAcceptOptions(url), }, BaseContext: func(_ net.Listener) context.Context { return ctx }, ReadTimeout: time.Second * 10, diff --git a/src/core/link_ws_test.go b/src/core/link_ws_test.go new file mode 100644 index 00000000..b5dd9300 --- /dev/null +++ b/src/core/link_ws_test.go @@ -0,0 +1,152 @@ +package core + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/coder/websocket" +) + +func TestWSAcceptOptionsOriginQuery(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + rawurl string + insecureSkipVerify bool + originPatterns []string + }{ + { + name: "default same origin policy", + rawurl: "ws://0.0.0.0:9001", + }, + { + name: "host origin pattern", + rawurl: "ws://0.0.0.0:9001?origin=demo.example.org", + originPatterns: []string{"demo.example.org"}, + }, + { + name: "scheme origin pattern", + rawurl: "ws://0.0.0.0:9001?origin=https%3A%2F%2Fdemo.example.org", + originPatterns: []string{"https://demo.example.org"}, + }, + { + name: "multiple origin patterns", + rawurl: "ws://0.0.0.0:9001?origin=demo.example.org&origin=https%3A%2F%2Fdemo2.example.org", + originPatterns: []string{"demo.example.org", "https://demo2.example.org"}, + }, + { + name: "wildcard disables verification", + rawurl: "ws://0.0.0.0:9001?origin=*", + insecureSkipVerify: true, + }, + { + name: "wildcard overrides other patterns", + rawurl: "ws://0.0.0.0:9001?origin=demo.example.org&origin=*", + insecureSkipVerify: true, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + u, err := url.Parse(tc.rawurl) + if err != nil { + t.Fatal(err) + } + + opts := wsAcceptOptions(u) + if got := opts.InsecureSkipVerify; got != tc.insecureSkipVerify { + t.Fatalf("InsecureSkipVerify = %v, want %v", got, tc.insecureSkipVerify) + } + if strings.Join(opts.OriginPatterns, ",") != strings.Join(tc.originPatterns, ",") { + t.Fatalf("OriginPatterns = %#v, want %#v", opts.OriginPatterns, tc.originPatterns) + } + if strings.Join(opts.Subprotocols, ",") != "ygg-ws" { + t.Fatalf("Subprotocols = %#v, want [ygg-ws]", opts.Subprotocols) + } + }) + } +} + +func TestWSServerOriginPolicy(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + rawurl string + origin string + success bool + }{ + { + name: "default rejects cross origin", + rawurl: "ws://127.0.0.1:0", + origin: "https://demo.example.org", + success: false, + }, + { + name: "configured origin accepts cross origin", + rawurl: "ws://127.0.0.1:0?origin=demo.example.org", + origin: "https://demo.example.org", + success: true, + }, + { + name: "wildcard accepts cross origin", + rawurl: "ws://127.0.0.1:0?origin=*", + origin: "https://unexpected.example.org", + success: true, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + u, err := url.Parse(tc.rawurl) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch := make(chan *linkWSConn, 1) + server := httptest.NewServer(&wsServer{ + ch: ch, + ctx: ctx, + acceptOptions: wsAcceptOptions(u), + }) + defer server.Close() + + dialURL := "ws" + strings.TrimPrefix(server.URL, "http") + c, resp, err := websocket.Dial(ctx, dialURL, &websocket.DialOptions{ + HTTPHeader: http.Header{ + "Origin": []string{tc.origin}, + }, + Subprotocols: []string{"ygg-ws"}, + }) + if err != nil && resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + if tc.success { + if err != nil { + t.Fatalf("websocket dial failed: %v", err) + } + _ = c.Close(websocket.StatusNormalClosure, "") + select { + case conn := <-ch: + _ = conn.Close() + case <-time.After(time.Second): + t.Fatal("timed out waiting for accepted connection") + } + } else if err == nil { + _ = c.Close(websocket.StatusNormalClosure, "") + t.Fatal("websocket dial succeeded, want origin rejection") + } + }) + } +}