mirror of
https://github.com/yggdrasil-network/yggdrasil-go.git
synced 2026-05-20 21:06:30 +03:00
Allow WebSocket listeners to configure browser origins (#1342)
Adds an `origin` query option for `ws://` listener URLs so peer operators can explicitly allow browser-hosted WebSocket clients. - `ws://host:port` keeps the existing same-origin behavior - `ws://host:port?origin=demo.example.org` allows that origin host - `ws://host:port?origin=https://demo.example.org` allows that scheme and host - repeated `origin=` parameters allow multiple origin patterns - `origin=*` intentionally disables origin verification for public WebSocket peer endpoints ## Problem I've implemented a WASM based browser demo yggdrasil node to found that it cannot directly dial any existing public `ws://` or `wss://` peers. Browsers always include an `Origin` header in WebSocket handshakes, and the JavaScript `WebSocket()` constructor does not allow applications to override or remove arbitrary handshake headers. This means a browser demo served from an origin such as `http://127.0.0.1:8000` cannot connect to a public peer whose WebSocket server only accepts same-origin handshakes.
This commit is contained in:
@@ -30,8 +30,9 @@ type linkWSListener struct {
|
||||
}
|
||||
|
||||
type wsServer struct {
|
||||
ch chan *linkWSConn
|
||||
ctx context.Context
|
||||
ch chan *linkWSConn
|
||||
ctx context.Context
|
||||
acceptOptions *websocket.AcceptOptions
|
||||
}
|
||||
|
||||
func (l *linkWSListener) Accept() (net.Conn, error) {
|
||||
@@ -60,9 +61,7 @@ func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||
Subprotocols: []string{"ygg-ws"},
|
||||
})
|
||||
c, err := websocket.Accept(w, r, s.acceptOptions)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -87,6 +86,25 @@ func (l *links) newLinkWS() *linkWS {
|
||||
return lt
|
||||
}
|
||||
|
||||
func wsAcceptOptions(url *url.URL) *websocket.AcceptOptions {
|
||||
opts := &websocket.AcceptOptions{
|
||||
Subprotocols: []string{"ygg-ws"},
|
||||
}
|
||||
for _, origin := range url.Query()["origin"] {
|
||||
switch origin {
|
||||
case "":
|
||||
continue
|
||||
case "*":
|
||||
opts.InsecureSkipVerify = true
|
||||
opts.OriginPatterns = nil
|
||||
return opts
|
||||
default:
|
||||
opts.OriginPatterns = append(opts.OriginPatterns, origin)
|
||||
}
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func (l *linkWS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
|
||||
return l.findSuitableIP(url, func(hostname string, ip net.IP, port int) (net.Conn, error) {
|
||||
u := *url
|
||||
@@ -129,8 +147,9 @@ func (l *linkWS) listen(ctx context.Context, url *url.URL, _ string) (net.Listen
|
||||
|
||||
httpServer := &http.Server{
|
||||
Handler: &wsServer{
|
||||
ch: ch,
|
||||
ctx: ctx,
|
||||
ch: ch,
|
||||
ctx: ctx,
|
||||
acceptOptions: wsAcceptOptions(url),
|
||||
},
|
||||
BaseContext: func(_ net.Listener) context.Context { return ctx },
|
||||
ReadTimeout: time.Second * 10,
|
||||
|
||||
152
src/core/link_ws_test.go
Normal file
152
src/core/link_ws_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
func TestWSAcceptOptionsOriginQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
rawurl string
|
||||
insecureSkipVerify bool
|
||||
originPatterns []string
|
||||
}{
|
||||
{
|
||||
name: "default same origin policy",
|
||||
rawurl: "ws://0.0.0.0:9001",
|
||||
},
|
||||
{
|
||||
name: "host origin pattern",
|
||||
rawurl: "ws://0.0.0.0:9001?origin=demo.example.org",
|
||||
originPatterns: []string{"demo.example.org"},
|
||||
},
|
||||
{
|
||||
name: "scheme origin pattern",
|
||||
rawurl: "ws://0.0.0.0:9001?origin=https%3A%2F%2Fdemo.example.org",
|
||||
originPatterns: []string{"https://demo.example.org"},
|
||||
},
|
||||
{
|
||||
name: "multiple origin patterns",
|
||||
rawurl: "ws://0.0.0.0:9001?origin=demo.example.org&origin=https%3A%2F%2Fdemo2.example.org",
|
||||
originPatterns: []string{"demo.example.org", "https://demo2.example.org"},
|
||||
},
|
||||
{
|
||||
name: "wildcard disables verification",
|
||||
rawurl: "ws://0.0.0.0:9001?origin=*",
|
||||
insecureSkipVerify: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard overrides other patterns",
|
||||
rawurl: "ws://0.0.0.0:9001?origin=demo.example.org&origin=*",
|
||||
insecureSkipVerify: true,
|
||||
},
|
||||
} {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
u, err := url.Parse(tc.rawurl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := wsAcceptOptions(u)
|
||||
if got := opts.InsecureSkipVerify; got != tc.insecureSkipVerify {
|
||||
t.Fatalf("InsecureSkipVerify = %v, want %v", got, tc.insecureSkipVerify)
|
||||
}
|
||||
if strings.Join(opts.OriginPatterns, ",") != strings.Join(tc.originPatterns, ",") {
|
||||
t.Fatalf("OriginPatterns = %#v, want %#v", opts.OriginPatterns, tc.originPatterns)
|
||||
}
|
||||
if strings.Join(opts.Subprotocols, ",") != "ygg-ws" {
|
||||
t.Fatalf("Subprotocols = %#v, want [ygg-ws]", opts.Subprotocols)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSServerOriginPolicy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
rawurl string
|
||||
origin string
|
||||
success bool
|
||||
}{
|
||||
{
|
||||
name: "default rejects cross origin",
|
||||
rawurl: "ws://127.0.0.1:0",
|
||||
origin: "https://demo.example.org",
|
||||
success: false,
|
||||
},
|
||||
{
|
||||
name: "configured origin accepts cross origin",
|
||||
rawurl: "ws://127.0.0.1:0?origin=demo.example.org",
|
||||
origin: "https://demo.example.org",
|
||||
success: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard accepts cross origin",
|
||||
rawurl: "ws://127.0.0.1:0?origin=*",
|
||||
origin: "https://unexpected.example.org",
|
||||
success: true,
|
||||
},
|
||||
} {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
u, err := url.Parse(tc.rawurl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ch := make(chan *linkWSConn, 1)
|
||||
server := httptest.NewServer(&wsServer{
|
||||
ch: ch,
|
||||
ctx: ctx,
|
||||
acceptOptions: wsAcceptOptions(u),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
dialURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
c, resp, err := websocket.Dial(ctx, dialURL, &websocket.DialOptions{
|
||||
HTTPHeader: http.Header{
|
||||
"Origin": []string{tc.origin},
|
||||
},
|
||||
Subprotocols: []string{"ygg-ws"},
|
||||
})
|
||||
if err != nil && resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
if tc.success {
|
||||
if err != nil {
|
||||
t.Fatalf("websocket dial failed: %v", err)
|
||||
}
|
||||
_ = c.Close(websocket.StatusNormalClosure, "")
|
||||
select {
|
||||
case conn := <-ch:
|
||||
_ = conn.Close()
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for accepted connection")
|
||||
}
|
||||
} else if err == nil {
|
||||
_ = c.Close(websocket.StatusNormalClosure, "")
|
||||
t.Fatal("websocket dial succeeded, want origin rejection")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user