Skip to content

Commit d67546a

Browse files
authored
Merge pull request #55 from nhooyr/better
Remove AcceptOrigins and add AcceptInsecureOrigin
2 parents 932d16d + 696af24 commit d67546a

File tree

10 files changed

+154
-103
lines changed

10 files changed

+154
-103
lines changed

accept.go

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,26 @@ func AcceptSubprotocols(protocols ...string) AcceptOption {
2929
return acceptSubprotocols(protocols)
3030
}
3131

32-
type acceptOrigins []string
32+
type acceptInsecureOrigin struct{}
3333

34-
func (o acceptOrigins) acceptOption() {}
34+
func (o acceptInsecureOrigin) acceptOption() {}
3535

36-
// AcceptOrigins lists the origins that Accept will accept.
37-
// Accept will always accept r.Host as the origin. Use this
38-
// option when you want to accept an origin with a different domain
39-
// than the one the WebSocket server is running on.
36+
// AcceptInsecureOrigin disables Accept's origin verification
37+
// behaviour. By default Accept only allows the handshake to
38+
// succeed if the javascript that is initiating the handshake
39+
// is on the same domain as the server. This is to prevent CSRF
40+
// when secure data is stored in cookies.
4041
//
41-
// Use this option with caution to avoid exposing your WebSocket
42-
// server to a CSRF attack.
4342
// See https://stackoverflow.com/a/37837709/4283659
44-
func AcceptOrigins(origins ...string) AcceptOption {
45-
return acceptOrigins(origins)
43+
//
44+
// Use this if you want a WebSocket server any javascript can
45+
// connect to or you want to perform Origin verification yourself
46+
// and allow some whitelist of domains.
47+
//
48+
// Ensure you understand exactly what the above means before you use
49+
// this option in conjugation with cookies containing secure data.
50+
func AcceptInsecureOrigin() AcceptOption {
51+
return acceptInsecureOrigin{}
4652
}
4753

4854
func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
@@ -86,11 +92,11 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
8692
// Accept uses w to write the handshake response so the timeouts on the http.Server apply.
8793
func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) {
8894
var subprotocols []string
89-
origins := []string{r.Host}
95+
verifyOrigin := true
9096
for _, opt := range opts {
9197
switch opt := opt.(type) {
92-
case acceptOrigins:
93-
origins = []string(opt)
98+
case acceptInsecureOrigin:
99+
verifyOrigin = false
94100
case acceptSubprotocols:
95101
subprotocols = []string(opt)
96102
}
@@ -101,12 +107,12 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
101107
return nil, err
102108
}
103109

104-
origins = append(origins, r.Host)
105-
106-
err = authenticateOrigin(r, origins)
107-
if err != nil {
108-
http.Error(w, err.Error(), http.StatusForbidden)
109-
return nil, err
110+
if verifyOrigin {
111+
err = authenticateOrigin(r)
112+
if err != nil {
113+
http.Error(w, err.Error(), http.StatusForbidden)
114+
return nil, err
115+
}
110116
}
111117

112118
hj, ok := w.(http.Hijacker)
@@ -172,7 +178,7 @@ func handleKey(w http.ResponseWriter, r *http.Request) {
172178
w.Header().Set("Sec-WebSocket-Accept", responseKey)
173179
}
174180

175-
func authenticateOrigin(r *http.Request, origins []string) error {
181+
func authenticateOrigin(r *http.Request) error {
176182
origin := r.Header.Get("Origin")
177183
if origin == "" {
178184
return nil
@@ -181,10 +187,8 @@ func authenticateOrigin(r *http.Request, origins []string) error {
181187
if err != nil {
182188
return xerrors.Errorf("failed to parse Origin header %q: %w", origin, err)
183189
}
184-
for _, o := range origins {
185-
if strings.EqualFold(u.Host, o) {
186-
return nil
187-
}
190+
if strings.EqualFold(u.Host, r.Host) {
191+
return nil
188192
}
189-
return xerrors.Errorf("request origin %q is not authorized", r.Header.Get("Origin"))
193+
return xerrors.Errorf("request origin %q is not authorized", origin)
190194
}

accept_test.go

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -140,37 +140,39 @@ func Test_authenticateOrigin(t *testing.T) {
140140
t.Parallel()
141141

142142
testCases := []struct {
143-
name string
144-
origin string
145-
authorizedOrigins []string
146-
success bool
143+
name string
144+
origin string
145+
host string
146+
success bool
147147
}{
148148
{
149149
name: "none",
150150
success: true,
151+
host: "example.com",
151152
},
152153
{
153154
name: "invalid",
154155
origin: "$#)(*)$#@*$(#@*$)#@*%)#(@*%)#(@%#@$#@$#$#@$#@}{}{}",
156+
host: "example.com",
155157
success: false,
156158
},
157159
{
158-
name: "unauthorized",
159-
origin: "https://example.com",
160-
authorizedOrigins: []string{"example1.com"},
161-
success: false,
160+
name: "unauthorized",
161+
origin: "https://example.com",
162+
host: "example1.com",
163+
success: false,
162164
},
163165
{
164-
name: "authorized",
165-
origin: "https://example.com",
166-
authorizedOrigins: []string{"example.com"},
167-
success: true,
166+
name: "authorized",
167+
origin: "https://example.com",
168+
host: "example.com",
169+
success: true,
168170
},
169171
{
170-
name: "authorizedCaseInsensitive",
171-
origin: "https://examplE.com",
172-
authorizedOrigins: []string{"example.com"},
173-
success: true,
172+
name: "authorizedCaseInsensitive",
173+
origin: "https://examplE.com",
174+
host: "example.com",
175+
success: true,
174176
},
175177
}
176178

@@ -179,10 +181,10 @@ func Test_authenticateOrigin(t *testing.T) {
179181
t.Run(tc.name, func(t *testing.T) {
180182
t.Parallel()
181183

182-
r := httptest.NewRequest("GET", "/", nil)
184+
r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil)
183185
r.Header.Set("Origin", tc.origin)
184186

185-
err := authenticateOrigin(r, tc.authorizedOrigins)
187+
err := authenticateOrigin(r)
186188
if (err == nil) != tc.success {
187189
t.Fatalf("unexpected error value: %+v", err)
188190
}

datatype.go

Lines changed: 0 additions & 12 deletions
This file was deleted.

datatype_string.go

Lines changed: 0 additions & 25 deletions
This file was deleted.

example_test.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,18 @@ import (
1414

1515
func ExampleAccept_echo() {
1616
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
17-
c, err := websocket.Accept(w, r)
17+
c, err := websocket.Accept(w, r, websocket.AcceptSubprotocols("echo"))
1818
if err != nil {
1919
log.Printf("server handshake failed: %v", err)
2020
return
2121
}
2222
defer c.Close(websocket.StatusInternalError, "")
2323

24+
if c.Subprotocol() == "" {
25+
c.Close(websocket.StatusPolicyViolation, "cannot communicate with the default protocol")
26+
return
27+
}
28+
2429
echo := func() error {
2530
ctx, cancel := context.WithTimeout(r.Context(), time.Minute)
2631
defer cancel()

json.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func (jc *JSONConn) read(ctx context.Context, v interface{}) error {
2828
return err
2929
}
3030

31-
if typ != DataText {
31+
if typ != MessageText {
3232
return xerrors.Errorf("unexpected frame type for json (expected DataText): %v", typ)
3333
}
3434

@@ -39,6 +39,7 @@ func (jc *JSONConn) read(ctx context.Context, v interface{}) error {
3939
if err != nil {
4040
return xerrors.Errorf("failed to decode json: %w", err)
4141
}
42+
4243
return nil
4344
}
4445

@@ -52,7 +53,7 @@ func (jc JSONConn) Write(ctx context.Context, v interface{}) error {
5253
}
5354

5455
func (jc JSONConn) write(ctx context.Context, v interface{}) error {
55-
w := jc.Conn.Write(ctx, DataText)
56+
w := jc.Conn.Write(ctx, MessageText)
5657

5758
e := json.NewEncoder(w)
5859
err := e.Encode(v)

messagetype.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package websocket
2+
3+
// MessageType represents the Opcode of a WebSocket data frame.
4+
type MessageType int
5+
6+
//go:generate go run golang.org/x/tools/cmd/stringer -type=MessageType
7+
8+
// MessageType constants.
9+
const (
10+
MessageText MessageType = MessageType(opText)
11+
MessageBinary MessageType = MessageType(opBinary)
12+
)

messagetype_string.go

Lines changed: 25 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)