@@ -29,21 +29,26 @@ func AcceptSubprotocols(protocols ...string) AcceptOption {
29
29
return acceptSubprotocols (protocols )
30
30
}
31
31
32
- type acceptOrigins [] string
32
+ type acceptInsecureOrigin struct {}
33
33
34
- func (o acceptOrigins ) acceptOption () {}
34
+ func (o acceptInsecureOrigin ) acceptOption () {}
35
35
36
- // AcceptOrigins lists the origins that Accept will accept.
37
- // Accept will always accept r.Host as the origin. Use this
38
- // option when you want to accept an origin with a different domain
39
- // than the one the WebSocket server is running on.
36
+ // AcceptInsecureOrigin disables Accept's origin verification
37
+ // behaviour. By default Accept only allows the handshake to
38
+ // succeed if the javascript that is initiating the handshake
39
+ // is on the same domain as the server. This is to prevent CSRF
40
+ // when secure data is stored in cookies.
40
41
//
41
- // Use this option with caution to avoid exposing your WebSocket
42
- // server to a CSRF attack.
43
42
// See https://stackoverflow.com/a/37837709/4283659
44
- // TODO remove in favour of AcceptInsecureOrigin
45
- func AcceptOrigins (origins ... string ) AcceptOption {
46
- return acceptOrigins (origins )
43
+ //
44
+ // Use this if you want a WebSocket server any javascript can
45
+ // connect to or you want to perform Origin verification yourself
46
+ // and allow some whitelist of domains.
47
+ //
48
+ // Ensure you understand exactly what the above means before you use
49
+ // this option in conjugation with cookies containing secure data.
50
+ func AcceptInsecureOrigin () AcceptOption {
51
+ return acceptInsecureOrigin {}
47
52
}
48
53
49
54
func verifyClientRequest (w http.ResponseWriter , r * http.Request ) error {
@@ -87,11 +92,11 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
87
92
// Accept uses w to write the handshake response so the timeouts on the http.Server apply.
88
93
func Accept (w http.ResponseWriter , r * http.Request , opts ... AcceptOption ) (* Conn , error ) {
89
94
var subprotocols []string
90
- origins := [] string { r . Host }
95
+ verifyOrigin := true
91
96
for _ , opt := range opts {
92
97
switch opt := opt .(type ) {
93
- case acceptOrigins :
94
- origins = [] string ( opt )
98
+ case acceptInsecureOrigin :
99
+ verifyOrigin = false
95
100
case acceptSubprotocols :
96
101
subprotocols = []string (opt )
97
102
}
@@ -102,12 +107,12 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
102
107
return nil , err
103
108
}
104
109
105
- origins = append ( origins , r . Host )
106
-
107
- err = authenticateOrigin ( r , origins )
108
- if err != nil {
109
- http . Error ( w , err . Error (), http . StatusForbidden )
110
- return nil , err
110
+ if verifyOrigin {
111
+ err = authenticateOrigin ( r )
112
+ if err != nil {
113
+ http . Error ( w , err . Error (), http . StatusForbidden )
114
+ return nil , err
115
+ }
111
116
}
112
117
113
118
hj , ok := w .(http.Hijacker )
@@ -173,7 +178,7 @@ func handleKey(w http.ResponseWriter, r *http.Request) {
173
178
w .Header ().Set ("Sec-WebSocket-Accept" , responseKey )
174
179
}
175
180
176
- func authenticateOrigin (r * http.Request , origins [] string ) error {
181
+ func authenticateOrigin (r * http.Request ) error {
177
182
origin := r .Header .Get ("Origin" )
178
183
if origin == "" {
179
184
return nil
@@ -182,10 +187,8 @@ func authenticateOrigin(r *http.Request, origins []string) error {
182
187
if err != nil {
183
188
return xerrors .Errorf ("failed to parse Origin header %q: %w" , origin , err )
184
189
}
185
- for _ , o := range origins {
186
- if strings .EqualFold (u .Host , o ) {
187
- return nil
188
- }
190
+ if strings .EqualFold (u .Host , r .Host ) {
191
+ return nil
189
192
}
190
- return xerrors .Errorf ("request origin %q is not authorized" , r . Header . Get ( "Origin" ) )
193
+ return xerrors .Errorf ("request origin %q is not authorized" , origin )
191
194
}
0 commit comments