Allow WebSocket listeners to configure browser origins (#1342)

Adds an `origin` query option for `ws://` listener URLs so peer
operators can
explicitly allow browser-hosted WebSocket clients.

- `ws://host:port` keeps the existing same-origin behavior
- `ws://host:port?origin=demo.example.org` allows that origin host
- `ws://host:port?origin=https://demo.example.org` allows that scheme
and host
- repeated `origin=` parameters allow multiple origin patterns
- `origin=*` intentionally disables origin verification for public
WebSocket
  peer endpoints

## Problem
I've implemented a WASM based browser demo yggdrasil node to found that
it
cannot directly dial any existing public `ws://` or `wss://` peers.

Browsers always include an `Origin` header in WebSocket handshakes, and
the
JavaScript `WebSocket()` constructor does not allow applications to
override or
remove arbitrary handshake headers.  
This means a browser demo served from an origin such as
`http://127.0.0.1:8000` cannot connect to a public peer whose WebSocket
server
only accepts same-origin handshakes.
This commit is contained in:
Ascii Moth
2026-05-13 00:40:59 +04:00
committed by GitHub
parent aaf263957b
commit 2cc8e7506e
2 changed files with 178 additions and 7 deletions

View File

@@ -30,8 +30,9 @@ type linkWSListener struct {
}
type wsServer struct {
ch chan *linkWSConn
ctx context.Context
ch chan *linkWSConn
ctx context.Context
acceptOptions *websocket.AcceptOptions
}
func (l *linkWSListener) Accept() (net.Conn, error) {
@@ -60,9 +61,7 @@ func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
Subprotocols: []string{"ygg-ws"},
})
c, err := websocket.Accept(w, r, s.acceptOptions)
if err != nil {
return
}
@@ -87,6 +86,25 @@ func (l *links) newLinkWS() *linkWS {
return lt
}
func wsAcceptOptions(url *url.URL) *websocket.AcceptOptions {
opts := &websocket.AcceptOptions{
Subprotocols: []string{"ygg-ws"},
}
for _, origin := range url.Query()["origin"] {
switch origin {
case "":
continue
case "*":
opts.InsecureSkipVerify = true
opts.OriginPatterns = nil
return opts
default:
opts.OriginPatterns = append(opts.OriginPatterns, origin)
}
}
return opts
}
func (l *linkWS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
return l.findSuitableIP(url, func(hostname string, ip net.IP, port int) (net.Conn, error) {
u := *url
@@ -129,8 +147,9 @@ func (l *linkWS) listen(ctx context.Context, url *url.URL, _ string) (net.Listen
httpServer := &http.Server{
Handler: &wsServer{
ch: ch,
ctx: ctx,
ch: ch,
ctx: ctx,
acceptOptions: wsAcceptOptions(url),
},
BaseContext: func(_ net.Listener) context.Context { return ctx },
ReadTimeout: time.Second * 10,

152
src/core/link_ws_test.go Normal file
View File

@@ -0,0 +1,152 @@
package core
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/coder/websocket"
)
func TestWSAcceptOptionsOriginQuery(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
rawurl string
insecureSkipVerify bool
originPatterns []string
}{
{
name: "default same origin policy",
rawurl: "ws://0.0.0.0:9001",
},
{
name: "host origin pattern",
rawurl: "ws://0.0.0.0:9001?origin=demo.example.org",
originPatterns: []string{"demo.example.org"},
},
{
name: "scheme origin pattern",
rawurl: "ws://0.0.0.0:9001?origin=https%3A%2F%2Fdemo.example.org",
originPatterns: []string{"https://demo.example.org"},
},
{
name: "multiple origin patterns",
rawurl: "ws://0.0.0.0:9001?origin=demo.example.org&origin=https%3A%2F%2Fdemo2.example.org",
originPatterns: []string{"demo.example.org", "https://demo2.example.org"},
},
{
name: "wildcard disables verification",
rawurl: "ws://0.0.0.0:9001?origin=*",
insecureSkipVerify: true,
},
{
name: "wildcard overrides other patterns",
rawurl: "ws://0.0.0.0:9001?origin=demo.example.org&origin=*",
insecureSkipVerify: true,
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
u, err := url.Parse(tc.rawurl)
if err != nil {
t.Fatal(err)
}
opts := wsAcceptOptions(u)
if got := opts.InsecureSkipVerify; got != tc.insecureSkipVerify {
t.Fatalf("InsecureSkipVerify = %v, want %v", got, tc.insecureSkipVerify)
}
if strings.Join(opts.OriginPatterns, ",") != strings.Join(tc.originPatterns, ",") {
t.Fatalf("OriginPatterns = %#v, want %#v", opts.OriginPatterns, tc.originPatterns)
}
if strings.Join(opts.Subprotocols, ",") != "ygg-ws" {
t.Fatalf("Subprotocols = %#v, want [ygg-ws]", opts.Subprotocols)
}
})
}
}
func TestWSServerOriginPolicy(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
rawurl string
origin string
success bool
}{
{
name: "default rejects cross origin",
rawurl: "ws://127.0.0.1:0",
origin: "https://demo.example.org",
success: false,
},
{
name: "configured origin accepts cross origin",
rawurl: "ws://127.0.0.1:0?origin=demo.example.org",
origin: "https://demo.example.org",
success: true,
},
{
name: "wildcard accepts cross origin",
rawurl: "ws://127.0.0.1:0?origin=*",
origin: "https://unexpected.example.org",
success: true,
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
u, err := url.Parse(tc.rawurl)
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ch := make(chan *linkWSConn, 1)
server := httptest.NewServer(&wsServer{
ch: ch,
ctx: ctx,
acceptOptions: wsAcceptOptions(u),
})
defer server.Close()
dialURL := "ws" + strings.TrimPrefix(server.URL, "http")
c, resp, err := websocket.Dial(ctx, dialURL, &websocket.DialOptions{
HTTPHeader: http.Header{
"Origin": []string{tc.origin},
},
Subprotocols: []string{"ygg-ws"},
})
if err != nil && resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
if tc.success {
if err != nil {
t.Fatalf("websocket dial failed: %v", err)
}
_ = c.Close(websocket.StatusNormalClosure, "")
select {
case conn := <-ch:
_ = conn.Close()
case <-time.After(time.Second):
t.Fatal("timed out waiting for accepted connection")
}
} else if err == nil {
_ = c.Close(websocket.StatusNormalClosure, "")
t.Fatal("websocket dial succeeded, want origin rejection")
}
})
}
}