Skip to content

Commit f4997d7

Browse files
committed
Server selects first acceptable compression offer
Unacceptable offers are declined without rejecting the request.
1 parent cc2d7bd commit f4997d7

File tree

4 files changed

+88
-65
lines changed

4 files changed

+88
-65
lines changed

accept.go

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
118118
w.Header().Set("Sec-WebSocket-Protocol", subproto)
119119
}
120120

121-
copts, err := acceptCompression(r, w, opts.CompressionMode)
122-
if err != nil {
123-
return nil, err
121+
copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
122+
if ok {
123+
w.Header().Set("Sec-WebSocket-Extensions", copts.String())
124124
}
125125

126126
w.WriteHeader(http.StatusSwitchingProtocols)
@@ -230,26 +230,26 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string {
230230
return ""
231231
}
232232

233-
func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) {
233+
func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
234234
if mode == CompressionDisabled {
235-
return nil, nil
235+
return nil, false
236236
}
237-
238-
for _, ext := range websocketExtensions(r.Header) {
237+
for _, ext := range extensions {
239238
switch ext.name {
240239
case "permessage-deflate":
241-
return acceptDeflate(w, ext, mode)
240+
if copts, ok := acceptDeflate(ext, mode); ok {
241+
return copts, true
242+
}
242243
// Disabled for now, see https://github.com/nhooyr/websocket/issues/218
243244
// case "x-webkit-deflate-frame":
244245
// return acceptWebkitDeflate(w, ext, mode)
245246
}
246247
}
247-
return nil, nil
248+
return nil, false
248249
}
249250

250-
func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
251+
func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
251252
copts := mode.opts()
252-
253253
for _, p := range ext.params {
254254
switch p {
255255
case "client_no_context_takeover":
@@ -258,23 +258,17 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
258258
case "server_no_context_takeover":
259259
copts.serverNoContextTakeover = true
260260
continue
261-
case "server_max_window_bits=15":
261+
case "client_max_window_bits",
262+
"server_max_window_bits=15":
262263
continue
263264
}
264-
265-
if strings.HasPrefix(p, "client_max_window_bits") {
266-
// We cannot adjust the read sliding window so cannot make use of this.
265+
if strings.HasPrefix(p, "client_max_window_bits=") {
266+
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
267267
continue
268268
}
269-
270-
err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
271-
http.Error(w, err.Error(), http.StatusBadRequest)
272-
return nil, err
269+
return nil, false
273270
}
274-
275-
copts.setHeader(w.Header())
276-
277-
return copts, nil
271+
return copts, true
278272
}
279273

280274
func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {

accept_test.go

Lines changed: 68 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,47 @@ func TestAccept(t *testing.T) {
4545
t.Run("badCompression", func(t *testing.T) {
4646
t.Parallel()
4747

48-
w := mockHijacker{
49-
ResponseWriter: httptest.NewRecorder(),
48+
newRequest := func(extensions string) *http.Request {
49+
r := httptest.NewRequest("GET", "/", nil)
50+
r.Header.Set("Connection", "Upgrade")
51+
r.Header.Set("Upgrade", "websocket")
52+
r.Header.Set("Sec-WebSocket-Version", "13")
53+
r.Header.Set("Sec-WebSocket-Key", "meow123")
54+
r.Header.Set("Sec-WebSocket-Extensions", extensions)
55+
return r
56+
}
57+
newResponseWriter := func() http.ResponseWriter {
58+
return mockHijacker{
59+
ResponseWriter: httptest.NewRecorder(),
60+
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
61+
return nil, nil, errors.New("hijack error")
62+
},
63+
}
5064
}
51-
r := httptest.NewRequest("GET", "/", nil)
52-
r.Header.Set("Connection", "Upgrade")
53-
r.Header.Set("Upgrade", "websocket")
54-
r.Header.Set("Sec-WebSocket-Version", "13")
55-
r.Header.Set("Sec-WebSocket-Key", "meow123")
56-
r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar")
5765

58-
_, err := Accept(w, r, &AcceptOptions{
59-
CompressionMode: CompressionContextTakeover,
66+
t.Run("withoutFallback", func(t *testing.T) {
67+
t.Parallel()
68+
69+
w := newResponseWriter()
70+
r := newRequest("permessage-deflate; harharhar")
71+
_, _ = Accept(w, r, &AcceptOptions{
72+
CompressionMode: CompressionNoContextTakeover,
73+
})
74+
assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "")
75+
})
76+
t.Run("withFallback", func(t *testing.T) {
77+
t.Parallel()
78+
79+
w := newResponseWriter()
80+
r := newRequest("permessage-deflate; harharhar, permessage-deflate")
81+
_, _ = Accept(w, r, &AcceptOptions{
82+
CompressionMode: CompressionNoContextTakeover,
83+
})
84+
assert.Equal(t, "extension header",
85+
w.Header().Get("Sec-WebSocket-Extensions"),
86+
CompressionNoContextTakeover.opts().String(),
87+
)
6088
})
61-
assert.Contains(t, err, `unsupported permessage-deflate parameter`)
6289
})
6390

6491
t.Run("requireHttpHijacker", func(t *testing.T) {
@@ -321,42 +348,53 @@ func Test_authenticateOrigin(t *testing.T) {
321348
}
322349
}
323350

324-
func Test_acceptCompression(t *testing.T) {
351+
func Test_selectDeflate(t *testing.T) {
325352
t.Parallel()
326353

327354
testCases := []struct {
328-
name string
329-
mode CompressionMode
330-
reqSecWebSocketExtensions string
331-
respSecWebSocketExtensions string
332-
expCopts *compressionOptions
333-
error bool
355+
name string
356+
mode CompressionMode
357+
header string
358+
expCopts *compressionOptions
359+
expOK bool
334360
}{
335361
{
336362
name: "disabled",
337363
mode: CompressionDisabled,
338364
expCopts: nil,
365+
expOK: false,
339366
},
340367
{
341368
name: "noClientSupport",
342369
mode: CompressionNoContextTakeover,
343370
expCopts: nil,
371+
expOK: false,
344372
},
345373
{
346-
name: "permessage-deflate",
347-
mode: CompressionNoContextTakeover,
348-
reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits",
349-
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover",
374+
name: "permessage-deflate",
375+
mode: CompressionNoContextTakeover,
376+
header: "permessage-deflate; client_max_window_bits",
350377
expCopts: &compressionOptions{
351378
clientNoContextTakeover: true,
352379
serverNoContextTakeover: true,
353380
},
381+
expOK: true,
382+
},
383+
{
384+
name: "permessage-deflate/unknown-parameter",
385+
mode: CompressionNoContextTakeover,
386+
header: "permessage-deflate; meow",
387+
expOK: false,
354388
},
355389
{
356-
name: "permessage-deflate/error",
357-
mode: CompressionNoContextTakeover,
358-
reqSecWebSocketExtensions: "permessage-deflate; meow",
359-
error: true,
390+
name: "permessage-deflate/unknown-parameter",
391+
mode: CompressionNoContextTakeover,
392+
header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits",
393+
expCopts: &compressionOptions{
394+
clientNoContextTakeover: true,
395+
serverNoContextTakeover: true,
396+
},
397+
expOK: true,
360398
},
361399
// {
362400
// name: "x-webkit-deflate-frame",
@@ -381,19 +419,11 @@ func Test_acceptCompression(t *testing.T) {
381419
t.Run(tc.name, func(t *testing.T) {
382420
t.Parallel()
383421

384-
r := httptest.NewRequest(http.MethodGet, "/", nil)
385-
r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions)
386-
387-
w := httptest.NewRecorder()
388-
copts, err := acceptCompression(r, w, tc.mode)
389-
if tc.error {
390-
assert.Error(t, err)
391-
return
392-
}
393-
394-
assert.Success(t, err)
422+
h := http.Header{}
423+
h.Set("Sec-WebSocket-Extensions", tc.header)
424+
copts, ok := selectDeflate(websocketExtensions(h), tc.mode)
425+
assert.Equal(t, "selected options", tc.expOK, ok)
395426
assert.Equal(t, "compression options", tc.expCopts, copts)
396-
assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))
397427
})
398428
}
399429
}

compress.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ package websocket
55
import (
66
"compress/flate"
77
"io"
8-
"net/http"
98
"sync"
109
)
1110

@@ -58,15 +57,15 @@ type compressionOptions struct {
5857
serverNoContextTakeover bool
5958
}
6059

61-
func (copts *compressionOptions) setHeader(h http.Header) {
60+
func (copts *compressionOptions) String() string {
6261
s := "permessage-deflate"
6362
if copts.clientNoContextTakeover {
6463
s += "; client_no_context_takeover"
6564
}
6665
if copts.serverNoContextTakeover {
6766
s += "; server_no_context_takeover"
6867
}
69-
h.Set("Sec-WebSocket-Extensions", s)
68+
return s
7069
}
7170

7271
// These bytes are required to get flate.Reader to return.

dial.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts
162162
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
163163
}
164164
if copts != nil {
165-
copts.setHeader(req.Header)
165+
req.Header.Set("Sec-WebSocket-Extensions", copts.String())
166166
}
167167

168168
resp, err := opts.HTTPClient.Do(req)

0 commit comments

Comments
 (0)