@@ -2,6 +2,7 @@ package main
2
2
3
3
import (
4
4
"context"
5
+ "errors"
5
6
"io"
6
7
"io/ioutil"
7
8
"log"
@@ -12,21 +13,36 @@ import (
12
13
"nhooyr.io/websocket"
13
14
)
14
15
16
+ // chatServer enables broadcasting to a set of subscribers.
15
17
type chatServer struct {
16
18
subscribersMu sync.RWMutex
17
- subscribers map [chan []byte ]struct {}
19
+ subscribers map [chan <- []byte ]struct {}
18
20
}
19
21
22
+ // subscribeHandler accepts the WebSocket connection and then subscribes
23
+ // it to all future messages.
20
24
func (cs * chatServer ) subscribeHandler (w http.ResponseWriter , r * http.Request ) {
21
25
c , err := websocket .Accept (w , r , nil )
22
26
if err != nil {
23
27
log .Print (err )
24
28
return
25
29
}
26
30
27
- cs .subscribe (r .Context (), c )
31
+ err = cs .subscribe (r .Context (), c )
32
+ if errors .Is (err , context .Canceled ) {
33
+ return
34
+ }
35
+ if websocket .CloseStatus (err ) == websocket .StatusNormalClosure ||
36
+ websocket .CloseStatus (err ) == websocket .StatusGoingAway {
37
+ return
38
+ }
39
+ if err != nil {
40
+ log .Print (err )
41
+ }
28
42
}
29
43
44
+ // publishHandler reads the request body with a limit of 8192 bytes and then publishes
45
+ // the received message.
30
46
func (cs * chatServer ) publishHandler (w http.ResponseWriter , r * http.Request ) {
31
47
if r .Method != "POST" {
32
48
http .Error (w , http .StatusText (http .StatusMethodNotAllowed ), http .StatusMethodNotAllowed )
@@ -35,12 +51,44 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
35
51
body := io .LimitReader (r .Body , 8192 )
36
52
msg , err := ioutil .ReadAll (body )
37
53
if err != nil {
54
+ http .Error (w , http .StatusText (http .StatusRequestEntityTooLarge ), http .StatusRequestEntityTooLarge )
38
55
return
39
56
}
40
57
41
58
cs .publish (msg )
42
59
}
43
60
61
+ // subscribe subscribes the given WebSocket to all broadcast messages.
62
+ // It creates a msgs chan with a buffer of 16 to give some room to slower
63
+ // connections and then registers it. It then listens for all messages
64
+ // and writes them to the WebSocket. If the context is cancelled or
65
+ // an error occurs, it returns and deletes the subscription.
66
+ //
67
+ // It uses CloseRead to keep reading from the connection to process control
68
+ // messages and cancel the context if the connection drops.
69
+ func (cs * chatServer ) subscribe (ctx context.Context , c * websocket.Conn ) error {
70
+ ctx = c .CloseRead (ctx )
71
+
72
+ msgs := make (chan []byte , 16 )
73
+ cs .addSubscriber (msgs )
74
+ defer cs .deleteSubscriber (msgs )
75
+
76
+ for {
77
+ select {
78
+ case msg := <- msgs :
79
+ err := writeTimeout (ctx , time .Second * 5 , c , msg )
80
+ if err != nil {
81
+ return err
82
+ }
83
+ case <- ctx .Done ():
84
+ return ctx .Err ()
85
+ }
86
+ }
87
+ }
88
+
89
+ // publish publishes the msg to all subscribers.
90
+ // It never blocks and so messages to slow subscribers
91
+ // are dropped.
44
92
func (cs * chatServer ) publish (msg []byte ) {
45
93
cs .subscribersMu .RLock ()
46
94
defer cs .subscribersMu .RUnlock ()
@@ -53,41 +101,24 @@ func (cs *chatServer) publish(msg []byte) {
53
101
}
54
102
}
55
103
56
- func (cs * chatServer ) addSubscriber (msgs chan []byte ) {
104
+ // addSubscriber registers a subscriber with a channel
105
+ // on which to send messages.
106
+ func (cs * chatServer ) addSubscriber (msgs chan <- []byte ) {
57
107
cs .subscribersMu .Lock ()
58
108
if cs .subscribers == nil {
59
- cs .subscribers = make (map [chan []byte ]struct {})
109
+ cs .subscribers = make (map [chan <- []byte ]struct {})
60
110
}
61
111
cs .subscribers [msgs ] = struct {}{}
62
112
cs .subscribersMu .Unlock ()
63
113
}
64
114
115
+ // deleteSubscriber deletes the subscriber with the given msgs channel.
65
116
func (cs * chatServer ) deleteSubscriber (msgs chan []byte ) {
66
117
cs .subscribersMu .Lock ()
67
118
delete (cs .subscribers , msgs )
68
119
cs .subscribersMu .Unlock ()
69
120
}
70
121
71
- func (cs * chatServer ) subscribe (ctx context.Context , c * websocket.Conn ) error {
72
- ctx = c .CloseRead (ctx )
73
-
74
- msgs := make (chan []byte , 16 )
75
- cs .addSubscriber (msgs )
76
- defer cs .deleteSubscriber (msgs )
77
-
78
- for {
79
- select {
80
- case msg := <- msgs :
81
- err := writeTimeout (ctx , time .Second * 5 , c , msg )
82
- if err != nil {
83
- return err
84
- }
85
- case <- ctx .Done ():
86
- return ctx .Err ()
87
- }
88
- }
89
- }
90
-
91
122
func writeTimeout (ctx context.Context , timeout time.Duration , c * websocket.Conn , msg []byte ) error {
92
123
ctx , cancel := context .WithTimeout (ctx , timeout )
93
124
defer cancel ()
0 commit comments