blob: 55f51dd0302dd238abe5ec4c4d070d88eccf3b61 [file] [log] [blame]
/*
* Copyright (c) 2011-2018 Oracle and/or its affiliates. All rights reserved.
* Copyright 2004 The Apache Software Foundation
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.catalina.filters;
import org.apache.catalina.LogFacade;
import java.io.IOException;
import java.io.Serializable;
import java.security.SecureRandom;
import java.text.MessageFormat;
import java.util.*;
import java.util.logging.Logger;
import jakarta.servlet.FilterChain;
import jakarta.servlet.FilterConfig;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpServletResponseWrapper;
import jakarta.servlet.http.HttpSession;
/**
* Provides basic CSRF protection for a web application. The filter assumes
* that:
* <ul>
* <li>The filter is mapped to /*</li>
* <li>{@link HttpServletResponse#encodeRedirectURL(String)} and
* {@link HttpServletResponse#encodeURL(String)} are used to encode all URLs
* returned to the client
* </ul>
*/
public class CsrfPreventionFilter extends FilterBase {
protected static final Logger log = LogFacade.getLogger();
private String randomClass = SecureRandom.class.getName();
private Random randomSource;
private final Set<String> entryPoints = new HashSet<String>();
private int nonceCacheSize = 5;
@Override
protected Logger getLogger() {
return log;
}
/**
* Entry points are URLs that will not be tested for the presence of a valid
* nonce. They are used to provide a way to navigate back to a protected
* application after navigating away from it. Entry points will be limited
* to HTTP GET requests and should not trigger any security sensitive
* actions.
*
* @param entryPoints Comma separated list of URLs to be configured as
* entry points.
*/
public void setEntryPoints(String entryPoints) {
String values[] = entryPoints.split(",");
for (String value : values) {
this.entryPoints.add(value.trim());
}
}
/**
* Sets the number of previously issued nonces that will be cached on a LRU
* basis to support parallel requests, limited use of the refresh and back
* in the browser and similar behaviors that may result in the submission
* of a previous nonce rather than the current one. If not set, the default
* value of 5 will be used.
*
* @param nonceCacheSize The number of nonces to cache
*/
public void setNonceCacheSize(int nonceCacheSize) {
this.nonceCacheSize = nonceCacheSize;
}
/**
* Specify the class to use to generate the nonces. Must be in instance of
* {@link Random}.
*
* @param randomClass The name of the class to use
*/
public void setRandomClass(String randomClass) {
this.randomClass = randomClass;
}
@Override
public void init(FilterConfig filterConfig) throws ServletException {
// Set the parameters
super.init(filterConfig);
String msg = MessageFormat.format(rb.getString(LogFacade.UNABLE_CREATE_RANDOM_SOURCE_EXCEPTION), randomClass);
try {
Class<?> clazz = Class.forName(randomClass);
randomSource = (Random) clazz.newInstance();
} catch (ClassNotFoundException e) {
ServletException se = new ServletException(msg, e);
throw se;
} catch (InstantiationException e) {
ServletException se = new ServletException(msg, e);
throw se;
} catch (IllegalAccessException e) {
ServletException se = new ServletException(msg, e);
throw se;
}
}
@Override
public void doFilter(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {
ServletResponse wResponse = null;
if (request instanceof HttpServletRequest &&
response instanceof HttpServletResponse) {
HttpServletRequest req = (HttpServletRequest) request;
HttpServletResponse res = (HttpServletResponse) response;
boolean skipNonceCheck = false;
if (Constants.METHOD_GET.equals(req.getMethod())) {
String path = req.getServletPath();
if (req.getPathInfo() != null) {
path = path + req.getPathInfo();
}
if (entryPoints.contains(path)) {
skipNonceCheck = true;
}
}
HttpSession session = req.getSession(false);
@SuppressWarnings("unchecked")
LruCache<String> nonceCache = (session == null) ? null
: (LruCache<String>) session.getAttribute(
Constants.CSRF_NONCE_SESSION_ATTR_NAME);
if (!skipNonceCheck) {
String previousNonce =
req.getParameter(Constants.CSRF_NONCE_REQUEST_PARAM);
if (nonceCache == null || previousNonce == null ||
!nonceCache.contains(previousNonce)) {
res.sendError(HttpServletResponse.SC_FORBIDDEN);
return;
}
}
if (nonceCache == null) {
nonceCache = new LruCache<String>(nonceCacheSize);
if (session == null) {
session = req.getSession(true);
}
session.setAttribute(
Constants.CSRF_NONCE_SESSION_ATTR_NAME, nonceCache);
}
String newNonce = generateNonce();
nonceCache.add(newNonce);
wResponse = new CsrfResponseWrapper(res, newNonce);
} else {
wResponse = response;
}
chain.doFilter(request, wResponse);
}
@Override
protected boolean isConfigProblemFatal() {
return true;
}
/**
* Generate a once time token (nonce) for authenticating subsequent
* requests. This will also add the token to the session. The nonce
* generation is a simplified version of ManagerBase.generateSessionId().
*
*/
protected String generateNonce() {
byte random[] = new byte[16];
// Render the result as a String of hexadecimal digits
StringBuilder buffer = new StringBuilder();
randomSource.nextBytes(random);
for (int j = 0; j < random.length; j++) {
byte b1 = (byte) ((random[j] & 0xf0) >> 4);
byte b2 = (byte) (random[j] & 0x0f);
if (b1 < 10) {
buffer.append((char) ('0' + b1));
} else {
buffer.append((char) ('A' + (b1 - 10)));
}
if (b2 < 10) {
buffer.append((char) ('0' + b2));
} else {
buffer.append((char) ('A' + (b2 - 10)));
}
}
return buffer.toString();
}
protected static class CsrfResponseWrapper
extends HttpServletResponseWrapper {
private final String nonce;
public CsrfResponseWrapper(HttpServletResponse response, String nonce) {
super(response);
this.nonce = nonce;
}
@Override
public String encodeRedirectURL(String url) {
return addNonce(super.encodeRedirectURL(url));
}
@Override
public String encodeURL(String url) {
return addNonce(super.encodeURL(url));
}
/**
* Return the specified URL with the nonce added to the query string.
*
* @param url URL to be modified
*/
private String addNonce(String url) {
if ((url == null) || (nonce == null)) {
return (url);
}
String path = url;
String query = "";
String anchor = "";
int pound = path.indexOf('#');
if (pound >= 0) {
anchor = path.substring(pound);
path = path.substring(0, pound);
}
int question = path.indexOf('?');
if (question >= 0) {
query = path.substring(question);
path = path.substring(0, question);
}
StringBuilder sb = new StringBuilder(path);
if (query.length() >0) {
sb.append(query);
sb.append('&');
} else {
sb.append('?');
}
sb.append(Constants.CSRF_NONCE_REQUEST_PARAM);
sb.append('=');
sb.append(nonce);
sb.append(anchor);
return (sb.toString());
}
}
protected static class LruCache<T> implements Serializable {
private static final long serialVersionUID = 1L;
// Although the internal implementation uses a Map, this cache
// implementation is only concerned with the keys.
private final Map<T,T> cache;
public LruCache(final int cacheSize) {
cache = new FixSizeLinkedHashMap<T,T>(cacheSize);
}
public void add(T key) {
synchronized (cache) {
cache.put(key, null);
}
}
public boolean contains(T key) {
synchronized (cache) {
return cache.containsKey(key);
}
}
}
private static class FixSizeLinkedHashMap<K, V> extends LinkedHashMap<K, V> {
private static final long serialVersionUID = 1L;
private int cacheSize;
private FixSizeLinkedHashMap(int cs) {
super();
cacheSize = cs;
}
@Override
protected boolean removeEldestEntry(Map.Entry<K,V> eldest) {
if (size() > cacheSize) {
return true;
}
return false;
}
}
}