diff --git a/accept.go b/accept.go index 6c63e730..19e388ec 100644 --- a/accept.go +++ b/accept.go @@ -123,9 +123,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con w.Header().Set("Sec-WebSocket-Protocol", subproto) } - copts, err := acceptCompression(r, w, opts.CompressionMode) - if err != nil { - return nil, err + copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode) + if ok { + w.Header().Set("Sec-WebSocket-Extensions", copts.String()) } w.WriteHeader(http.StatusSwitchingProtocols) @@ -238,25 +238,26 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string { return "" } -func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) { +func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) { if mode == CompressionDisabled { - return nil, nil + return nil, false } - - for _, ext := range websocketExtensions(r.Header) { + for _, ext := range extensions { switch ext.name { // We used to implement x-webkit-deflate-fram too but Safari has bugs. // See https://github.com/nhooyr/websocket/issues/218 case "permessage-deflate": - return acceptDeflate(w, ext, mode) + copts, ok := acceptDeflate(ext, mode) + if ok { + return copts, true + } } } - return nil, nil + return nil, false } -func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { +func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) { copts := mode.opts() - for _, p := range ext.params { switch p { case "client_no_context_takeover": @@ -265,22 +266,18 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi case "server_no_context_takeover": copts.serverNoContextTakeover = true continue + case "client_max_window_bits", + "server_max_window_bits=15": + continue } - if strings.HasPrefix(p, "client_max_window_bits") { - // We cannot adjust the read sliding window so cannot make use of this. - // By not responding to it, we tell the client we're ignoring it. + if strings.HasPrefix(p, "client_max_window_bits=") { + // We can't adjust the deflate window, but decoding with a larger window is acceptable. continue } - - err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p) - http.Error(w, err.Error(), http.StatusBadRequest) - return nil, err + return nil, false } - - copts.setHeader(w.Header()) - - return copts, nil + return copts, true } func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { diff --git a/accept_test.go b/accept_test.go index ae17c0b4..c554bdaf 100644 --- a/accept_test.go +++ b/accept_test.go @@ -62,20 +62,50 @@ func TestAccept(t *testing.T) { t.Run("badCompression", func(t *testing.T) { t.Parallel() - w := mockHijacker{ - ResponseWriter: httptest.NewRecorder(), + newRequest := func(extensions string) *http.Request { + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Extensions", extensions) + return r + } + errHijack := errors.New("hijack error") + newResponseWriter := func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + hijack: func() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, errHijack + }, + } } - r := httptest.NewRequest("GET", "/", nil) - r.Header.Set("Connection", "Upgrade") - r.Header.Set("Upgrade", "websocket") - r.Header.Set("Sec-WebSocket-Version", "13") - r.Header.Set("Sec-WebSocket-Key", "meow123") - r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar") - _, err := Accept(w, r, &AcceptOptions{ - CompressionMode: CompressionContextTakeover, + t.Run("withoutFallback", func(t *testing.T) { + t.Parallel() + + w := newResponseWriter() + r := newRequest("permessage-deflate; harharhar") + _, err := Accept(w, r, &AcceptOptions{ + CompressionMode: CompressionNoContextTakeover, + }) + assert.ErrorIs(t, errHijack, err) + assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "") + }) + t.Run("withFallback", func(t *testing.T) { + t.Parallel() + + w := newResponseWriter() + r := newRequest("permessage-deflate; harharhar, permessage-deflate") + _, err := Accept(w, r, &AcceptOptions{ + CompressionMode: CompressionNoContextTakeover, + }) + assert.ErrorIs(t, errHijack, err) + assert.Equal(t, "extension header", + w.Header().Get("Sec-WebSocket-Extensions"), + CompressionNoContextTakeover.opts().String(), + ) }) - assert.Contains(t, err, `unsupported permessage-deflate parameter`) }) t.Run("requireHttpHijacker", func(t *testing.T) { @@ -344,59 +374,54 @@ func Test_authenticateOrigin(t *testing.T) { } } -func Test_acceptCompression(t *testing.T) { +func Test_selectDeflate(t *testing.T) { t.Parallel() testCases := []struct { - name string - mode CompressionMode - reqSecWebSocketExtensions string - respSecWebSocketExtensions string - expCopts *compressionOptions - error bool + name string + mode CompressionMode + header string + expCopts *compressionOptions + expOK bool }{ { name: "disabled", mode: CompressionDisabled, expCopts: nil, + expOK: false, }, { name: "noClientSupport", mode: CompressionNoContextTakeover, expCopts: nil, + expOK: false, }, { - name: "permessage-deflate", - mode: CompressionNoContextTakeover, - reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits", - respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover", + name: "permessage-deflate", + mode: CompressionNoContextTakeover, + header: "permessage-deflate; client_max_window_bits", expCopts: &compressionOptions{ clientNoContextTakeover: true, serverNoContextTakeover: true, }, + expOK: true, + }, + { + name: "permessage-deflate/unknown-parameter", + mode: CompressionNoContextTakeover, + header: "permessage-deflate; meow", + expOK: false, }, { - name: "permessage-deflate/error", - mode: CompressionNoContextTakeover, - reqSecWebSocketExtensions: "permessage-deflate; meow", - error: true, + name: "permessage-deflate/unknown-parameter", + mode: CompressionNoContextTakeover, + header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits", + expCopts: &compressionOptions{ + clientNoContextTakeover: true, + serverNoContextTakeover: true, + }, + expOK: true, }, - // { - // name: "x-webkit-deflate-frame", - // mode: CompressionNoContextTakeover, - // reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", - // respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", - // expCopts: &compressionOptions{ - // clientNoContextTakeover: true, - // serverNoContextTakeover: true, - // }, - // }, - // { - // name: "x-webkit-deflate/error", - // mode: CompressionNoContextTakeover, - // reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits", - // error: true, - // }, } for _, tc := range testCases { @@ -404,19 +429,11 @@ func Test_acceptCompression(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions) - - w := httptest.NewRecorder() - copts, err := acceptCompression(r, w, tc.mode) - if tc.error { - assert.Error(t, err) - return - } - - assert.Success(t, err) + h := http.Header{} + h.Set("Sec-WebSocket-Extensions", tc.header) + copts, ok := selectDeflate(websocketExtensions(h), tc.mode) + assert.Equal(t, "selected options", tc.expOK, ok) assert.Equal(t, "compression options", tc.expCopts, copts) - assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) }) } } diff --git a/compress.go b/compress.go index 61e6e268..81de751b 100644 --- a/compress.go +++ b/compress.go @@ -6,18 +6,13 @@ package websocket import ( "compress/flate" "io" - "net/http" "sync" ) // CompressionMode represents the modes available to the deflate extension. // See https://tools.ietf.org/html/rfc7692 // -// A compatibility layer is implemented for the older deflate-frame extension used -// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06 -// It will work the same in every way except that we cannot signal to the peer we -// want to use no context takeover on our side, we can only signal that they should. -// But it is currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218 +// Works in all browsers except Safari which does not implement the deflate extension. type CompressionMode int const ( @@ -65,7 +60,7 @@ type compressionOptions struct { serverNoContextTakeover bool } -func (copts *compressionOptions) setHeader(h http.Header) { +func (copts *compressionOptions) String() string { s := "permessage-deflate" if copts.clientNoContextTakeover { s += "; client_no_context_takeover" @@ -73,7 +68,7 @@ func (copts *compressionOptions) setHeader(h http.Header) { if copts.serverNoContextTakeover { s += "; server_no_context_takeover" } - h.Set("Sec-WebSocket-Extensions", s) + return s } // These bytes are required to get flate.Reader to return. diff --git a/dial.go b/dial.go index 0f2735da..e72432e7 100644 --- a/dial.go +++ b/dial.go @@ -185,7 +185,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } if copts != nil { - copts.setHeader(req.Header) + req.Header.Set("Sec-WebSocket-Extensions", copts.String()) } resp, err := opts.HTTPClient.Do(req) @@ -273,6 +273,10 @@ func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compress copts.serverNoContextTakeover = true continue } + if strings.HasPrefix(p, "server_max_window_bits=") { + // We can't adjust the deflate window, but decoding with a larger window is acceptable. + continue + } return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) } diff --git a/internal/test/assert/assert.go b/internal/test/assert/assert.go index 64c938c5..1b90cc9f 100644 --- a/internal/test/assert/assert.go +++ b/internal/test/assert/assert.go @@ -1,6 +1,7 @@ package assert import ( + "errors" "fmt" "reflect" "strings" @@ -43,3 +44,12 @@ func Contains(t testing.TB, v interface{}, sub string) { t.Fatalf("expected %q to contain %q", s, sub) } } + +// ErrorIs asserts errors.Is(got, exp) +func ErrorIs(t testing.TB, exp, got error) { + t.Helper() + + if !errors.Is(got, exp) { + t.Fatalf("expected %v but got %v", exp, got) + } +} diff --git a/ws_js.go b/ws_js.go index 9f0e19e9..e60601e3 100644 --- a/ws_js.go +++ b/ws_js.go @@ -485,12 +485,7 @@ func CloseStatus(err error) StatusCode { // CompressionMode represents the modes available to the deflate extension. // See https://tools.ietf.org/html/rfc7692 -// -// A compatibility layer is implemented for the older deflate-frame extension used -// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06 -// It will work the same in every way except that we cannot signal to the peer we -// want to use no context takeover on our side, we can only signal that they should. -// It is however currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218 +// Works in all browsers except Safari which does not implement the deflate extension. type CompressionMode int const (