blob: 2e5dd35738d2f2d0b54d3e63ae15fd45bdb3a59a [file] [log] [blame]
//
// ========================================================================
// Copyright (c) 1995-2017 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.proxy;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executor;
import javax.servlet.AsyncContext;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpHeaderValue;
import org.eclipse.jetty.http.HttpMethod;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.io.ManagedSelector;
import org.eclipse.jetty.io.MappedByteBufferPool;
import org.eclipse.jetty.io.SelectChannelEndPoint;
import org.eclipse.jetty.io.SelectorManager;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.HttpConnection;
import org.eclipse.jetty.server.HttpTransport;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.handler.HandlerWrapper;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.HostPort;
import org.eclipse.jetty.util.Promise;
import org.eclipse.jetty.util.TypeUtil;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler;
import org.eclipse.jetty.util.thread.Scheduler;
/**
* <p>Implementation of a {@link Handler} that supports HTTP CONNECT.</p>
*/
public class ConnectHandler extends HandlerWrapper
{
protected static final Logger LOG = Log.getLogger(ConnectHandler.class);
private final Set<String> whiteList = new HashSet<>();
private final Set<String> blackList = new HashSet<>();
private Executor executor;
private Scheduler scheduler;
private ByteBufferPool bufferPool;
private SelectorManager selector;
private long connectTimeout = 15000;
private long idleTimeout = 30000;
private int bufferSize = 4096;
public ConnectHandler()
{
this(null);
}
public ConnectHandler(Handler handler)
{
setHandler(handler);
}
public Executor getExecutor()
{
return executor;
}
public void setExecutor(Executor executor)
{
this.executor = executor;
}
public Scheduler getScheduler()
{
return scheduler;
}
public void setScheduler(Scheduler scheduler)
{
this.scheduler = scheduler;
}
public ByteBufferPool getByteBufferPool()
{
return bufferPool;
}
public void setByteBufferPool(ByteBufferPool bufferPool)
{
this.bufferPool = bufferPool;
}
/**
* @return the timeout, in milliseconds, to connect to the remote server
*/
public long getConnectTimeout()
{
return connectTimeout;
}
/**
* @param connectTimeout the timeout, in milliseconds, to connect to the remote server
*/
public void setConnectTimeout(long connectTimeout)
{
this.connectTimeout = connectTimeout;
}
/**
* @return the idle timeout, in milliseconds
*/
public long getIdleTimeout()
{
return idleTimeout;
}
/**
* @param idleTimeout the idle timeout, in milliseconds
*/
public void setIdleTimeout(long idleTimeout)
{
this.idleTimeout = idleTimeout;
}
public int getBufferSize()
{
return bufferSize;
}
public void setBufferSize(int bufferSize)
{
this.bufferSize = bufferSize;
}
@Override
protected void doStart() throws Exception
{
if (executor == null)
executor = getServer().getThreadPool();
if (scheduler == null)
addBean(scheduler = new ScheduledExecutorScheduler());
if (bufferPool == null)
addBean(bufferPool = new MappedByteBufferPool());
addBean(selector = newSelectorManager());
selector.setConnectTimeout(getConnectTimeout());
super.doStart();
}
protected SelectorManager newSelectorManager()
{
return new ConnectManager(getExecutor(), getScheduler(), 1);
}
@Override
public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
{
if (HttpMethod.CONNECT.is(request.getMethod()))
{
String serverAddress = request.getRequestURI();
if (LOG.isDebugEnabled())
LOG.debug("CONNECT request for {}", serverAddress);
handleConnect(baseRequest, request, response, serverAddress);
}
else
{
super.handle(target, baseRequest, request, response);
}
}
/**
* <p>Handles a CONNECT request.</p>
* <p>CONNECT requests may have authentication headers such as {@code Proxy-Authorization}
* that authenticate the client with the proxy.</p>
*
* @param baseRequest Jetty-specific http request
* @param request the http request
* @param response the http response
* @param serverAddress the remote server address in the form {@code host:port}
*/
protected void handleConnect(Request baseRequest, HttpServletRequest request, HttpServletResponse response, String serverAddress)
{
baseRequest.setHandled(true);
try
{
boolean proceed = handleAuthentication(request, response, serverAddress);
if (!proceed)
{
if (LOG.isDebugEnabled())
LOG.debug("Missing proxy authentication");
sendConnectResponse(request, response, HttpServletResponse.SC_PROXY_AUTHENTICATION_REQUIRED);
return;
}
HostPort hostPort = new HostPort(serverAddress);
String host = hostPort.getHost();
int port = hostPort.getPort(80);
if (!validateDestination(host, port))
{
if (LOG.isDebugEnabled())
LOG.debug("Destination {}:{} forbidden", host, port);
sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN);
return;
}
HttpTransport transport = baseRequest.getHttpChannel().getHttpTransport();
// TODO Handle CONNECT over HTTP2!
if (!(transport instanceof HttpConnection))
{
if (LOG.isDebugEnabled())
LOG.debug("CONNECT not supported for {}", transport);
sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN);
return;
}
AsyncContext asyncContext = request.startAsync();
asyncContext.setTimeout(0);
if (LOG.isDebugEnabled())
LOG.debug("Connecting to {}:{}", host, port);
connectToServer(request, host, port, new Promise<SocketChannel>()
{
@Override
public void succeeded(SocketChannel channel)
{
ConnectContext connectContext = new ConnectContext(request, response, asyncContext, (HttpConnection)transport);
if (channel.isConnected())
selector.accept(channel, connectContext);
else
selector.connect(channel, connectContext);
}
@Override
public void failed(Throwable x)
{
onConnectFailure(request, response, asyncContext, x);
}
});
}
catch (Exception x)
{
onConnectFailure(request, response, null, x);
}
}
protected void connectToServer(HttpServletRequest request, String host, int port, Promise<SocketChannel> promise)
{
SocketChannel channel = null;
try
{
channel = SocketChannel.open();
channel.socket().setTcpNoDelay(true);
channel.configureBlocking(false);
InetSocketAddress address = newConnectAddress(host, port);
channel.connect(address);
promise.succeeded(channel);
}
catch (Throwable x)
{
close(channel);
promise.failed(x);
}
}
private void close(Closeable closeable)
{
try
{
if (closeable != null)
closeable.close();
}
catch (Throwable x)
{
LOG.ignore(x);
}
}
/**
* Creates the server address to connect to.
*
* @param host The host from the CONNECT request
* @param port The port from the CONNECT request
* @return The InetSocketAddress to connect to.
*/
protected InetSocketAddress newConnectAddress(String host, int port)
{
return new InetSocketAddress(host, port);
}
protected void onConnectSuccess(ConnectContext connectContext, UpstreamConnection upstreamConnection)
{
ConcurrentMap<String, Object> context = connectContext.getContext();
HttpServletRequest request = connectContext.getRequest();
prepareContext(request, context);
HttpConnection httpConnection = connectContext.getHttpConnection();
EndPoint downstreamEndPoint = httpConnection.getEndPoint();
DownstreamConnection downstreamConnection = newDownstreamConnection(downstreamEndPoint, context, BufferUtil.EMPTY_BUFFER);
downstreamConnection.setInputBufferSize(getBufferSize());
upstreamConnection.setConnection(downstreamConnection);
downstreamConnection.setConnection(upstreamConnection);
if (LOG.isDebugEnabled())
LOG.debug("Connection setup completed: {}<->{}", downstreamConnection, upstreamConnection);
HttpServletResponse response = connectContext.getResponse();
sendConnectResponse(request, response, HttpServletResponse.SC_OK);
upgradeConnection(request, response, downstreamConnection);
connectContext.getAsyncContext().complete();
}
protected void onConnectFailure(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, Throwable failure)
{
if (LOG.isDebugEnabled())
LOG.debug("CONNECT failed", failure);
sendConnectResponse(request, response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
if (asyncContext != null)
asyncContext.complete();
}
private void sendConnectResponse(HttpServletRequest request, HttpServletResponse response, int statusCode)
{
try
{
response.setStatus(statusCode);
response.setContentLength(0);
if (statusCode != HttpServletResponse.SC_OK)
response.setHeader(HttpHeader.CONNECTION.asString(), HttpHeaderValue.CLOSE.asString());
response.getOutputStream().close();
if (LOG.isDebugEnabled())
LOG.debug("CONNECT response sent {} {}", request.getProtocol(), response.getStatus());
}
catch (IOException x)
{
if (LOG.isDebugEnabled())
LOG.debug("Could not send CONNECT response", x);
}
}
/**
* <p>Handles the authentication before setting up the tunnel to the remote server.</p>
* <p>The default implementation returns true.</p>
*
* @param request the HTTP request
* @param response the HTTP response
* @param address the address of the remote server in the form {@code host:port}.
* @return true to allow to connect to the remote host, false otherwise
*/
protected boolean handleAuthentication(HttpServletRequest request, HttpServletResponse response, String address)
{
return true;
}
/**
* @deprecated use {@link #newDownstreamConnection(EndPoint, ConcurrentMap)} instead
*/
@Deprecated
protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context, ByteBuffer buffer)
{
return newDownstreamConnection(endPoint, context);
}
protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context)
{
return new DownstreamConnection(endPoint, getExecutor(), getByteBufferPool(), context);
}
protected UpstreamConnection newUpstreamConnection(EndPoint endPoint, ConnectContext connectContext)
{
return new UpstreamConnection(endPoint, getExecutor(), getByteBufferPool(), connectContext);
}
protected void prepareContext(HttpServletRequest request, ConcurrentMap<String, Object> context)
{
}
private void upgradeConnection(HttpServletRequest request, HttpServletResponse response, Connection connection)
{
// Set the new connection as request attribute and change the status to 101
// so that Jetty understands that it has to upgrade the connection
request.setAttribute(HttpConnection.UPGRADE_CONNECTION_ATTRIBUTE, connection);
response.setStatus(HttpServletResponse.SC_SWITCHING_PROTOCOLS);
if (LOG.isDebugEnabled())
LOG.debug("Upgraded connection to {}", connection);
}
/**
* <p>Reads (with non-blocking semantic) into the given {@code buffer} from the given {@code endPoint}.</p>
*
* @param endPoint the endPoint to read from
* @param buffer the buffer to read data into
* @param context the context information related to the connection
* @return the number of bytes read (possibly 0 since the read is non-blocking)
* or -1 if the channel has been closed remotely
* @throws IOException if the endPoint cannot be read
*/
protected int read(EndPoint endPoint, ByteBuffer buffer, ConcurrentMap<String, Object> context) throws IOException
{
int read = read(endPoint, buffer);
if (LOG.isDebugEnabled())
LOG.debug("{} read {} bytes", this, read);
return read;
}
/**
* @deprecated override {@link #read(EndPoint, ByteBuffer, ConcurrentMap)} instead
*/
@Deprecated
protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
{
return endPoint.fill(buffer);
}
/**
* <p>Writes (with non-blocking semantic) the given buffer of data onto the given endPoint.</p>
*
* @param endPoint the endPoint to write to
* @param buffer the buffer to write
* @param callback the completion callback to invoke
* @param context the context information related to the connection
*/
protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback, ConcurrentMap<String, Object> context)
{
if (LOG.isDebugEnabled())
LOG.debug("{} writing {} bytes", this, buffer.remaining());
write(endPoint, buffer, callback);
}
/**
* @deprecated override {@link #write(EndPoint, ByteBuffer, Callback, ConcurrentMap)} instead
*/
@Deprecated
protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
{
endPoint.write(callback, buffer);
}
public Set<String> getWhiteListHosts()
{
return whiteList;
}
public Set<String> getBlackListHosts()
{
return blackList;
}
/**
* Checks the given {@code host} and {@code port} against whitelist and blacklist.
*
* @param host the host to check
* @param port the port to check
* @return true if it is allowed to connect to the given host and port
*/
public boolean validateDestination(String host, int port)
{
String hostPort = host + ":" + port;
if (!whiteList.isEmpty())
{
if (!whiteList.contains(hostPort))
{
if (LOG.isDebugEnabled())
LOG.debug("Host {}:{} not whitelisted", host, port);
return false;
}
}
if (!blackList.isEmpty())
{
if (blackList.contains(hostPort))
{
if (LOG.isDebugEnabled())
LOG.debug("Host {}:{} blacklisted", host, port);
return false;
}
}
return true;
}
@Override
public void dump(Appendable out, String indent) throws IOException
{
dumpThis(out);
dump(out, indent, getBeans(), TypeUtil.asList(getHandlers()));
}
protected class ConnectManager extends SelectorManager
{
protected ConnectManager(Executor executor, Scheduler scheduler, int selectors)
{
super(executor, scheduler, selectors);
}
@Override
protected EndPoint newEndPoint(SocketChannel channel, ManagedSelector selector, SelectionKey selectionKey) throws IOException
{
return new SelectChannelEndPoint(channel, selector, selectionKey, getScheduler(), getIdleTimeout());
}
@Override
public Connection newConnection(SocketChannel channel, EndPoint endpoint, Object attachment) throws IOException
{
if (ConnectHandler.LOG.isDebugEnabled())
ConnectHandler.LOG.debug("Connected to {}", channel.getRemoteAddress());
ConnectContext connectContext = (ConnectContext)attachment;
UpstreamConnection connection = newUpstreamConnection(endpoint, connectContext);
connection.setInputBufferSize(getBufferSize());
return connection;
}
@Override
protected void connectionFailed(SocketChannel channel, final Throwable ex, final Object attachment)
{
close(channel);
ConnectContext connectContext = (ConnectContext)attachment;
onConnectFailure(connectContext.request, connectContext.response, connectContext.asyncContext, ex);
}
}
protected static class ConnectContext
{
private final ConcurrentMap<String, Object> context = new ConcurrentHashMap<>();
private final HttpServletRequest request;
private final HttpServletResponse response;
private final AsyncContext asyncContext;
private final HttpConnection httpConnection;
public ConnectContext(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, HttpConnection httpConnection)
{
this.request = request;
this.response = response;
this.asyncContext = asyncContext;
this.httpConnection = httpConnection;
}
public ConcurrentMap<String, Object> getContext()
{
return context;
}
public HttpServletRequest getRequest()
{
return request;
}
public HttpServletResponse getResponse()
{
return response;
}
public AsyncContext getAsyncContext()
{
return asyncContext;
}
public HttpConnection getHttpConnection()
{
return httpConnection;
}
}
public class UpstreamConnection extends ProxyConnection
{
private ConnectContext connectContext;
public UpstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConnectContext connectContext)
{
super(endPoint, executor, bufferPool, connectContext.getContext());
this.connectContext = connectContext;
}
@Override
public void onOpen()
{
super.onOpen();
onConnectSuccess(connectContext, UpstreamConnection.this);
fillInterested();
}
@Override
protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
{
return ConnectHandler.this.read(endPoint, buffer, getContext());
}
@Override
protected void write(EndPoint endPoint, ByteBuffer buffer,Callback callback)
{
ConnectHandler.this.write(endPoint, buffer, callback, getContext());
}
}
public class DownstreamConnection extends ProxyConnection implements Connection.UpgradeTo
{
private ByteBuffer buffer;
public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context)
{
super(endPoint, executor, bufferPool, context);
}
/**
* @deprecated use {@link #DownstreamConnection(EndPoint, Executor, ByteBufferPool, ConcurrentMap)} instead
*/
@Deprecated
public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context, ByteBuffer buffer)
{
this(endPoint, executor, bufferPool, context);
}
@Override
public void onUpgradeTo(ByteBuffer buffer)
{
this.buffer = buffer == null ? BufferUtil.EMPTY_BUFFER : buffer;
}
@Override
public void onOpen()
{
super.onOpen();
final int remaining = buffer.remaining();
write(getConnection().getEndPoint(), buffer, new Callback()
{
@Override
public void succeeded()
{
if (LOG.isDebugEnabled())
LOG.debug("{} wrote initial {} bytes to server", DownstreamConnection.this, remaining);
fillInterested();
}
@Override
public void failed(Throwable x)
{
if (LOG.isDebugEnabled())
LOG.debug(this + " failed to write initial " + remaining + " bytes to server", x);
close();
getConnection().close();
}
});
}
@Override
protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
{
return ConnectHandler.this.read(endPoint, buffer, getContext());
}
@Override
protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
{
ConnectHandler.this.write(endPoint, buffer, callback, getContext());
}
}
}