blob: a5dab447b6b912a468445d37e2a65c7ed9eee436 [file] [log] [blame]
//
// ========================================================================
// Copyright (c) 1995-2017 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.servlets;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.servlet.AsyncContext;
import javax.servlet.AsyncEvent;
import javax.servlet.AsyncListener;
import javax.servlet.DispatcherType;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import javax.servlet.http.HttpSessionActivationListener;
import javax.servlet.http.HttpSessionBindingEvent;
import javax.servlet.http.HttpSessionBindingListener;
import javax.servlet.http.HttpSessionEvent;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.server.handler.ContextHandler;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.annotation.ManagedAttribute;
import org.eclipse.jetty.util.annotation.ManagedObject;
import org.eclipse.jetty.util.annotation.ManagedOperation;
import org.eclipse.jetty.util.annotation.Name;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler;
import org.eclipse.jetty.util.thread.Scheduler;
/**
* Denial of Service filter
* <p>
* This filter is useful for limiting
* exposure to abuse from request flooding, whether malicious, or as a result of
* a misconfigured client.
* <p>
* The filter keeps track of the number of requests from a connection per
* second. If a limit is exceeded, the request is either rejected, delayed, or
* throttled.
* <p>
* When a request is throttled, it is placed in a priority queue. Priority is
* given first to authenticated users and users with an HttpSession, then
* connections which can be identified by their IP addresses. Connections with
* no way to identify them are given lowest priority.
* <p>
* The {@link #extractUserId(ServletRequest request)} function should be
* implemented, in order to uniquely identify authenticated users.
* <p>
* The following init parameters control the behavior of the filter:
* <dl>
* <dt>maxRequestsPerSec</dt>
* <dd>the maximum number of requests from a connection per
* second. Requests in excess of this are first delayed,
* then throttled.</dd>
* <dt>delayMs</dt>
* <dd>is the delay given to all requests over the rate limit,
* before they are considered at all. -1 means just reject request,
* 0 means no delay, otherwise it is the delay.</dd>
* <dt>maxWaitMs</dt>
* <dd>how long to blocking wait for the throttle semaphore.</dd>
* <dt>throttledRequests</dt>
* <dd>is the number of requests over the rate limit able to be
* considered at once.</dd>
* <dt>throttleMs</dt>
* <dd>how long to async wait for semaphore.</dd>
* <dt>maxRequestMs</dt>
* <dd>how long to allow this request to run.</dd>
* <dt>maxIdleTrackerMs</dt>
* <dd>how long to keep track of request rates for a connection,
* before deciding that the user has gone away, and discarding it</dd>
* <dt>insertHeaders</dt>
* <dd>if true , insert the DoSFilter headers into the response. Defaults to true.</dd>
* <dt>trackSessions</dt>
* <dd>if true, usage rate is tracked by session if a session exists. Defaults to true.</dd>
* <dt>remotePort</dt>
* <dd>if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.</dd>
* <dt>ipWhitelist</dt>
* <dd>a comma-separated list of IP addresses that will not be rate limited</dd>
* <dt>managedAttr</dt>
* <dd>if set to true, then this servlet is set as a {@link ServletContext} attribute with the
* filter name as the attribute name. This allows context external mechanism (eg JMX via {@link ContextHandler#MANAGED_ATTRIBUTES}) to
* manage the configuration of the filter.</dd>
* <dt>tooManyCode</dt>
* <dd>The status code to send if there are too many requests. By default is 429 (too many requests), but 503 (Unavailable) is
* another option</dd>
* </dl>
* <p>
* This filter should be configured for {@link DispatcherType#REQUEST} and {@link DispatcherType#ASYNC} and with
* <code>&lt;async-supported&gt;true&lt;/async-supported&gt;</code>.
* </p>
*/
@ManagedObject("limits exposure to abuse from request flooding, whether malicious, or as a result of a misconfigured client")
public class DoSFilter implements Filter
{
private static final Logger LOG = Log.getLogger(DoSFilter.class);
private static final String IPv4_GROUP = "(\\d{1,3})";
private static final Pattern IPv4_PATTERN = Pattern.compile(IPv4_GROUP+"\\."+IPv4_GROUP+"\\."+IPv4_GROUP+"\\."+IPv4_GROUP);
private static final String IPv6_GROUP = "(\\p{XDigit}{1,4})";
private static final Pattern IPv6_PATTERN = Pattern.compile(IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP);
private static final Pattern CIDR_PATTERN = Pattern.compile("([^/]+)/(\\d+)");
private static final String __TRACKER = "DoSFilter.Tracker";
private static final String __THROTTLED = "DoSFilter.Throttled";
private static final int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
private static final int __DEFAULT_DELAY_MS = 100;
private static final int __DEFAULT_THROTTLE = 5;
private static final int __DEFAULT_MAX_WAIT_MS = 50;
private static final long __DEFAULT_THROTTLE_MS = 30000L;
private static final long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM = 30000L;
private static final long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM = 30000L;
static final String MANAGED_ATTR_INIT_PARAM = "managedAttr";
static final String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
static final String DELAY_MS_INIT_PARAM = "delayMs";
static final String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
static final String MAX_WAIT_INIT_PARAM = "maxWaitMs";
static final String THROTTLE_MS_INIT_PARAM = "throttleMs";
static final String MAX_REQUEST_MS_INIT_PARAM = "maxRequestMs";
static final String MAX_IDLE_TRACKER_MS_INIT_PARAM = "maxIdleTrackerMs";
static final String INSERT_HEADERS_INIT_PARAM = "insertHeaders";
static final String TRACK_SESSIONS_INIT_PARAM = "trackSessions";
static final String REMOTE_PORT_INIT_PARAM = "remotePort";
static final String IP_WHITELIST_INIT_PARAM = "ipWhitelist";
static final String ENABLED_INIT_PARAM = "enabled";
static final String TOO_MANY_CODE = "tooManyCode";
private static final int USER_AUTH = 2;
private static final int USER_SESSION = 2;
private static final int USER_IP = 1;
private static final int USER_UNKNOWN = 0;
private final String _suspended = "DoSFilter@" + Integer.toHexString(hashCode()) + ".SUSPENDED";
private final String _resumed = "DoSFilter@" + Integer.toHexString(hashCode()) + ".RESUMED";
private final ConcurrentHashMap<String, RateTracker> _rateTrackers = new ConcurrentHashMap<>();
private final List<String> _whitelist = new CopyOnWriteArrayList<>();
private int _tooManyCode;
private volatile long _delayMs;
private volatile long _throttleMs;
private volatile long _maxWaitMs;
private volatile long _maxRequestMs;
private volatile long _maxIdleTrackerMs;
private volatile boolean _insertHeaders;
private volatile boolean _trackSessions;
private volatile boolean _remotePort;
private volatile boolean _enabled;
private Semaphore _passes;
private volatile int _throttledRequests;
private volatile int _maxRequestsPerSec;
private Queue<AsyncContext>[] _queues;
private AsyncListener[] _listeners;
private Scheduler _scheduler;
public void init(FilterConfig filterConfig) throws ServletException
{
_queues = new Queue[getMaxPriority() + 1];
_listeners = new AsyncListener[_queues.length];
for (int p = 0; p < _queues.length; p++)
{
_queues[p] = new ConcurrentLinkedQueue<>();
_listeners[p] = new DoSAsyncListener(p);
}
_rateTrackers.clear();
int maxRequests = __DEFAULT_MAX_REQUESTS_PER_SEC;
String parameter = filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM);
if (parameter != null)
maxRequests = Integer.parseInt(parameter);
setMaxRequestsPerSec(maxRequests);
long delay = __DEFAULT_DELAY_MS;
parameter = filterConfig.getInitParameter(DELAY_MS_INIT_PARAM);
if (parameter != null)
delay = Long.parseLong(parameter);
setDelayMs(delay);
int throttledRequests = __DEFAULT_THROTTLE;
parameter = filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM);
if (parameter != null)
throttledRequests = Integer.parseInt(parameter);
setThrottledRequests(throttledRequests);
long maxWait = __DEFAULT_MAX_WAIT_MS;
parameter = filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM);
if (parameter != null)
maxWait = Long.parseLong(parameter);
setMaxWaitMs(maxWait);
long throttle = __DEFAULT_THROTTLE_MS;
parameter = filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM);
if (parameter != null)
throttle = Long.parseLong(parameter);
setThrottleMs(throttle);
long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
parameter = filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM);
if (parameter != null)
maxRequestMs = Long.parseLong(parameter);
setMaxRequestMs(maxRequestMs);
long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
parameter = filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM);
if (parameter != null)
maxIdleTrackerMs = Long.parseLong(parameter);
setMaxIdleTrackerMs(maxIdleTrackerMs);
String whiteList = "";
parameter = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
if (parameter != null)
whiteList = parameter;
setWhitelist(whiteList);
parameter = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
setInsertHeaders(parameter == null || Boolean.parseBoolean(parameter));
parameter = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
setTrackSessions(parameter == null || Boolean.parseBoolean(parameter));
parameter = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
setRemotePort(parameter != null && Boolean.parseBoolean(parameter));
parameter = filterConfig.getInitParameter(ENABLED_INIT_PARAM);
setEnabled(parameter == null || Boolean.parseBoolean(parameter));
parameter = filterConfig.getInitParameter(TOO_MANY_CODE);
setTooManyCode(parameter==null?429:Integer.parseInt(parameter));
_scheduler = startScheduler();
ServletContext context = filterConfig.getServletContext();
if (context != null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
context.setAttribute(filterConfig.getFilterName(), this);
}
protected Scheduler startScheduler() throws ServletException
{
try
{
Scheduler result = new ScheduledExecutorScheduler();
result.start();
return result;
}
catch (Exception x)
{
throw new ServletException(x);
}
}
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException
{
doFilter((HttpServletRequest)request, (HttpServletResponse)response, filterChain);
}
protected void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException
{
if (!isEnabled())
{
filterChain.doFilter(request, response);
return;
}
// Look for the rate tracker for this request.
RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
if (tracker == null)
{
// This is the first time we have seen this request.
if (LOG.isDebugEnabled())
LOG.debug("Filtering {}", request);
// Get a rate tracker associated with this request, and record one hit.
tracker = getRateTracker(request);
// Calculate the rate and check it is over the allowed limit
final boolean overRateLimit = tracker.isRateExceeded(System.currentTimeMillis());
// Pass it through if we are not currently over the rate limit.
if (!overRateLimit)
{
if (LOG.isDebugEnabled())
LOG.debug("Allowing {}", request);
doFilterChain(filterChain, request, response);
return;
}
// We are over the limit.
// So either reject it, delay it or throttle it.
long delayMs = getDelayMs();
boolean insertHeaders = isInsertHeaders();
switch ((int)delayMs)
{
case -1:
{
// Reject this request.
LOG.warn("DOS ALERT: Request rejected ip={}, session={}, user={}", request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
if (insertHeaders)
response.addHeader("DoSFilter", "unavailable");
response.sendError(getTooManyCode());
return;
}
case 0:
{
// Fall through to throttle the request.
LOG.warn("DOS ALERT: Request throttled ip={}, session={}, user={}", request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
request.setAttribute(__TRACKER, tracker);
break;
}
default:
{
// Insert a delay before throttling the request,
// using the suspend+timeout mechanism of AsyncContext.
LOG.warn("DOS ALERT: Request delayed={}ms, ip={}, session={}, user={}", delayMs, request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
if (insertHeaders)
response.addHeader("DoSFilter", "delayed");
request.setAttribute(__TRACKER, tracker);
AsyncContext asyncContext = request.startAsync();
if (delayMs > 0)
asyncContext.setTimeout(delayMs);
asyncContext.addListener(new DoSTimeoutAsyncListener());
return;
}
}
}
if (LOG.isDebugEnabled())
LOG.debug("Throttling {}", request);
// Throttle the request.
boolean accepted = false;
try
{
// Check if we can afford to accept another request at this time.
accepted = _passes.tryAcquire(getMaxWaitMs(), TimeUnit.MILLISECONDS);
if (!accepted)
{
// We were not accepted, so either we suspend to wait,
// or if we were woken up we insist or we fail.
Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
long throttleMs = getThrottleMs();
if (throttled != Boolean.TRUE && throttleMs > 0)
{
int priority = getPriority(request, tracker);
request.setAttribute(__THROTTLED, Boolean.TRUE);
if (isInsertHeaders())
response.addHeader("DoSFilter", "throttled");
AsyncContext asyncContext = request.startAsync();
request.setAttribute(_suspended, Boolean.TRUE);
if (throttleMs > 0)
asyncContext.setTimeout(throttleMs);
asyncContext.addListener(_listeners[priority]);
_queues[priority].add(asyncContext);
if (LOG.isDebugEnabled())
LOG.debug("Throttled {}, {}ms", request, throttleMs);
return;
}
Boolean resumed = (Boolean)request.getAttribute(_resumed);
if (resumed == Boolean.TRUE)
{
// We were resumed, we wait for the next pass.
_passes.acquire();
accepted = true;
}
}
// If we were accepted (either immediately or after throttle)...
if (accepted)
{
// ...call the chain.
if (LOG.isDebugEnabled())
LOG.debug("Allowing {}", request);
doFilterChain(filterChain, request, response);
}
else
{
// ...otherwise fail the request.
if (LOG.isDebugEnabled())
LOG.debug("Rejecting {}", request);
if (isInsertHeaders())
response.addHeader("DoSFilter", "unavailable");
response.sendError(getTooManyCode());
}
}
catch (InterruptedException e)
{
LOG.ignore(e);
response.sendError(getTooManyCode());
}
finally
{
if (accepted)
{
try
{
// Wake up the next highest priority request.
for (int p = _queues.length - 1; p >= 0; --p)
{
AsyncContext asyncContext = _queues[p].poll();
if (asyncContext != null)
{
ServletRequest candidate = asyncContext.getRequest();
Boolean suspended = (Boolean)candidate.getAttribute(_suspended);
if (suspended == Boolean.TRUE)
{
if (LOG.isDebugEnabled())
LOG.debug("Resuming {}", request);
candidate.setAttribute(_resumed, Boolean.TRUE);
asyncContext.dispatch();
break;
}
}
}
}
finally
{
_passes.release();
}
}
}
}
protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response) throws IOException, ServletException
{
final Thread thread = Thread.currentThread();
Runnable requestTimeout = new Runnable()
{
@Override
public void run()
{
closeConnection(request, response, thread);
}
};
Scheduler.Task task = _scheduler.schedule(requestTimeout, getMaxRequestMs(), TimeUnit.MILLISECONDS);
try
{
chain.doFilter(request, response);
}
finally
{
task.cancel();
}
}
/**
* Invoked when the request handling exceeds {@link #getMaxRequestMs()}.
* <p>
* By default, a HTTP 503 response is returned and the handling thread is interrupted.
*
* @param request the current request
* @param response the current response
* @param handlingThread the handling thread
*/
protected void onRequestTimeout(HttpServletRequest request, HttpServletResponse response, Thread handlingThread)
{
try
{
if (LOG.isDebugEnabled())
LOG.debug("Timing out {}", request);
response.sendError(HttpStatus.SERVICE_UNAVAILABLE_503);
}
catch (Throwable x)
{
LOG.info(x);
}
handlingThread.interrupt();
}
/**
* @deprecated use {@link #onRequestTimeout(HttpServletRequest, HttpServletResponse, Thread)} instead
*/
@Deprecated
protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
{
onRequestTimeout(request, response, thread);
}
/**
* Get priority for this request, based on user type
*
* @param request the current request
* @param tracker the rate tracker for this request
* @return the priority for this request
*/
protected int getPriority(HttpServletRequest request, RateTracker tracker)
{
if (extractUserId(request) != null)
return USER_AUTH;
if (tracker != null)
return tracker.getType();
return USER_UNKNOWN;
}
/**
* @return the maximum priority that we can assign to a request
*/
protected int getMaxPriority()
{
return USER_AUTH;
}
/**
* Return a request rate tracker associated with this connection; keeps
* track of this connection's request rate. If this is not the first request
* from this connection, return the existing object with the stored stats.
* If it is the first request, then create a new request tracker.
* <p>
* Assumes that each connection has an identifying characteristic, and goes
* through them in order, taking the first that matches: user id (logged
* in), session id, client IP address. Unidentifiable connections are lumped
* into one.
* <p>
* When a session expires, its rate tracker is automatically deleted.
*
* @param request the current request
* @return the request rate tracker for the current connection
*/
public RateTracker getRateTracker(ServletRequest request)
{
HttpSession session = ((HttpServletRequest)request).getSession(false);
String loadId = extractUserId(request);
final int type;
if (loadId != null)
{
type = USER_AUTH;
}
else
{
if (isTrackSessions() && session != null && !session.isNew())
{
loadId = session.getId();
type = USER_SESSION;
}
else
{
loadId = isRemotePort() ? (request.getRemoteAddr() + request.getRemotePort()) : request.getRemoteAddr();
type = USER_IP;
}
}
RateTracker tracker = _rateTrackers.get(loadId);
if (tracker == null)
{
boolean allowed = checkWhitelist(request.getRemoteAddr());
int maxRequestsPerSec = getMaxRequestsPerSec();
tracker = allowed ? new FixedRateTracker(loadId, type, maxRequestsPerSec)
: new RateTracker(loadId, type, maxRequestsPerSec);
RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker);
if (existing != null)
tracker = existing;
if (type == USER_IP)
{
// USER_IP expiration from _rateTrackers is handled by the _scheduler
_scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS);
}
else if (session != null)
{
// USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener
session.setAttribute(__TRACKER, tracker);
}
}
return tracker;
}
protected boolean checkWhitelist(String candidate)
{
for (String address : _whitelist)
{
if (address.contains("/"))
{
if (subnetMatch(address, candidate))
return true;
}
else
{
if (address.equals(candidate))
return true;
}
}
return false;
}
@Deprecated
protected boolean checkWhitelist(List<String> whitelist, String candidate)
{
for (String address : whitelist)
{
if (address.contains("/"))
{
if (subnetMatch(address, candidate))
return true;
}
else
{
if (address.equals(candidate))
return true;
}
}
return false;
}
protected boolean subnetMatch(String subnetAddress, String address)
{
Matcher cidrMatcher = CIDR_PATTERN.matcher(subnetAddress);
if (!cidrMatcher.matches())
return false;
String subnet = cidrMatcher.group(1);
int prefix;
try
{
prefix = Integer.parseInt(cidrMatcher.group(2));
}
catch (NumberFormatException x)
{
LOG.info("Ignoring malformed CIDR address {}", subnetAddress);
return false;
}
byte[] subnetBytes = addressToBytes(subnet);
if (subnetBytes == null)
{
LOG.info("Ignoring malformed CIDR address {}", subnetAddress);
return false;
}
byte[] addressBytes = addressToBytes(address);
if (addressBytes == null)
{
LOG.info("Ignoring malformed remote address {}", address);
return false;
}
// Comparing IPv4 with IPv6 ?
int length = subnetBytes.length;
if (length != addressBytes.length)
return false;
byte[] mask = prefixToBytes(prefix, length);
for (int i = 0; i < length; ++i)
{
if ((subnetBytes[i] & mask[i]) != (addressBytes[i] & mask[i]))
return false;
}
return true;
}
private byte[] addressToBytes(String address)
{
Matcher ipv4Matcher = IPv4_PATTERN.matcher(address);
if (ipv4Matcher.matches())
{
byte[] result = new byte[4];
for (int i = 0; i < result.length; ++i)
result[i] = Integer.valueOf(ipv4Matcher.group(i + 1)).byteValue();
return result;
}
else
{
Matcher ipv6Matcher = IPv6_PATTERN.matcher(address);
if (ipv6Matcher.matches())
{
byte[] result = new byte[16];
for (int i = 0; i < result.length; i += 2)
{
int word = Integer.valueOf(ipv6Matcher.group(i / 2 + 1), 16);
result[i] = (byte)((word & 0xFF00) >>> 8);
result[i + 1] = (byte)(word & 0xFF);
}
return result;
}
}
return null;
}
private byte[] prefixToBytes(int prefix, int length)
{
byte[] result = new byte[length];
int index = 0;
while (prefix / 8 > 0)
{
result[index] = -1;
prefix -= 8;
++index;
}
if (index == result.length)
return result;
// Sets the _prefix_ most significant bits to 1
result[index] = (byte)~((1 << (8 - prefix)) - 1);
return result;
}
public void destroy()
{
LOG.debug("Destroy {}",this);
stopScheduler();
_rateTrackers.clear();
_whitelist.clear();
}
protected void stopScheduler()
{
try
{
_scheduler.stop();
}
catch (Exception x)
{
LOG.ignore(x);
}
}
/**
* Returns the user id, used to track this connection.
* This SHOULD be overridden by subclasses.
*
* @param request the current request
* @return a unique user id, if logged in; otherwise null.
*/
protected String extractUserId(ServletRequest request)
{
return null;
}
/**
* Get maximum number of requests from a connection per
* second. Requests in excess of this are first delayed,
* then throttled.
*
* @return maximum number of requests
*/
@ManagedAttribute("maximum number of requests allowed from a connection per second")
public int getMaxRequestsPerSec()
{
return _maxRequestsPerSec;
}
/**
* Get maximum number of requests from a connection per
* second. Requests in excess of this are first delayed,
* then throttled.
*
* @param value maximum number of requests
*/
public void setMaxRequestsPerSec(int value)
{
_maxRequestsPerSec = value;
}
/**
* Get delay (in milliseconds) that is applied to all requests
* over the rate limit, before they are considered at all.
* @return the delay in milliseconds
*/
@ManagedAttribute("delay applied to all requests over the rate limit (in ms)")
public long getDelayMs()
{
return _delayMs;
}
/**
* Set delay (in milliseconds) that is applied to all requests
* over the rate limit, before they are considered at all.
*
* @param value delay (in milliseconds), 0 - no delay, -1 - reject request
*/
public void setDelayMs(long value)
{
_delayMs = value;
}
/**
* Get maximum amount of time (in milliseconds) the filter will
* blocking wait for the throttle semaphore.
*
* @return maximum wait time
*/
@ManagedAttribute("maximum time the filter will block waiting throttled connections, (0 for no delay, -1 to reject requests)")
public long getMaxWaitMs()
{
return _maxWaitMs;
}
/**
* Set maximum amount of time (in milliseconds) the filter will
* blocking wait for the throttle semaphore.
*
* @param value maximum wait time
*/
public void setMaxWaitMs(long value)
{
_maxWaitMs = value;
}
/**
* Get number of requests over the rate limit able to be
* considered at once.
*
* @return number of requests
*/
@ManagedAttribute("number of requests over rate limit")
public int getThrottledRequests()
{
return _throttledRequests;
}
/**
* Set number of requests over the rate limit able to be
* considered at once.
*
* @param value number of requests
*/
public void setThrottledRequests(int value)
{
int permits = _passes == null ? 0 : _passes.availablePermits();
_passes = new Semaphore((value - _throttledRequests + permits), true);
_throttledRequests = value;
}
/**
* Get amount of time (in milliseconds) to async wait for semaphore.
*
* @return wait time
*/
@ManagedAttribute("amount of time to async wait for semaphore")
public long getThrottleMs()
{
return _throttleMs;
}
/**
* Set amount of time (in milliseconds) to async wait for semaphore.
*
* @param value wait time
*/
public void setThrottleMs(long value)
{
_throttleMs = value;
}
/**
* Get maximum amount of time (in milliseconds) to allow
* the request to process.
*
* @return maximum processing time
*/
@ManagedAttribute("maximum time to allow requests to process (in ms)")
public long getMaxRequestMs()
{
return _maxRequestMs;
}
/**
* Set maximum amount of time (in milliseconds) to allow
* the request to process.
*
* @param value maximum processing time
*/
public void setMaxRequestMs(long value)
{
_maxRequestMs = value;
}
/**
* Get maximum amount of time (in milliseconds) to keep track
* of request rates for a connection, before deciding that
* the user has gone away, and discarding it.
*
* @return maximum tracking time
*/
@ManagedAttribute("maximum time to track of request rates for connection before discarding")
public long getMaxIdleTrackerMs()
{
return _maxIdleTrackerMs;
}
/**
* Set maximum amount of time (in milliseconds) to keep track
* of request rates for a connection, before deciding that
* the user has gone away, and discarding it.
*
* @param value maximum tracking time
*/
public void setMaxIdleTrackerMs(long value)
{
_maxIdleTrackerMs = value;
}
/**
* Check flag to insert the DoSFilter headers into the response.
*
* @return value of the flag
*/
@ManagedAttribute("inser DoSFilter headers in response")
public boolean isInsertHeaders()
{
return _insertHeaders;
}
/**
* Set flag to insert the DoSFilter headers into the response.
*
* @param value value of the flag
*/
public void setInsertHeaders(boolean value)
{
_insertHeaders = value;
}
/**
* Get flag to have usage rate tracked by session if a session exists.
*
* @return value of the flag
*/
@ManagedAttribute("usage rate is tracked by session if one exists")
public boolean isTrackSessions()
{
return _trackSessions;
}
/**
* Set flag to have usage rate tracked by session if a session exists.
*
* @param value value of the flag
*/
public void setTrackSessions(boolean value)
{
_trackSessions = value;
}
/**
* Get flag to have usage rate tracked by IP+port (effectively connection)
* if session tracking is not used.
*
* @return value of the flag
*/
@ManagedAttribute("usage rate is tracked by IP+port is session tracking not used")
public boolean isRemotePort()
{
return _remotePort;
}
/**
* Set flag to have usage rate tracked by IP+port (effectively connection)
* if session tracking is not used.
*
* @param value value of the flag
*/
public void setRemotePort(boolean value)
{
_remotePort = value;
}
/**
* @return whether this filter is enabled
*/
@ManagedAttribute("whether this filter is enabled")
public boolean isEnabled()
{
return _enabled;
}
/**
* @param enabled whether this filter is enabled
*/
public void setEnabled(boolean enabled)
{
_enabled = enabled;
}
public int getTooManyCode()
{
return _tooManyCode;
}
public void setTooManyCode(int tooManyCode)
{
_tooManyCode = tooManyCode;
}
/**
* Get a list of IP addresses that will not be rate limited.
*
* @return comma-separated whitelist
*/
@ManagedAttribute("list of IPs that will not be rate limited")
public String getWhitelist()
{
StringBuilder result = new StringBuilder();
for (Iterator<String> iterator = _whitelist.iterator(); iterator.hasNext();)
{
String address = iterator.next();
result.append(address);
if (iterator.hasNext())
result.append(",");
}
return result.toString();
}
/**
* Set a list of IP addresses that will not be rate limited.
*
* @param commaSeparatedList comma-separated whitelist
*/
public void setWhitelist(String commaSeparatedList)
{
List<String> result = new ArrayList<>();
for (String address : StringUtil.csvSplit(commaSeparatedList))
addWhitelistAddress(result, address);
clearWhitelist();
_whitelist.addAll(result);
LOG.debug("Whitelisted IP addresses: {}", result);
}
/**
* Clears the list of whitelisted IP addresses
*/
@ManagedOperation("clears the list of IP addresses that will not be rate limited")
public void clearWhitelist()
{
_whitelist.clear();
}
/**
* Adds the given IP address, either in the form of a dotted decimal notation A.B.C.D
* or in the CIDR notation A.B.C.D/M, to the list of whitelisted IP addresses.
*
* @param address the address to add
* @return whether the address was added to the list
* @see #removeWhitelistAddress(String)
*/
@ManagedOperation("adds an IP address that will not be rate limited")
public boolean addWhitelistAddress(@Name("address") String address)
{
return addWhitelistAddress(_whitelist, address);
}
private boolean addWhitelistAddress(List<String> list, String address)
{
address = address.trim();
return address.length() > 0 && list.add(address);
}
/**
* Removes the given address from the list of whitelisted IP addresses.
*
* @param address the address to remove
* @return whether the address was removed from the list
* @see #addWhitelistAddress(String)
*/
@ManagedOperation("removes an IP address that will not be rate limited")
public boolean removeWhitelistAddress(@Name("address") String address)
{
return _whitelist.remove(address);
}
/**
* A RateTracker is associated with a connection, and stores request rate
* data.
*/
class RateTracker implements Runnable, HttpSessionBindingListener, HttpSessionActivationListener, Serializable
{
private static final long serialVersionUID = 3534663738034577872L;
protected final String _id;
protected final int _type;
protected final long[] _timestamps;
protected int _next;
public RateTracker(String id, int type, int maxRequestsPerSecond)
{
_id = id;
_type = type;
_timestamps = new long[maxRequestsPerSecond];
_next = 0;
}
/**
* @param now the time now (in milliseconds)
* @return the current calculated request rate over the last second
*/
public boolean isRateExceeded(long now)
{
final long last;
synchronized (this)
{
last = _timestamps[_next];
_timestamps[_next] = now;
_next = (_next + 1) % _timestamps.length;
}
return last != 0 && (now - last) < 1000L;
}
public String getId()
{
return _id;
}
public int getType()
{
return _type;
}
public void valueBound(HttpSessionBindingEvent event)
{
if (LOG.isDebugEnabled())
LOG.debug("Value bound: {}", getId());
}
public void valueUnbound(HttpSessionBindingEvent event)
{
//take the tracker out of the list of trackers
_rateTrackers.remove(_id);
if (LOG.isDebugEnabled())
LOG.debug("Tracker removed: {}", getId());
}
public void sessionWillPassivate(HttpSessionEvent se)
{
//take the tracker of the list of trackers (if its still there)
_rateTrackers.remove(_id);
}
public void sessionDidActivate(HttpSessionEvent se)
{
RateTracker tracker = (RateTracker)se.getSession().getAttribute(__TRACKER);
if (tracker!=null)
_rateTrackers.put(tracker.getId(),tracker);
}
@Override
public void run()
{
int latestIndex = _next == 0 ? (_timestamps.length - 1) : (_next - 1);
long last = _timestamps[latestIndex];
boolean hasRecentRequest = last != 0 && (System.currentTimeMillis() - last) < 1000L;
if (hasRecentRequest)
_scheduler.schedule(this, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS);
else
_rateTrackers.remove(_id);
}
@Override
public String toString()
{
return "RateTracker/" + _id + "/" + _type;
}
}
class FixedRateTracker extends RateTracker
{
public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
{
super(id, type, numRecentRequestsTracked);
}
@Override
public boolean isRateExceeded(long now)
{
// rate limit is never exceeded, but we keep track of the request timestamps
// so that we know whether there was recent activity on this tracker
// and whether it should be expired
synchronized (this)
{
_timestamps[_next] = now;
_next = (_next + 1) % _timestamps.length;
}
return false;
}
@Override
public String toString()
{
return "Fixed" + super.toString();
}
}
private class DoSTimeoutAsyncListener implements AsyncListener
{
@Override
public void onStartAsync(AsyncEvent event) throws IOException
{
}
@Override
public void onComplete(AsyncEvent event) throws IOException
{
}
@Override
public void onTimeout(AsyncEvent event) throws IOException
{
event.getAsyncContext().dispatch();
}
@Override
public void onError(AsyncEvent event) throws IOException
{
}
}
private class DoSAsyncListener extends DoSTimeoutAsyncListener
{
private final int priority;
public DoSAsyncListener(int priority)
{
this.priority = priority;
}
@Override
public void onTimeout(AsyncEvent event) throws IOException
{
_queues[priority].remove(event.getAsyncContext());
super.onTimeout(event);
}
}
}