Skip to content

Commit 23a3ede

Browse files
authored
Merge pull request #131 from nhooyr/netconn
Modify NetConn to take a context as the first argument
2 parents 016b716 + 31b47c3 commit 23a3ede

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

netconn.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ import (
2121
// Every Write to the net.Conn will correspond to a message write of
2222
// the given type on *websocket.Conn.
2323
//
24-
// If a message is read that is not of the correct type, an error
25-
// will be thrown.
24+
// The passed ctx bounds the lifetime of the net.Conn. If cancelled,
25+
// all reads and writes on the net.Conn will be cancelled.
26+
//
27+
// If a message is read that is not of the correct type, the connection
28+
// will be closed with StatusUnsupportedData and an error will be returned.
2629
//
2730
// Close will close the *websocket.Conn with StatusNormalClosure.
2831
//
@@ -35,20 +38,20 @@ import (
3538
//
3639
// A received StatusNormalClosure or StatusGoingAway close frame will be translated to
3740
// io.EOF when reading.
38-
func NetConn(c *Conn, msgType MessageType) net.Conn {
41+
func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
3942
nc := &netConn{
4043
c: c,
4144
msgType: msgType,
4245
}
4346

4447
var cancel context.CancelFunc
45-
nc.writeContext, cancel = context.WithCancel(context.Background())
48+
nc.writeContext, cancel = context.WithCancel(ctx)
4649
nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel)
4750
if !nc.writeTimer.Stop() {
4851
<-nc.writeTimer.C
4952
}
5053

51-
nc.readContext, cancel = context.WithCancel(context.Background())
54+
nc.readContext, cancel = context.WithCancel(ctx)
5255
nc.readTimer = time.AfterFunc(math.MaxInt64, cancel)
5356
if !nc.readTimer.Stop() {
5457
<-nc.readTimer.C

websocket_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ func TestConn(t *testing.T) {
264264
{
265265
name: "netConn",
266266
server: func(ctx context.Context, c *websocket.Conn) error {
267-
nc := websocket.NetConn(c, websocket.MessageBinary)
267+
nc := websocket.NetConn(ctx, c, websocket.MessageBinary)
268268
defer nc.Close()
269269

270270
nc.SetWriteDeadline(time.Time{})
@@ -290,7 +290,7 @@ func TestConn(t *testing.T) {
290290
return nil
291291
},
292292
client: func(ctx context.Context, c *websocket.Conn) error {
293-
nc := websocket.NetConn(c, websocket.MessageBinary)
293+
nc := websocket.NetConn(ctx, c, websocket.MessageBinary)
294294

295295
nc.SetReadDeadline(time.Time{})
296296
time.Sleep(1)
@@ -317,7 +317,7 @@ func TestConn(t *testing.T) {
317317
{
318318
name: "netConn/badReadMsgType",
319319
server: func(ctx context.Context, c *websocket.Conn) error {
320-
nc := websocket.NetConn(c, websocket.MessageBinary)
320+
nc := websocket.NetConn(ctx, c, websocket.MessageBinary)
321321

322322
nc.SetDeadline(time.Now().Add(time.Second * 15))
323323

@@ -337,7 +337,7 @@ func TestConn(t *testing.T) {
337337
{
338338
name: "netConn/badRead",
339339
server: func(ctx context.Context, c *websocket.Conn) error {
340-
nc := websocket.NetConn(c, websocket.MessageBinary)
340+
nc := websocket.NetConn(ctx, c, websocket.MessageBinary)
341341
defer nc.Close()
342342

343343
nc.SetDeadline(time.Now().Add(time.Second * 15))

0 commit comments

Comments
 (0)