Gitlab Community Edition Instance

Commit b3467d96 authored by mhellka's avatar mhellka
Browse files

Rewrote CORS filter

parent 9348843d
Pipeline #128124 passed with stages
in 8 minutes and 55 seconds
package de.gwdg.cdstar.rest.servlet;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.Set;
import javax.servlet.DispatcherType;
import javax.servlet.Filter;
import javax.servlet.FilterRegistration;
import javax.servlet.ServletContainerInitializer;
import javax.servlet.ServletContext;
......@@ -13,6 +14,7 @@ import javax.servlet.ServletRegistration;
import javax.servlet.annotation.WebListener;
import de.gwdg.cdstar.Utils;
import de.gwdg.cdstar.rest.utils.SessionHelper;
import de.gwdg.cdstar.runtime.RuntimeContext;
@WebListener
......@@ -43,14 +45,39 @@ public class CDStarServletInitializer implements ServletContainerInitializer {
servlet.addMapping("/*");
servlet.setAsyncSupported(true);
addFilter(ctx, "CDSTAR: Method Override Filter", new MethodOverrideFilter());
addFilter(ctx, "CDSTAR: CORS Filter", new CORSFilter());
setupMethodOverride(ctx);
setupCORS(ctx);
}
private void addFilter(ServletContext ctx, String name, Filter filter) {
final FilterRegistration.Dynamic conf = ctx.addFilter(name, filter);
private void setupMethodOverride(ServletContext ctx) {
FilterRegistration.Dynamic conf = ctx.addFilter("CDSTAR: Method Override Filter", MethodOverrideFilter.class);
conf.setAsyncSupported(true);
conf.addMappingForServletNames(EnumSet.of(DispatcherType.REQUEST), true, CDSTAR_SERVLET_NAME);
}
private void setupCORS(ServletContext ctx) {
Set<String> allow = new HashSet<>();
Set<String> expose = new HashSet<>();
// Default headers
allow.addAll(Arrays.asList("Cache-Control", "Content-Type", "Authorization"));
allow.add(SessionHelper.HEADER_TRANSACTION);
allow.add(MethodOverrideFilter.HTTP_OVERRIDE_HEADER);
expose.add("Location");
// TUS headers (make this configurable)
allow.addAll(Arrays.asList("Upload-Offset", "Upload-Length", "Tus-Resumable", "Upload-Metadata"));
expose.addAll(Arrays.asList("Upload-Offset", "Upload-Length", "Tus-Version", "Tus-Resumable", "Tus-Max-Size", "Tus-Extension", "Upload-Metadata"));
// Add CORS filter
FilterRegistration.Dynamic conf = ctx.addFilter("CDSTAR: CORS Filter", CORSFilter.class);
conf.setAsyncSupported(true);
conf.addMappingForServletNames(EnumSet.of(DispatcherType.REQUEST), true, CDSTAR_SERVLET_NAME);
conf.setInitParameter(CORSFilter.ALLOWED_ORIGINS_PARAM, "*");
conf.setInitParameter(CORSFilter.ALLOWED_HEADERS_PARAM, String.join(",", allow));
conf.setInitParameter(CORSFilter.EXPOSED_HEADERS_PARAM, String.join(",", expose));
conf.setInitParameter(CORSFilter.ALLOWED_METHODS_PARAM, "HEAD,GET,POST,PUT,DELETE,PATCH");
conf.setInitParameter(CORSFilter.PREFLIGHT_MAX_AGE_PARAM, "3600");
}
}
package de.gwdg.cdstar.rest.servlet;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.regex.Pattern;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
......@@ -19,98 +24,234 @@ import javax.servlet.http.HttpServletResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import de.gwdg.cdstar.Utils;
import de.gwdg.cdstar.rest.utils.SessionHelper;
/**
* Filter to add CORS simple-request headers and handle CORS preflight requests.
*
* By default, allow all origins. The Access-Control-Allow-Credentials header is
* NOT set, so clients are forced to set the Authorization header manually and
* cannot used browser-cached credentials to access a CDSTAR instance.
*/
public class CORSFilter implements Filter {
private static final Logger log = LoggerFactory.getLogger(CORSFilter.class);
private String allowHeaders;
private String exposeHeaders;
private String allowMethods;
private String maxAge;
public static final String ALLOW_CREDENTIALS_PARAM = "allowCredentials";
public static final String EXPOSED_HEADERS_PARAM = "exposedHeaders";
public static final String PREFLIGHT_MAX_AGE_PARAM = "preflightMaxAge";
public static final String ALLOWED_HEADERS_PARAM = "allowedHeaders";
public static final String ALLOWED_METHODS_PARAM = "allowedMethods";
public static final String ALLOWED_ORIGINS_PARAM = "allowedOrigins";
// Request headers
private static final String ORIGIN_HEADER = "Origin";
private static final String ACCESS_CONTROL_REQUEST_METHOD_HEADER = "Access-Control-Request-Method";
private static final String ACCESS_CONTROL_REQUEST_HEADERS_HEADER = "Access-Control-Request-Headers";
// Response headers
private static final String ACCESS_CONTROL_ALLOW_ORIGIN_HEADER = "Access-Control-Allow-Origin";
private static final String ACCESS_CONTROL_ALLOW_METHODS_HEADER = "Access-Control-Allow-Methods";
private static final String ACCESS_CONTROL_ALLOW_HEADERS_HEADER = "Access-Control-Allow-Headers";
private static final String ACCESS_CONTROL_MAX_AGE_HEADER = "Access-Control-Max-Age";
private static final String ACCESS_CONTROL_ALLOW_CREDENTIALS_HEADER = "Access-Control-Allow-Credentials";
private static final String ACCESS_CONTROL_EXPOSE_HEADERS_HEADER = "Access-Control-Expose-Headers";
private static final Set<String> SAVE_HEADERS = new HashSet<>(
Arrays.asList("accept", "accept-language", "content-language", "content-fype"));
@SuppressWarnings("unused")
private static final Set<String> AUTO_EXPOSED = new HashSet<>(
Arrays.asList("cache-control", "content-language", "content-length", "content-zype", "expires",
"last-modified", "pragma"));
private boolean allOriginsAllowed;
private Set<String> allowedOrigins = new HashSet<>(0);
private Set<Pattern> allowedOriginPatterns = new HashSet<>(0);
private String allowedMethodsHeader;
private boolean allMethodsAllowed;
private Set<String> allowedMethods = new HashSet<>(8);
private String allowedHeadersHeader;
private boolean allHeadersAllowed;
private Set<String> allowedHeaders = new HashSet<>(0);
private boolean allowCredentials = false;
private String preflightMaxAge;
private String exposedHeadersHeader;
private String getRequiredInitParam(FilterConfig config, String name) {
return Objects.requireNonNull(config.getInitParameter(name), "Init parameter "+name+ "required");
}
@Override
public void init(FilterConfig arg0) throws ServletException {
final Set<String> allow = new HashSet<>();
final Set<String> expose = new HashSet<>();
allow.addAll(Arrays.asList("Cache-Control", "Content-Type", "Authorization", "X-HTTP-Method-Override"));
allow.add(SessionHelper.HEADER_TRANSACTION);
expose.addAll(Arrays.asList("Location"));
// TODO: Find a way for plugins to add their headers and settings to this filter
final List<String> tusHeaders = Arrays.asList("Upload-Offset", "Upload-Length", "Tus-Version", "Tus-Resumable",
"Tus-Extension", "Tus-Max-Size", "Upload-Metadata");
allow.addAll(tusHeaders);
expose.addAll(tusHeaders);
setExposeHeaders(expose);
setAllowHeaders(allow);
setAllowMethods(Arrays.asList("HEAD", "GET", "POST", "PUT", "DELETE", "PATCH"));
setMaxAge(3600);
public void init(FilterConfig config) throws ServletException {
// Access-Control-Allow-Origin
for (String origin : hlistSplit(getRequiredInitParam(config, ALLOWED_ORIGINS_PARAM))) {
if (origin.equals("*")) {
allOriginsAllowed = true;
} else if (origin.startsWith("^") && origin.endsWith("$")) {
Pattern rx = Pattern.compile(origin);
allowedOriginPatterns.add(rx);
} else {
allowedOrigins.add(origin);
}
}
// Access-Control-Allow-Methods
for (String method : hlistSplit(getRequiredInitParam(config, ALLOWED_METHODS_PARAM))) {
if (method.equals("*")) {
allMethodsAllowed = true;
} else {
allowedMethods.add(method.toUpperCase());
}
}
allowedMethodsHeader = allMethodsAllowed ? "*" : String.join(", ", allowedMethods);
// Access-Control-Allow-Headers
for (String header : hlistSplit(getRequiredInitParam(config, ALLOWED_HEADERS_PARAM))) {
if (header.equals("*")) {
allHeadersAllowed = true;
} else {
allowedHeaders.add(header);
}
}
allowedHeadersHeader = allHeadersAllowed ? "*" : String.join(", ", allowedHeaders);
// Some headers are always allowed, but NOT included in the
// Access-Control-Allow-Headers by default, as that would bypass
// default value restrictions for these headers. Allow them explicitly
// to remove the default restrictions.
allowedHeaders.addAll(SAVE_HEADERS);
// Access-Control-Allow-Credentials
allowCredentials = "true".equalsIgnoreCase(config.getInitParameter(ALLOW_CREDENTIALS_PARAM));
// Access-Control-Expose-Headers
exposedHeadersHeader = config.getInitParameter(EXPOSED_HEADERS_PARAM);
if (exposedHeadersHeader != null)
exposedHeadersHeader = String.join(", ", new HashSet<String>(hlistSplit(exposedHeadersHeader)));
// Access-Control-MAx-Age
preflightMaxAge = config.getInitParameter(PREFLIGHT_MAX_AGE_PARAM);
if (preflightMaxAge == null)
preflightMaxAge = "1800"; // Default is 30 minutes
}
private void setAllowHeaders(Collection<String> values) {
allowHeaders = String.join(", ", values);
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
handle((HttpServletRequest) request, (HttpServletResponse) response, chain);
}
private void setExposeHeaders(Collection<String> values) {
exposeHeaders = String.join(", ", values);
private void handle(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
String origin = request.getHeader(ORIGIN_HEADER);
if (origin == null) {
// Not a CORS request
chain.doFilter(request, response);
return;
}
if (!originAllowed(origin)) {
handleError(request, response, 403, "CORS failed. Origin not allowed.");
return;
}
if (isPreflightRequest(request)) {
handlePreflightResponse(request, response, origin);
} else {
handleSimpleRequest(request, response, origin);
chain.doFilter(request, response);
}
}
private void setAllowMethods(Collection<String> values) {
allowMethods = String.join(", ", values);
private boolean originAllowed(String origin) {
if (allOriginsAllowed)
return true;
origin = origin.toLowerCase();
if (allowedOrigins.contains(origin))
return true;
for (Pattern p : allowedOriginPatterns)
if (p.matcher(origin).matches())
return true;
return false;
}
private void setMaxAge(long seconds) {
maxAge = Long.toString(Math.max(60, seconds));
private boolean isPreflightRequest(HttpServletRequest request) {
return "OPTIONS".equalsIgnoreCase(request.getMethod())
&& request.getHeader(ACCESS_CONTROL_REQUEST_METHOD_HEADER) != null;
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
final HttpServletRequest rq = (HttpServletRequest) request;
final HttpServletResponse rs = (HttpServletResponse) response;
final String origin = rq.getHeader("Origin");
if (origin != null) {
if (log.isDebugEnabled())
log.debug("CORS request (origin={}, client={}, method={}, path={})",
Utils.repr(origin), rq.getRemoteAddr(), rq.getMethod(), rq.getPathInfo());
// CORS pre-flight request
if ("OPTIONS".equals(rq.getMethod()) &&
rq.getHeader("Access-Control-Request-Method") != null) {
rs.setHeader("Access-Control-Allow-Origin", "*");
rs.setHeader("Access-Control-Max-Age", maxAge);
rs.setHeader("Access-Control-Allow-Methods", allowMethods);
rs.setHeader("Access-Control-Allow-Headers", allowHeaders);
rs.setStatus(200);
return;
private void handlePreflightResponse(HttpServletRequest request, HttpServletResponse response, String origin) {
String accessControlRequestMethod = request.getHeader(ACCESS_CONTROL_REQUEST_METHOD_HEADER).toUpperCase();
if (!(allMethodsAllowed || allowedMethods.contains(accessControlRequestMethod))) {
log.debug("Preflight failed. Method not allowed: {} {}", accessControlRequestMethod, allowedMethods);
handleError(request, response, 403, "CORS failed. Requested method not allowed.");
return;
}
String accessControlRequestHeaders = request.getHeader(ACCESS_CONTROL_REQUEST_HEADERS_HEADER);
if (!allHeadersAllowed && accessControlRequestHeaders != null) {
for (String header : hlistSplit(accessControlRequestHeaders)) {
if (!(allowedHeaders.contains(header))) {
log.debug("Preflight failed. Header not allowed: {}", accessControlRequestHeaders);
handleError(request, response, 403, "CORS failed: Requested header not allowed.");
return;
}
}
}
response.setHeader(ACCESS_CONTROL_ALLOW_ORIGIN_HEADER, origin);
response.addHeader("Vary", ORIGIN_HEADER);
response.setHeader(ACCESS_CONTROL_MAX_AGE_HEADER, preflightMaxAge);
response.setHeader(ACCESS_CONTROL_ALLOW_METHODS_HEADER, allowedMethodsHeader);
response.setHeader(ACCESS_CONTROL_ALLOW_HEADERS_HEADER, allowedHeadersHeader);
if (allowCredentials)
response.setHeader(ACCESS_CONTROL_ALLOW_CREDENTIALS_HEADER, "true");
}
private void handleSimpleRequest(HttpServletRequest request, HttpServletResponse response, String origin) {
response.setHeader(ACCESS_CONTROL_ALLOW_ORIGIN_HEADER, origin);
response.addHeader("Vary", ORIGIN_HEADER);
if (allowCredentials)
response.setHeader(ACCESS_CONTROL_ALLOW_CREDENTIALS_HEADER, "true");
if (exposedHeadersHeader != null)
response.setHeader(ACCESS_CONTROL_EXPOSE_HEADERS_HEADER, exposedHeadersHeader);
}
// CORS simple or preflight-aproved request
rs.setHeader("Access-Control-Allow-Origin", "*");
rs.setHeader("Access-Control-Expose-Headers", exposeHeaders);
private void handleError(HttpServletRequest request, HttpServletResponse response, int status, String string) {
response.setStatus(status);
response.setContentType("text/plain");
try (ServletOutputStream body = response.getOutputStream()) {
body.write("CORS failed. Requested method not allowed.\n".getBytes(StandardCharsets.US_ASCII));
} catch (IOException e) {
// Who cares?
}
}
chain.doFilter(request, response);
/**
* Same as {@code new ArrayList(str.trim().toLowerCase().split("(,|\\s)+"));}
* but faster. Also, returns an empty list if the parameter is null.
*/
private List<String> hlistSplit(String str) {
if (str == null)
return Collections.emptyList();
List<String> results = new ArrayList<>(8);
StringBuilder sb = new StringBuilder(str.length());
for (int i = 0; i < str.length(); i++) {
char c = str.charAt(i);
if (c == ',' || Character.isWhitespace(c)) {
if (sb.length() > 0) {
results.add(sb.toString());
sb.setLength(0);
}
continue;
} else {
sb.append(Character.toLowerCase(c));
}
}
if (sb.length() > 0)
results.add(sb.toString());
return results;
}
@Override
public void destroy() {
}
}
......@@ -22,7 +22,7 @@ public class CORSFilterTest extends BaseRestTest {
public void testSimpleRequestContainsCorsHeader() throws Exception {
target("/v3/test/").request().header("Origin", "https://example.com").get();
assertStatus(Status.OK);
assertHeaderEquals("Access-Control-Allow-Origin", "*");
assertHeaderEquals("Access-Control-Allow-Origin", "https://example.com");
}
@Test
......@@ -32,7 +32,7 @@ public class CORSFilterTest extends BaseRestTest {
.header("Access-Control-Request-Method", "DELETE")
.options();
assertStatus(Status.OK);
assertHeaderEquals("Access-Control-Allow-Origin", "*");
assertHeaderEquals("Access-Control-Allow-Origin", "https://example.com");
assertTrue(Utils.notNullOrEmpty(getLastResponse().getHeaderString("Access-Control-Max-Age")));
assertTrue(Utils.notNullOrEmpty(getLastResponse().getHeaderString("Access-Control-Allow-Methods")));
assertTrue(Utils.notNullOrEmpty(getLastResponse().getHeaderString("Access-Control-Allow-Headers")));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment