28
28
import javax .websocket .Encoder ;
29
29
import javax .websocket .Endpoint ;
30
30
import javax .websocket .Extension ;
31
+ import javax .websocket .server .ServerEndpointConfig ;
31
32
32
33
import io .undertow .server .HttpServerExchange ;
33
34
import io .undertow .server .HttpUpgradeListener ;
34
35
import io .undertow .servlet .api .InstanceFactory ;
35
36
import io .undertow .servlet .api .InstanceHandle ;
36
37
import io .undertow .servlet .websockets .ServletWebSocketHttpExchange ;
38
+ import io .undertow .util .PathTemplate ;
37
39
import io .undertow .websockets .core .WebSocketChannel ;
38
40
import io .undertow .websockets .core .WebSocketVersion ;
39
41
import io .undertow .websockets .core .protocol .Handshake ;
40
42
import io .undertow .websockets .jsr .ConfiguredServerEndpoint ;
41
43
import io .undertow .websockets .jsr .EncodingFactory ;
42
44
import io .undertow .websockets .jsr .EndpointSessionHandler ;
43
45
import io .undertow .websockets .jsr .ServerWebSocketContainer ;
46
+ import io .undertow .websockets .jsr .annotated .AnnotatedEndpointFactory ;
44
47
import io .undertow .websockets .jsr .handshake .HandshakeUtil ;
45
48
import io .undertow .websockets .jsr .handshake .JsrHybi07Handshake ;
46
49
import io .undertow .websockets .jsr .handshake .JsrHybi08Handshake ;
@@ -64,20 +67,38 @@ public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrat
64
67
65
68
private static final Constructor <ServletWebSocketHttpExchange > exchangeConstructor ;
66
69
70
+ private static final Constructor <ConfiguredServerEndpoint > endpointConstructor ;
71
+
67
72
private static final boolean undertow10Present ;
68
73
74
+ private static final boolean undertow11Present ;
75
+
69
76
static {
70
- Class <ServletWebSocketHttpExchange > type = ServletWebSocketHttpExchange .class ;
71
- Class <?>[] paramTypes = new Class <?>[] {HttpServletRequest .class , HttpServletResponse .class , Set .class };
72
- if (ClassUtils .hasConstructor (type , paramTypes )) {
73
- exchangeConstructor = ClassUtils .getConstructorIfAvailable (type , paramTypes );
77
+ Class <ServletWebSocketHttpExchange > exchangeType = ServletWebSocketHttpExchange .class ;
78
+ Class <?>[] exchangeParamTypes = new Class <?>[] {HttpServletRequest .class , HttpServletResponse .class , Set .class };
79
+ if (ClassUtils .hasConstructor (exchangeType , exchangeParamTypes )) {
80
+ exchangeConstructor = ClassUtils .getConstructorIfAvailable (exchangeType , exchangeParamTypes );
74
81
undertow10Present = false ;
75
82
}
76
83
else {
77
- paramTypes = new Class <?>[] {HttpServletRequest .class , HttpServletResponse .class };
78
- exchangeConstructor = ClassUtils .getConstructorIfAvailable (type , paramTypes );
84
+ exchangeParamTypes = new Class <?>[] {HttpServletRequest .class , HttpServletResponse .class };
85
+ exchangeConstructor = ClassUtils .getConstructorIfAvailable (exchangeType , exchangeParamTypes );
79
86
undertow10Present = true ;
80
87
}
88
+
89
+ Class <ConfiguredServerEndpoint > endpointType = ConfiguredServerEndpoint .class ;
90
+ Class <?>[] endpointParamTypes = new Class <?>[] {ServerEndpointConfig .class , InstanceFactory .class ,
91
+ PathTemplate .class , EncodingFactory .class , AnnotatedEndpointFactory .class };
92
+ if (ClassUtils .hasConstructor (endpointType , endpointParamTypes )) {
93
+ endpointConstructor = ClassUtils .getConstructorIfAvailable (endpointType , endpointParamTypes );
94
+ undertow11Present = true ;
95
+ }
96
+ else {
97
+ endpointParamTypes = new Class <?>[] {ServerEndpointConfig .class , InstanceFactory .class ,
98
+ PathTemplate .class , EncodingFactory .class };
99
+ endpointConstructor = ClassUtils .getConstructorIfAvailable (endpointType , endpointParamTypes );
100
+ undertow11Present = false ;
101
+ }
81
102
}
82
103
83
104
private static final String [] supportedVersions = new String [] {
@@ -174,12 +195,21 @@ private ConfiguredServerEndpoint createConfiguredServerEndpoint(String selectedP
174
195
endpointRegistration .setSubprotocols (Arrays .asList (selectedProtocol ));
175
196
endpointRegistration .setExtensions (selectedExtensions );
176
197
177
- return new ConfiguredServerEndpoint (endpointRegistration , new EndpointInstanceFactory (endpoint ), null ,
178
- new EncodingFactory (
179
- Collections .<Class <?>, List <InstanceFactory <? extends Encoder >>>emptyMap (),
180
- Collections .<Class <?>, List <InstanceFactory <? extends Decoder >>>emptyMap (),
181
- Collections .<Class <?>, List <InstanceFactory <? extends Encoder >>>emptyMap (),
182
- Collections .<Class <?>, List <InstanceFactory <? extends Decoder >>>emptyMap ()));
198
+ EncodingFactory encodingFactory = new EncodingFactory (
199
+ Collections .<Class <?>, List <InstanceFactory <? extends Encoder >>>emptyMap (),
200
+ Collections .<Class <?>, List <InstanceFactory <? extends Decoder >>>emptyMap (),
201
+ Collections .<Class <?>, List <InstanceFactory <? extends Encoder >>>emptyMap (),
202
+ Collections .<Class <?>, List <InstanceFactory <? extends Decoder >>>emptyMap ());
203
+ try {
204
+ return undertow11Present ?
205
+ endpointConstructor .newInstance (endpointRegistration ,
206
+ new EndpointInstanceFactory (endpoint ), null , encodingFactory , null ) :
207
+ endpointConstructor .newInstance (endpointRegistration ,
208
+ new EndpointInstanceFactory (endpoint ), null , encodingFactory );
209
+ }
210
+ catch (Exception ex ) {
211
+ throw new HandshakeFailureException ("Failed to instantiate ConfiguredServerEndpoint" , ex );
212
+ }
183
213
}
184
214
185
215
0 commit comments