@@ -6,10 +6,12 @@ package websocket
6
6
import (
7
7
"bufio"
8
8
"errors"
9
+ "io"
9
10
"net"
10
11
"net/http"
11
12
"net/http/httptest"
12
13
"strings"
14
+ "sync"
13
15
"testing"
14
16
15
17
"nhooyr.io/websocket/internal/test/assert"
@@ -142,6 +144,43 @@ func TestAccept(t *testing.T) {
142
144
_ , err := Accept (w , r , nil )
143
145
assert .Contains (t , err , `failed to hijack connection` )
144
146
})
147
+ t .Run ("closeRace" , func (t * testing.T ) {
148
+ t .Parallel ()
149
+
150
+ server , _ := net .Pipe ()
151
+
152
+ pr , pw := io .Pipe ()
153
+ rw := bufio .NewReadWriter (bufio .NewReader (pr ), bufio .NewWriter (pw ))
154
+ newResponseWriter := func () http.ResponseWriter {
155
+ return mockHijacker {
156
+ ResponseWriter : httptest .NewRecorder (),
157
+ hijack : func () (net.Conn , * bufio.ReadWriter , error ) {
158
+ return server , rw , nil
159
+ },
160
+ }
161
+ }
162
+ w := newResponseWriter ()
163
+
164
+ r := httptest .NewRequest ("GET" , "/" , nil )
165
+ r .Header .Set ("Connection" , "Upgrade" )
166
+ r .Header .Set ("Upgrade" , "websocket" )
167
+ r .Header .Set ("Sec-WebSocket-Version" , "13" )
168
+ r .Header .Set ("Sec-WebSocket-Key" , xrand .Base64 (16 ))
169
+
170
+ c , err := Accept (w , r , nil )
171
+ wg := & sync.WaitGroup {}
172
+ wg .Add (2 )
173
+ go func () {
174
+ c .Close (StatusInternalError , "the sky is falling" )
175
+ wg .Done ()
176
+ }()
177
+ go func () {
178
+ c .CloseNow ()
179
+ wg .Done ()
180
+ }()
181
+ wg .Wait ()
182
+ assert .Success (t , err )
183
+ })
145
184
}
146
185
147
186
func Test_verifyClientHandshake (t * testing.T ) {
0 commit comments