6
"io"
6
"io"
7
"math"
7
"math"
8
"net"
8
"net"
9
- "sync"
9
+ "sync/atomic "
10
"time"
10
"time"
11
)
11
)
12
12
@@ -28,9 +28,10 @@ import (
28
//
28
//
29
// Close will close the *websocket.Conn with StatusNormalClosure.
29
// Close will close the *websocket.Conn with StatusNormalClosure.
30
//
30
//
31
- // When a deadline is hit, the connection will be closed. This is
31
+ // When a deadline is hit and there is an active read or write goroutine, the
32
- // different from most net.Conn implementations where only the
32
+ // connection will be closed. This is different from most net.Conn implementations
33
- // reading/writing goroutines are interrupted but the connection is kept alive.
33
+ // where only the reading/writing goroutines are interrupted but the connection
34
+ // is kept alive.
34
//
35
//
35
// The Addr methods will return a mock net.Addr that returns "websocket" for Network
36
// The Addr methods will return a mock net.Addr that returns "websocket" for Network
36
// and "websocket/unknown-addr" for String.
37
// and "websocket/unknown-addr" for String.
@@ -41,17 +42,43 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
41
nc := & netConn {
42
nc := & netConn {
42
c : c ,
43
c : c ,
43
msgType : msgType ,
44
msgType : msgType ,
45
+ readMu : newMu (c ),
46
+ writeMu : newMu (c ),
44
}
47
}
45
48
46
- var cancel context.CancelFunc
49
+ var writeCancel context.CancelFunc
47
- nc .writeContext , cancel = context .WithCancel (ctx )
50
+ nc .writeCtx , writeCancel = context .WithCancel (ctx )
48
- nc .writeTimer = time .AfterFunc (math .MaxInt64 , cancel )
51
+ var readCancel context.CancelFunc
52
+ nc .readCtx , readCancel = context .WithCancel (ctx )
53
+
54
+ nc .writeTimer = time .AfterFunc (math .MaxInt64 , func () {
55
+ if ! nc .writeMu .tryLock () {
56
+ // If the lock cannot be acquired, then there is an
57
+ // active write goroutine and so we should cancel the context.
58
+ writeCancel ()
59
+ return
60
+ }
61
+ defer nc .writeMu .unlock ()
62
+
63
+ // Prevents future writes from writing until the deadline is reset.
64
+ atomic .StoreInt64 (& nc .writeExpired , 1 )
65
+ })
49
if ! nc .writeTimer .Stop () {
66
if ! nc .writeTimer .Stop () {
50
<- nc .writeTimer .C
67
<- nc .writeTimer .C
51
}
68
}
52
69
53
- nc .readContext , cancel = context .WithCancel (ctx )
70
+ nc .readTimer = time .AfterFunc (math .MaxInt64 , func () {
54
- nc .readTimer = time .AfterFunc (math .MaxInt64 , cancel )
71
+ if ! nc .readMu .tryLock () {
72
+ // If the lock cannot be acquired, then there is an
73
+ // active read goroutine and so we should cancel the context.
74
+ readCancel ()
75
+ return
76
+ }
77
+ defer nc .readMu .unlock ()
78
+
79
+ // Prevents future reads from reading until the deadline is reset.
80
+ atomic .StoreInt64 (& nc .readExpired , 1 )
81
+ })
55
if ! nc .readTimer .Stop () {
82
if ! nc .readTimer .Stop () {
56
<- nc .readTimer .C
83
<- nc .readTimer .C
57
}
84
}
@@ -64,59 +91,72 @@ type netConn struct {
64
msgType MessageType
91
msgType MessageType
65
92
66
writeTimer * time.Timer
93
writeTimer * time.Timer
67
- writeContext context.Context
94
+ writeMu * mu
95
+ writeExpired int64
96
+ writeCtx context.Context
68
97
69
readTimer * time.Timer
98
readTimer * time.Timer
70
- readContext context. Context
99
+ readMu * mu
71
-
100
+ readExpired int64
72
- readMu sync. Mutex
101
+ readCtx context. Context
73
- eofed bool
102
+ readEOFed bool
74
- reader io.Reader
103
+ reader io.Reader
75
}
104
}
76
105
77
var _ net.Conn = & netConn {}
106
var _ net.Conn = & netConn {}
78
107
79
- func (c * netConn ) Close () error {
108
+ func (nc * netConn ) Close () error {
80
- return c .c .Close (StatusNormalClosure , "" )
109
+ return nc .c .Close (StatusNormalClosure , "" )
81
}
110
}
82
111
83
- func (c * netConn ) Write (p []byte ) (int , error ) {
112
+ func (nc * netConn ) Write (p []byte ) (int , error ) {
84
- err := c .c .Write (c .writeContext , c .msgType , p )
113
+ nc .writeMu .forceLock ()
114
+ defer nc .writeMu .unlock ()
115
+
116
+ if atomic .LoadInt64 (& nc .writeExpired ) == 1 {
117
+ return 0 , fmt .Errorf ("failed to write: %w" , context .DeadlineExceeded )
118
+ }
119
+
120
+ err := nc .c .Write (nc .writeCtx , nc .msgType , p )
85
if err != nil {
121
if err != nil {
86
return 0 , err
122
return 0 , err
87
}
123
}
88
return len (p ), nil
124
return len (p ), nil
89
}
125
}
90
126
91
- func (c * netConn ) Read (p []byte ) (int , error ) {
127
+ func (nc * netConn ) Read (p []byte ) (int , error ) {
92
- c .readMu .Lock ()
128
+ nc .readMu .forceLock ()
93
- defer c .readMu .Unlock ()
129
+ defer nc .readMu .unlock ()
130
+
131
+ if atomic .LoadInt64 (& nc .readExpired ) == 1 {
132
+ return 0 , fmt .Errorf ("failed to read: %w" , context .DeadlineExceeded )
133
+ }
94
134
95
- if c . eofed {
135
+ if nc . readEOFed {
96
return 0 , io .EOF
136
return 0 , io .EOF
97
}
137
}
98
138
99
- if c .reader == nil {
139
+ if nc .reader == nil {
100
- typ , r , err := c .c .Reader (c . readContext )
140
+ typ , r , err := nc .c .Reader (nc . readCtx )
101
if err != nil {
141
if err != nil {
102
switch CloseStatus (err ) {
142
switch CloseStatus (err ) {
103
case StatusNormalClosure , StatusGoingAway :
143
case StatusNormalClosure , StatusGoingAway :
104
- c . eofed = true
144
+ nc . readEOFed = true
105
return 0 , io .EOF
145
return 0 , io .EOF
106
}
146
}
107
return 0 , err
147
return 0 , err
108
}
148
}
109
- if typ != c .msgType {
149
+ if typ != nc .msgType {
110
- err := fmt .Errorf ("unexpected frame type read (expected %v): %v" , c .msgType , typ )
150
+ err := fmt .Errorf ("unexpected frame type read (expected %v): %v" , nc .msgType , typ )
111
- c .c .Close (StatusUnsupportedData , err .Error ())
151
+ nc .c .Close (StatusUnsupportedData , err .Error ())
112
return 0 , err
152
return 0 , err
113
}
153
}
114
- c .reader = r
154
+ nc .reader = r
115
}
155
}
116
156
117
- n , err := c .reader .Read (p )
157
+ n , err := nc .reader .Read (p )
118
if err == io .EOF {
158
if err == io .EOF {
119
- c .reader = nil
159
+ nc .reader = nil
120
err = nil
160
err = nil
121
}
161
}
122
return n , err
162
return n , err
@@ -133,34 +173,36 @@ func (a websocketAddr) String() string {
133
return "websocket/unknown-addr"
173
return "websocket/unknown-addr"
134
}
174
}
135
175
136
- func (c * netConn ) RemoteAddr () net.Addr {
176
+ func (nc * netConn ) RemoteAddr () net.Addr {
137
return websocketAddr {}
177
return websocketAddr {}
138
}
178
}
139
179
140
- func (c * netConn ) LocalAddr () net.Addr {
180
+ func (nc * netConn ) LocalAddr () net.Addr {
141
return websocketAddr {}
181
return websocketAddr {}
142
}
182
}
143
183
144
- func (c * netConn ) SetDeadline (t time.Time ) error {
184
+ func (nc * netConn ) SetDeadline (t time.Time ) error {
145
- c .SetWriteDeadline (t )
185
+ nc .SetWriteDeadline (t )
146
- c .SetReadDeadline (t )
186
+ nc .SetReadDeadline (t )
147
return nil
187
return nil
148
}
188
}
149
189
150
- func (c * netConn ) SetWriteDeadline (t time.Time ) error {
190
+ func (nc * netConn ) SetWriteDeadline (t time.Time ) error {
191
+ atomic .StoreInt64 (& nc .writeExpired , 0 )
151
if t .IsZero () {
192
if t .IsZero () {
152
- c .writeTimer .Stop ()
193
+ nc .writeTimer .Stop ()
153
} else {
194
} else {
154
- c .writeTimer .Reset (t .Sub (time .Now ()))
195
+ nc .writeTimer .Reset (t .Sub (time .Now ()))
155
}
196
}
156
return nil
197
return nil
157
}
198
}
158
199
159
- func (c * netConn ) SetReadDeadline (t time.Time ) error {
200
+ func (nc * netConn ) SetReadDeadline (t time.Time ) error {
201
+ atomic .StoreInt64 (& nc .readExpired , 0 )
160
if t .IsZero () {
202
if t .IsZero () {
161
- c .readTimer .Stop ()
203
+ nc .readTimer .Stop ()
162
} else {
204
} else {
163
- c .readTimer .Reset (t .Sub (time .Now ()))
205
+ nc .readTimer .Reset (t .Sub (time .Now ()))
164
}
206
}
165
return nil
207
return nil
166
}
208
}
0 commit comments