Skip to content

Commit bda74e6

Browse files
committed
Remove AcceptOrigins and add AcceptInsecureOrigin
Closes #39
1 parent 0746583 commit bda74e6

File tree

3 files changed

+72
-46
lines changed

3 files changed

+72
-46
lines changed

accept.go

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,26 @@ func AcceptSubprotocols(protocols ...string) AcceptOption {
2929
return acceptSubprotocols(protocols)
3030
}
3131

32-
type acceptOrigins []string
32+
type acceptInsecureOrigin struct{}
3333

34-
func (o acceptOrigins) acceptOption() {}
34+
func (o acceptInsecureOrigin) acceptOption() {}
3535

36-
// AcceptOrigins lists the origins that Accept will accept.
37-
// Accept will always accept r.Host as the origin. Use this
38-
// option when you want to accept an origin with a different domain
39-
// than the one the WebSocket server is running on.
36+
// AcceptInsecureOrigin disables Accept's origin verification
37+
// behaviour. By default Accept only allows the handshake to
38+
// succeed if the javascript that is initiating the handshake
39+
// is on the same domain as the server. This is to prevent CSRF
40+
// when secure data is stored in cookies.
4041
//
41-
// Use this option with caution to avoid exposing your WebSocket
42-
// server to a CSRF attack.
4342
// See https://stackoverflow.com/a/37837709/4283659
44-
// TODO remove in favour of AcceptInsecureOrigin
45-
func AcceptOrigins(origins ...string) AcceptOption {
46-
return acceptOrigins(origins)
43+
//
44+
// Use this if you want a WebSocket server any javascript can
45+
// connect to or you want to perform Origin verification yourself
46+
// and allow some whitelist of domains.
47+
//
48+
// Ensure you understand exactly what the above means before you use
49+
// this option in conjugation with cookies containing secure data.
50+
func AcceptInsecureOrigin() AcceptOption {
51+
return acceptInsecureOrigin{}
4752
}
4853

4954
func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
@@ -87,11 +92,11 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
8792
// Accept uses w to write the handshake response so the timeouts on the http.Server apply.
8893
func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) {
8994
var subprotocols []string
90-
origins := []string{r.Host}
95+
verifyOrigin := true
9196
for _, opt := range opts {
9297
switch opt := opt.(type) {
93-
case acceptOrigins:
94-
origins = []string(opt)
98+
case acceptInsecureOrigin:
99+
verifyOrigin = false
95100
case acceptSubprotocols:
96101
subprotocols = []string(opt)
97102
}
@@ -102,12 +107,12 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
102107
return nil, err
103108
}
104109

105-
origins = append(origins, r.Host)
106-
107-
err = authenticateOrigin(r, origins)
108-
if err != nil {
109-
http.Error(w, err.Error(), http.StatusForbidden)
110-
return nil, err
110+
if verifyOrigin {
111+
err = authenticateOrigin(r)
112+
if err != nil {
113+
http.Error(w, err.Error(), http.StatusForbidden)
114+
return nil, err
115+
}
111116
}
112117

113118
hj, ok := w.(http.Hijacker)
@@ -173,7 +178,7 @@ func handleKey(w http.ResponseWriter, r *http.Request) {
173178
w.Header().Set("Sec-WebSocket-Accept", responseKey)
174179
}
175180

176-
func authenticateOrigin(r *http.Request, origins []string) error {
181+
func authenticateOrigin(r *http.Request) error {
177182
origin := r.Header.Get("Origin")
178183
if origin == "" {
179184
return nil
@@ -182,10 +187,8 @@ func authenticateOrigin(r *http.Request, origins []string) error {
182187
if err != nil {
183188
return xerrors.Errorf("failed to parse Origin header %q: %w", origin, err)
184189
}
185-
for _, o := range origins {
186-
if strings.EqualFold(u.Host, o) {
187-
return nil
188-
}
190+
if strings.EqualFold(u.Host, r.Host) {
191+
return nil
189192
}
190-
return xerrors.Errorf("request origin %q is not authorized", r.Header.Get("Origin"))
193+
return xerrors.Errorf("request origin %q is not authorized", origin)
191194
}

accept_test.go

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -140,37 +140,39 @@ func Test_authenticateOrigin(t *testing.T) {
140140
t.Parallel()
141141

142142
testCases := []struct {
143-
name string
144-
origin string
145-
authorizedOrigins []string
146-
success bool
143+
name string
144+
origin string
145+
host string
146+
success bool
147147
}{
148148
{
149149
name: "none",
150150
success: true,
151+
host: "example.com",
151152
},
152153
{
153154
name: "invalid",
154155
origin: "$#)(*)$#@*$(#@*$)#@*%)#(@*%)#(@%#@$#@$#$#@$#@}{}{}",
156+
host: "example.com",
155157
success: false,
156158
},
157159
{
158-
name: "unauthorized",
159-
origin: "https://example.com",
160-
authorizedOrigins: []string{"example1.com"},
161-
success: false,
160+
name: "unauthorized",
161+
origin: "https://example.com",
162+
host: "example1.com",
163+
success: false,
162164
},
163165
{
164-
name: "authorized",
165-
origin: "https://example.com",
166-
authorizedOrigins: []string{"example.com"},
167-
success: true,
166+
name: "authorized",
167+
origin: "https://example.com",
168+
host: "example.com",
169+
success: true,
168170
},
169171
{
170-
name: "authorizedCaseInsensitive",
171-
origin: "https://examplE.com",
172-
authorizedOrigins: []string{"example.com"},
173-
success: true,
172+
name: "authorizedCaseInsensitive",
173+
origin: "https://examplE.com",
174+
host: "example.com",
175+
success: true,
174176
},
175177
}
176178

@@ -179,10 +181,10 @@ func Test_authenticateOrigin(t *testing.T) {
179181
t.Run(tc.name, func(t *testing.T) {
180182
t.Parallel()
181183

182-
r := httptest.NewRequest("GET", "/", nil)
184+
r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil)
183185
r.Header.Set("Origin", tc.origin)
184186

185-
err := authenticateOrigin(r, tc.authorizedOrigins)
187+
err := authenticateOrigin(r)
186188
if (err == nil) != tc.success {
187189
t.Fatalf("unexpected error value: %+v", err)
188190
}

websocket_test.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,30 @@ func TestHandshake(t *testing.T) {
143143
},
144144
},
145145
{
146-
name: "authorizedOrigin",
146+
name: "acceptSecureOrigin",
147147
server: func(w http.ResponseWriter, r *http.Request) error {
148-
c, err := websocket.Accept(w, r, websocket.AcceptOrigins("har.bar.com", "example.com"))
148+
c, err := websocket.Accept(w, r, websocket.AcceptInsecureOrigin())
149+
if err != nil {
150+
return err
151+
}
152+
defer c.Close(websocket.StatusInternalError, "")
153+
return nil
154+
},
155+
client: func(ctx context.Context, u string) error {
156+
h := http.Header{}
157+
h.Set("Origin", "https://127.0.0.1")
158+
c, _, err := websocket.Dial(ctx, u, websocket.DialHeader(h))
159+
if err != nil {
160+
return err
161+
}
162+
defer c.Close(websocket.StatusInternalError, "")
163+
return nil
164+
},
165+
},
166+
{
167+
name: "acceptInsecureOrigin",
168+
server: func(w http.ResponseWriter, r *http.Request) error {
169+
c, err := websocket.Accept(w, r, websocket.AcceptInsecureOrigin())
149170
if err != nil {
150171
return err
151172
}

0 commit comments

Comments
 (0)