16
16
17
17
package org .springframework .web .socket .server .standard ;
18
18
19
+ import java .lang .reflect .Constructor ;
19
20
import java .util .Arrays ;
20
21
import java .util .Collections ;
21
22
import java .util .List ;
23
+ import java .util .Set ;
24
+ import java .util .concurrent .ConcurrentHashMap ;
22
25
import javax .servlet .http .HttpServletRequest ;
23
26
import javax .servlet .http .HttpServletResponse ;
24
27
import javax .websocket .Decoder ;
34
37
import io .undertow .websockets .core .WebSocketChannel ;
35
38
import io .undertow .websockets .core .WebSocketVersion ;
36
39
import io .undertow .websockets .core .protocol .Handshake ;
37
- import io .undertow .websockets .core .protocol .version07 .Hybi07Handshake ;
38
- import io .undertow .websockets .core .protocol .version08 .Hybi08Handshake ;
39
- import io .undertow .websockets .core .protocol .version13 .Hybi13Handshake ;
40
40
import io .undertow .websockets .jsr .ConfiguredServerEndpoint ;
41
41
import io .undertow .websockets .jsr .EncodingFactory ;
42
42
import io .undertow .websockets .jsr .EndpointSessionHandler ;
45
45
import io .undertow .websockets .jsr .handshake .JsrHybi07Handshake ;
46
46
import io .undertow .websockets .jsr .handshake .JsrHybi08Handshake ;
47
47
import io .undertow .websockets .jsr .handshake .JsrHybi13Handshake ;
48
+ import org .springframework .util .ClassUtils ;
48
49
import org .xnio .StreamConnection ;
49
50
50
51
import org .springframework .http .server .ServerHttpRequest ;
61
62
*/
62
63
public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy {
63
64
64
- private final String [] supportedVersions = new String [] {
65
+ private static final Constructor <ServletWebSocketHttpExchange > exchangeConstructor ;
66
+
67
+ private static final boolean undertow10Present ;
68
+
69
+ static {
70
+ Class <ServletWebSocketHttpExchange > type = ServletWebSocketHttpExchange .class ;
71
+ Class <?>[] paramTypes = new Class <?>[] {HttpServletRequest .class , HttpServletResponse .class , Set .class };
72
+ if (ClassUtils .hasConstructor (type , paramTypes )) {
73
+ exchangeConstructor = ClassUtils .getConstructorIfAvailable (type , paramTypes );
74
+ undertow10Present = false ;
75
+ }
76
+ else {
77
+ paramTypes = new Class <?>[] {HttpServletRequest .class , HttpServletResponse .class };
78
+ exchangeConstructor = ClassUtils .getConstructorIfAvailable (type , paramTypes );
79
+ undertow10Present = true ;
80
+ }
81
+ }
82
+
83
+ private static final String [] supportedVersions = new String [] {
65
84
WebSocketVersion .V13 .toHttpHeaderValue (),
66
85
WebSocketVersion .V08 .toHttpHeaderValue (),
67
86
WebSocketVersion .V07 .toHttpHeaderValue ()
68
87
};
69
88
70
89
90
+ private Set <WebSocketChannel > peerConnections ;
91
+
92
+
93
+ public UndertowRequestUpgradeStrategy () {
94
+ if (undertow10Present ) {
95
+ this .peerConnections = null ;
96
+ }
97
+ else {
98
+ this .peerConnections = Collections .newSetFromMap (new ConcurrentHashMap <WebSocketChannel , Boolean >());
99
+ }
100
+ }
101
+
102
+
71
103
@ Override
72
104
public String [] getSupportedVersions () {
73
- return this . supportedVersions ;
105
+ return supportedVersions ;
74
106
}
75
107
76
108
@ Override
@@ -80,7 +112,7 @@ protected void upgradeInternal(ServerHttpRequest request, ServerHttpResponse res
80
112
HttpServletRequest servletRequest = getHttpServletRequest (request );
81
113
HttpServletResponse servletResponse = getHttpServletResponse (response );
82
114
83
- final ServletWebSocketHttpExchange exchange = new ServletWebSocketHttpExchange (servletRequest , servletResponse );
115
+ final ServletWebSocketHttpExchange exchange = createHttpExchange (servletRequest , servletResponse );
84
116
exchange .putAttachment (HandshakeUtil .PATH_PARAMS , Collections .<String , String >emptyMap ());
85
117
86
118
ServerWebSocketContainer wsContainer = (ServerWebSocketContainer ) getContainer (servletRequest );
@@ -95,13 +127,27 @@ protected void upgradeInternal(ServerHttpRequest request, ServerHttpResponse res
95
127
@ Override
96
128
public void handleUpgrade (StreamConnection connection , HttpServerExchange serverExchange ) {
97
129
WebSocketChannel channel = handshake .createChannel (exchange , connection , exchange .getBufferPool ());
130
+ if (peerConnections != null ) {
131
+ peerConnections .add (channel );
132
+ }
98
133
endpointSessionHandler .onConnect (exchange , channel );
99
134
}
100
135
});
101
136
102
137
handshake .handshake (exchange );
103
138
}
104
139
140
+ private ServletWebSocketHttpExchange createHttpExchange (HttpServletRequest request , HttpServletResponse response ) {
141
+ try {
142
+ return (this .peerConnections != null ?
143
+ exchangeConstructor .newInstance (request , response , this .peerConnections ) :
144
+ exchangeConstructor .newInstance (request , response ));
145
+ }
146
+ catch (Exception ex ) {
147
+ throw new HandshakeFailureException ("Failed to instantiate ServletWebSocketHttpExchange" , ex );
148
+ }
149
+ }
150
+
105
151
private Handshake getHandshakeToUse (ServletWebSocketHttpExchange exchange , ConfiguredServerEndpoint endpoint ) {
106
152
Handshake handshake = new JsrHybi13Handshake (endpoint );
107
153
if (handshake .matches (exchange )) {
0 commit comments