Files
yggdrasil-go/src/core/link_ws.go
Ascii Moth 2cc8e7506e Allow WebSocket listeners to configure browser origins (#1342)
Adds an `origin` query option for `ws://` listener URLs so peer
operators can
explicitly allow browser-hosted WebSocket clients.

- `ws://host:port` keeps the existing same-origin behavior
- `ws://host:port?origin=demo.example.org` allows that origin host
- `ws://host:port?origin=https://demo.example.org` allows that scheme
and host
- repeated `origin=` parameters allow multiple origin patterns
- `origin=*` intentionally disables origin verification for public
WebSocket
  peer endpoints

## Problem
I've implemented a WASM based browser demo yggdrasil node to found that
it
cannot directly dial any existing public `ws://` or `wss://` peers.

Browsers always include an `Origin` header in WebSocket handshakes, and
the
JavaScript `WebSocket()` constructor does not allow applications to
override or
remove arbitrary handshake headers.  
This means a browser demo served from an origin such as
`http://127.0.0.1:8000` cannot connect to a public peer whose WebSocket
server
only accepts same-origin handshakes.
2026-05-12 21:40:59 +01:00

168 lines
3.4 KiB
Go

package core
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"time"
"github.com/Arceliar/phony"
"github.com/coder/websocket"
)
type linkWS struct {
phony.Inbox
*links
listenconfig *net.ListenConfig
}
type linkWSConn struct {
net.Conn
}
type linkWSListener struct {
ch chan *linkWSConn
ctx context.Context
httpServer *http.Server
listener net.Listener
}
type wsServer struct {
ch chan *linkWSConn
ctx context.Context
acceptOptions *websocket.AcceptOptions
}
func (l *linkWSListener) Accept() (net.Conn, error) {
qs := <-l.ch
if qs == nil {
return nil, context.Canceled
}
return qs, nil
}
func (l *linkWSListener) Addr() net.Addr {
return l.listener.Addr()
}
func (l *linkWSListener) Close() error {
if err := l.httpServer.Shutdown(l.ctx); err != nil {
return err
}
return l.listener.Close()
}
func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/health" || r.URL.Path == "/healthz" {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
return
}
c, err := websocket.Accept(w, r, s.acceptOptions)
if err != nil {
return
}
if c.Subprotocol() != "ygg-ws" {
c.Close(websocket.StatusPolicyViolation, "client must speak the ygg-ws subprotocol")
return
}
s.ch <- &linkWSConn{
Conn: websocket.NetConn(s.ctx, c, websocket.MessageBinary),
}
}
func (l *links) newLinkWS() *linkWS {
lt := &linkWS{
links: l,
listenconfig: &net.ListenConfig{
KeepAlive: -1,
},
}
return lt
}
func wsAcceptOptions(url *url.URL) *websocket.AcceptOptions {
opts := &websocket.AcceptOptions{
Subprotocols: []string{"ygg-ws"},
}
for _, origin := range url.Query()["origin"] {
switch origin {
case "":
continue
case "*":
opts.InsecureSkipVerify = true
opts.OriginPatterns = nil
return opts
default:
opts.OriginPatterns = append(opts.OriginPatterns, origin)
}
}
return opts
}
func (l *linkWS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
return l.findSuitableIP(url, func(hostname string, ip net.IP, port int) (net.Conn, error) {
u := *url
u.Host = net.JoinHostPort(ip.String(), fmt.Sprintf("%d", port))
addr := &net.TCPAddr{
IP: ip,
Port: port,
}
dialer, err := l.tcp.dialerFor(addr, info.sintf)
if err != nil {
return nil, err
}
wsconn, _, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
HTTPClient: &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: dialer.Dial,
DialContext: dialer.DialContext,
},
},
Subprotocols: []string{"ygg-ws"},
Host: hostname,
})
if err != nil {
return nil, err
}
return &linkWSConn{
Conn: websocket.NetConn(ctx, wsconn, websocket.MessageBinary),
}, nil
})
}
func (l *linkWS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) {
nl, err := l.listenconfig.Listen(ctx, "tcp", url.Host)
if err != nil {
return nil, err
}
ch := make(chan *linkWSConn)
httpServer := &http.Server{
Handler: &wsServer{
ch: ch,
ctx: ctx,
acceptOptions: wsAcceptOptions(url),
},
BaseContext: func(_ net.Listener) context.Context { return ctx },
ReadTimeout: time.Second * 10,
WriteTimeout: time.Second * 10,
}
lwl := &linkWSListener{
ch: ch,
ctx: ctx,
httpServer: httpServer,
listener: nl,
}
go lwl.httpServer.Serve(nl) // nolint:errcheck
return lwl, nil
}