1
1
/*
2
- * Copyright 2012-2017 the original author or authors.
2
+ * Copyright 2012-2019 the original author or authors.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
26
26
import java .util .HashSet ;
27
27
import java .util .List ;
28
28
import java .util .Set ;
29
+ import java .util .function .Predicate ;
29
30
30
31
/**
31
32
* <p>
66
67
* Rejects URLs that contain a URL encoded percent. See
67
68
* {@link #setAllowUrlEncodedPercent(boolean)}
68
69
* </li>
70
+ * <li>
71
+ * Rejects hosts that are not allowed. See
72
+ * {@link #setAllowedHostnames(Predicate)}
73
+ * </li>
69
74
* </ul>
70
75
*
71
76
* @see DefaultHttpFirewall
72
77
* @author Rob Winch
78
+ * @author Eddú Meléndez
73
79
* @since 4.2.4
74
80
*/
75
81
public class StrictHttpFirewall implements HttpFirewall {
@@ -98,6 +104,8 @@ public class StrictHttpFirewall implements HttpFirewall {
98
104
99
105
private Set <String > allowedHttpMethods = createDefaultAllowedHttpMethods ();
100
106
107
+ private Predicate <String > allowedHostnames = hostname -> true ;
108
+
101
109
public StrictHttpFirewall () {
102
110
urlBlacklistsAddAll (FORBIDDEN_SEMICOLON );
103
111
urlBlacklistsAddAll (FORBIDDEN_FORWARDSLASH );
@@ -297,6 +305,13 @@ public void setAllowUrlEncodedPercent(boolean allowUrlEncodedPercent) {
297
305
}
298
306
}
299
307
308
+ public void setAllowedHostnames (Predicate <String > allowedHostnames ) {
309
+ if (allowedHostnames == null ) {
310
+ throw new IllegalArgumentException ("allowedHostnames cannot be null" );
311
+ }
312
+ this .allowedHostnames = allowedHostnames ;
313
+ }
314
+
300
315
private void urlBlacklistsAddAll (Collection <String > values ) {
301
316
this .encodedUrlBlacklist .addAll (values );
302
317
this .decodedUrlBlacklist .addAll (values );
@@ -311,6 +326,7 @@ private void urlBlacklistsRemoveAll(Collection<String> values) {
311
326
public FirewalledRequest getFirewalledRequest (HttpServletRequest request ) throws RequestRejectedException {
312
327
rejectForbiddenHttpMethod (request );
313
328
rejectedBlacklistedUrls (request );
329
+ rejectedUntrustedHosts (request );
314
330
315
331
if (!isNormalized (request )) {
316
332
throw new RequestRejectedException ("The request was rejected because the URL was not normalized." );
@@ -352,6 +368,13 @@ private void rejectedBlacklistedUrls(HttpServletRequest request) {
352
368
}
353
369
}
354
370
371
+ private void rejectedUntrustedHosts (HttpServletRequest request ) {
372
+ String serverName = request .getServerName ();
373
+ if (serverName != null && !this .allowedHostnames .test (serverName )) {
374
+ throw new RequestRejectedException ("The request was rejected because the domain " + serverName + " is untrusted." );
375
+ }
376
+ }
377
+
355
378
@ Override
356
379
public HttpServletResponse getFirewalledResponse (HttpServletResponse response ) {
357
380
return new FirewalledResponse (response );
0 commit comments