Skip to content

Commit 0a61ffe

Browse files
committed
Make SetDeadline on NetConn not always close Conn
NetConn has to close the connection to interrupt in progress reads and writes. However, it can block reads and writes that occur after the deadline instead of closing the connection. Closes #228
1 parent 1695216 commit 0a61ffe

File tree

3 files changed

+126
-43
lines changed

3 files changed

+126
-43
lines changed

conn.go

Lines changed: 9 additions & 0 deletions
Original file line numberOriginal file lineDiff line numberDiff line change
@@ -246,6 +246,15 @@ func (m *mu) forceLock() {
246
m.ch <- struct{}{}
246
m.ch <- struct{}{}
247
}
247
}
248

248

249+
func (m *mu) tryLock() bool {
250+
select {
251+
case m.ch <- struct{}{}:
252+
return true
253+
default:
254+
return false
255+
}
256+
}
257+
249
func (m *mu) lock(ctx context.Context) error {
258
func (m *mu) lock(ctx context.Context) error {
250
select {
259
select {
251
case <-m.c.closed:
260
case <-m.c.closed:

netconn.go

Lines changed: 85 additions & 43 deletions
Original file line numberOriginal file lineDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
6
"io"
6
"io"
7
"math"
7
"math"
8
"net"
8
"net"
9-
"sync"
9+
"sync/atomic"
10
"time"
10
"time"
11
)
11
)
12

12

@@ -28,9 +28,10 @@ import (
28
//
28
//
29
// Close will close the *websocket.Conn with StatusNormalClosure.
29
// Close will close the *websocket.Conn with StatusNormalClosure.
30
//
30
//
31-
// When a deadline is hit, the connection will be closed. This is
31+
// When a deadline is hit and there is an active read or write goroutine, the
32-
// different from most net.Conn implementations where only the
32+
// connection will be closed. This is different from most net.Conn implementations
33-
// reading/writing goroutines are interrupted but the connection is kept alive.
33+
// where only the reading/writing goroutines are interrupted but the connection
34+
// is kept alive.
34
//
35
//
35
// The Addr methods will return a mock net.Addr that returns "websocket" for Network
36
// The Addr methods will return a mock net.Addr that returns "websocket" for Network
36
// and "websocket/unknown-addr" for String.
37
// and "websocket/unknown-addr" for String.
@@ -41,17 +42,43 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
41
nc := &netConn{
42
nc := &netConn{
42
c: c,
43
c: c,
43
msgType: msgType,
44
msgType: msgType,
45+
readMu: newMu(c),
46+
writeMu: newMu(c),
44
}
47
}
45

48

46-
var cancel context.CancelFunc
49+
var writeCancel context.CancelFunc
47-
nc.writeContext, cancel = context.WithCancel(ctx)
50+
nc.writeCtx, writeCancel = context.WithCancel(ctx)
48-
nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel)
51+
var readCancel context.CancelFunc
52+
nc.readCtx, readCancel = context.WithCancel(ctx)
53+
54+
nc.writeTimer = time.AfterFunc(math.MaxInt64, func() {
55+
if !nc.writeMu.tryLock() {
56+
// If the lock cannot be acquired, then there is an
57+
// active write goroutine and so we should cancel the context.
58+
writeCancel()
59+
return
60+
}
61+
defer nc.writeMu.unlock()
62+
63+
// Prevents future writes from writing until the deadline is reset.
64+
atomic.StoreInt64(&nc.writeExpired, 1)
65+
})
49
if !nc.writeTimer.Stop() {
66
if !nc.writeTimer.Stop() {
50
<-nc.writeTimer.C
67
<-nc.writeTimer.C
51
}
68
}
52

69

53-
nc.readContext, cancel = context.WithCancel(ctx)
70+
nc.readTimer = time.AfterFunc(math.MaxInt64, func() {
54-
nc.readTimer = time.AfterFunc(math.MaxInt64, cancel)
71+
if !nc.readMu.tryLock() {
72+
// If the lock cannot be acquired, then there is an
73+
// active read goroutine and so we should cancel the context.
74+
readCancel()
75+
return
76+
}
77+
defer nc.readMu.unlock()
78+
79+
// Prevents future reads from reading until the deadline is reset.
80+
atomic.StoreInt64(&nc.readExpired, 1)
81+
})
55
if !nc.readTimer.Stop() {
82
if !nc.readTimer.Stop() {
56
<-nc.readTimer.C
83
<-nc.readTimer.C
57
}
84
}
@@ -64,59 +91,72 @@ type netConn struct {
64
msgType MessageType
91
msgType MessageType
65

92

66
writeTimer *time.Timer
93
writeTimer *time.Timer
67-
writeContext context.Context
94+
writeMu *mu
95+
writeExpired int64
96+
writeCtx context.Context
68

97

69
readTimer *time.Timer
98
readTimer *time.Timer
70-
readContext context.Context
99+
readMu *mu
71-
100+
readExpired int64
72-
readMu sync.Mutex
101+
readCtx context.Context
73-
eofed bool
102+
readEOFed bool
74-
reader io.Reader
103+
reader io.Reader
75
}
104
}
76

105

77
var _ net.Conn = &netConn{}
106
var _ net.Conn = &netConn{}
78

107

79-
func (c *netConn) Close() error {
108+
func (nc *netConn) Close() error {
80-
return c.c.Close(StatusNormalClosure, "")
109+
return nc.c.Close(StatusNormalClosure, "")
81
}
110
}
82

111

83-
func (c *netConn) Write(p []byte) (int, error) {
112+
func (nc *netConn) Write(p []byte) (int, error) {
84-
err := c.c.Write(c.writeContext, c.msgType, p)
113+
nc.writeMu.forceLock()
114+
defer nc.writeMu.unlock()
115+
116+
if atomic.LoadInt64(&nc.writeExpired) == 1 {
117+
return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded)
118+
}
119+
120+
err := nc.c.Write(nc.writeCtx, nc.msgType, p)
85
if err != nil {
121
if err != nil {
86
return 0, err
122
return 0, err
87
}
123
}
88
return len(p), nil
124
return len(p), nil
89
}
125
}
90

126

91-
func (c *netConn) Read(p []byte) (int, error) {
127+
func (nc *netConn) Read(p []byte) (int, error) {
92-
c.readMu.Lock()
128+
nc.readMu.forceLock()
93-
defer c.readMu.Unlock()
129+
defer nc.readMu.unlock()
130+
131+
if atomic.LoadInt64(&nc.readExpired) == 1 {
132+
return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded)
133+
}
94

134

95-
if c.eofed {
135+
if nc.readEOFed {
96
return 0, io.EOF
136
return 0, io.EOF
97
}
137
}
98

138

99-
if c.reader == nil {
139+
if nc.reader == nil {
100-
typ, r, err := c.c.Reader(c.readContext)
140+
typ, r, err := nc.c.Reader(nc.readCtx)
101
if err != nil {
141
if err != nil {
102
switch CloseStatus(err) {
142
switch CloseStatus(err) {
103
case StatusNormalClosure, StatusGoingAway:
143
case StatusNormalClosure, StatusGoingAway:
104-
c.eofed = true
144+
nc.readEOFed = true
105
return 0, io.EOF
145
return 0, io.EOF
106
}
146
}
107
return 0, err
147
return 0, err
108
}
148
}
109-
if typ != c.msgType {
149+
if typ != nc.msgType {
110-
err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ)
150+
err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ)
111-
c.c.Close(StatusUnsupportedData, err.Error())
151+
nc.c.Close(StatusUnsupportedData, err.Error())
112
return 0, err
152
return 0, err
113
}
153
}
114-
c.reader = r
154+
nc.reader = r
115
}
155
}
116

156

117-
n, err := c.reader.Read(p)
157+
n, err := nc.reader.Read(p)
118
if err == io.EOF {
158
if err == io.EOF {
119-
c.reader = nil
159+
nc.reader = nil
120
err = nil
160
err = nil
121
}
161
}
122
return n, err
162
return n, err
@@ -133,34 +173,36 @@ func (a websocketAddr) String() string {
133
return "websocket/unknown-addr"
173
return "websocket/unknown-addr"
134
}
174
}
135

175

136-
func (c *netConn) RemoteAddr() net.Addr {
176+
func (nc *netConn) RemoteAddr() net.Addr {
137
return websocketAddr{}
177
return websocketAddr{}
138
}
178
}
139

179

140-
func (c *netConn) LocalAddr() net.Addr {
180+
func (nc *netConn) LocalAddr() net.Addr {
141
return websocketAddr{}
181
return websocketAddr{}
142
}
182
}
143

183

144-
func (c *netConn) SetDeadline(t time.Time) error {
184+
func (nc *netConn) SetDeadline(t time.Time) error {
145-
c.SetWriteDeadline(t)
185+
nc.SetWriteDeadline(t)
146-
c.SetReadDeadline(t)
186+
nc.SetReadDeadline(t)
147
return nil
187
return nil
148
}
188
}
149

189

150-
func (c *netConn) SetWriteDeadline(t time.Time) error {
190+
func (nc *netConn) SetWriteDeadline(t time.Time) error {
191+
atomic.StoreInt64(&nc.writeExpired, 0)
151
if t.IsZero() {
192
if t.IsZero() {
152-
c.writeTimer.Stop()
193+
nc.writeTimer.Stop()
153
} else {
194
} else {
154-
c.writeTimer.Reset(t.Sub(time.Now()))
195+
nc.writeTimer.Reset(t.Sub(time.Now()))
155
}
196
}
156
return nil
197
return nil
157
}
198
}
158

199

159-
func (c *netConn) SetReadDeadline(t time.Time) error {
200+
func (nc *netConn) SetReadDeadline(t time.Time) error {
201+
atomic.StoreInt64(&nc.readExpired, 0)
160
if t.IsZero() {
202
if t.IsZero() {
161-
c.readTimer.Stop()
203+
nc.readTimer.Stop()
162
} else {
204
} else {
163-
c.readTimer.Reset(t.Sub(time.Now()))
205+
nc.readTimer.Reset(t.Sub(time.Now()))
164
}
206
}
165
return nil
207
return nil
166
}
208
}

ws_js.go

Lines changed: 32 additions & 0 deletions
Original file line numberOriginal file lineDiff line numberDiff line change
@@ -511,3 +511,35 @@ const (
511
// MessageBinary is for binary messages like protobufs.
511
// MessageBinary is for binary messages like protobufs.
512
MessageBinary
512
MessageBinary
513
)
513
)
514+
515+
type mu struct {
516+
c *Conn
517+
ch chan struct{}
518+
}
519+
520+
func newMu(c *Conn) *mu {
521+
return &mu{
522+
c: c,
523+
ch: make(chan struct{}, 1),
524+
}
525+
}
526+
527+
func (m *mu) forceLock() {
528+
m.ch <- struct{}{}
529+
}
530+
531+
func (m *mu) tryLock() bool {
532+
select {
533+
case m.ch <- struct{}{}:
534+
return true
535+
default:
536+
return false
537+
}
538+
}
539+
540+
func (m *mu) unlock() {
541+
select {
542+
case <-m.ch:
543+
default:
544+
}
545+
}

0 commit comments

Comments
 (0)