| // |
| // ======================================================================== |
| // Copyright (c) 1995-2017 Mort Bay Consulting Pty. Ltd. |
| // ------------------------------------------------------------------------ |
| // All rights reserved. This program and the accompanying materials |
| // are made available under the terms of the Eclipse Public License v1.0 |
| // and Apache License v2.0 which accompanies this distribution. |
| // |
| // The Eclipse Public License is available at |
| // http://www.eclipse.org/legal/epl-v10.html |
| // |
| // The Apache License v2.0 is available at |
| // http://www.opensource.org/licenses/apache2.0.php |
| // |
| // You may elect to redistribute this code under either of these licenses. |
| // ======================================================================== |
| // |
| |
| package org.eclipse.jetty.proxy; |
| |
| import java.io.Closeable; |
| import java.io.IOException; |
| import java.net.InetSocketAddress; |
| import java.nio.ByteBuffer; |
| import java.nio.channels.SelectionKey; |
| import java.nio.channels.SocketChannel; |
| import java.util.HashSet; |
| import java.util.Set; |
| import java.util.concurrent.ConcurrentHashMap; |
| import java.util.concurrent.ConcurrentMap; |
| import java.util.concurrent.Executor; |
| |
| import javax.servlet.AsyncContext; |
| import javax.servlet.ServletException; |
| import javax.servlet.http.HttpServletRequest; |
| import javax.servlet.http.HttpServletResponse; |
| |
| import org.eclipse.jetty.http.HttpHeader; |
| import org.eclipse.jetty.http.HttpHeaderValue; |
| import org.eclipse.jetty.http.HttpMethod; |
| import org.eclipse.jetty.io.ByteBufferPool; |
| import org.eclipse.jetty.io.Connection; |
| import org.eclipse.jetty.io.EndPoint; |
| import org.eclipse.jetty.io.ManagedSelector; |
| import org.eclipse.jetty.io.MappedByteBufferPool; |
| import org.eclipse.jetty.io.SelectChannelEndPoint; |
| import org.eclipse.jetty.io.SelectorManager; |
| import org.eclipse.jetty.server.Handler; |
| import org.eclipse.jetty.server.HttpConnection; |
| import org.eclipse.jetty.server.HttpTransport; |
| import org.eclipse.jetty.server.Request; |
| import org.eclipse.jetty.server.handler.HandlerWrapper; |
| import org.eclipse.jetty.util.BufferUtil; |
| import org.eclipse.jetty.util.Callback; |
| import org.eclipse.jetty.util.HostPort; |
| import org.eclipse.jetty.util.Promise; |
| import org.eclipse.jetty.util.TypeUtil; |
| import org.eclipse.jetty.util.log.Log; |
| import org.eclipse.jetty.util.log.Logger; |
| import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler; |
| import org.eclipse.jetty.util.thread.Scheduler; |
| |
| /** |
| * <p>Implementation of a {@link Handler} that supports HTTP CONNECT.</p> |
| */ |
| public class ConnectHandler extends HandlerWrapper |
| { |
| protected static final Logger LOG = Log.getLogger(ConnectHandler.class); |
| |
| private final Set<String> whiteList = new HashSet<>(); |
| private final Set<String> blackList = new HashSet<>(); |
| private Executor executor; |
| private Scheduler scheduler; |
| private ByteBufferPool bufferPool; |
| private SelectorManager selector; |
| private long connectTimeout = 15000; |
| private long idleTimeout = 30000; |
| private int bufferSize = 4096; |
| |
| public ConnectHandler() |
| { |
| this(null); |
| } |
| |
| public ConnectHandler(Handler handler) |
| { |
| setHandler(handler); |
| } |
| |
| public Executor getExecutor() |
| { |
| return executor; |
| } |
| |
| public void setExecutor(Executor executor) |
| { |
| this.executor = executor; |
| } |
| |
| public Scheduler getScheduler() |
| { |
| return scheduler; |
| } |
| |
| public void setScheduler(Scheduler scheduler) |
| { |
| this.scheduler = scheduler; |
| } |
| |
| public ByteBufferPool getByteBufferPool() |
| { |
| return bufferPool; |
| } |
| |
| public void setByteBufferPool(ByteBufferPool bufferPool) |
| { |
| this.bufferPool = bufferPool; |
| } |
| |
| /** |
| * @return the timeout, in milliseconds, to connect to the remote server |
| */ |
| public long getConnectTimeout() |
| { |
| return connectTimeout; |
| } |
| |
| /** |
| * @param connectTimeout the timeout, in milliseconds, to connect to the remote server |
| */ |
| public void setConnectTimeout(long connectTimeout) |
| { |
| this.connectTimeout = connectTimeout; |
| } |
| |
| /** |
| * @return the idle timeout, in milliseconds |
| */ |
| public long getIdleTimeout() |
| { |
| return idleTimeout; |
| } |
| |
| /** |
| * @param idleTimeout the idle timeout, in milliseconds |
| */ |
| public void setIdleTimeout(long idleTimeout) |
| { |
| this.idleTimeout = idleTimeout; |
| } |
| |
| public int getBufferSize() |
| { |
| return bufferSize; |
| } |
| |
| public void setBufferSize(int bufferSize) |
| { |
| this.bufferSize = bufferSize; |
| } |
| |
| @Override |
| protected void doStart() throws Exception |
| { |
| if (executor == null) |
| executor = getServer().getThreadPool(); |
| |
| if (scheduler == null) |
| addBean(scheduler = new ScheduledExecutorScheduler()); |
| |
| if (bufferPool == null) |
| addBean(bufferPool = new MappedByteBufferPool()); |
| |
| addBean(selector = newSelectorManager()); |
| selector.setConnectTimeout(getConnectTimeout()); |
| |
| super.doStart(); |
| } |
| |
| protected SelectorManager newSelectorManager() |
| { |
| return new ConnectManager(getExecutor(), getScheduler(), 1); |
| } |
| |
| @Override |
| public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException |
| { |
| if (HttpMethod.CONNECT.is(request.getMethod())) |
| { |
| String serverAddress = request.getRequestURI(); |
| if (LOG.isDebugEnabled()) |
| LOG.debug("CONNECT request for {}", serverAddress); |
| |
| handleConnect(baseRequest, request, response, serverAddress); |
| } |
| else |
| { |
| super.handle(target, baseRequest, request, response); |
| } |
| } |
| |
| /** |
| * <p>Handles a CONNECT request.</p> |
| * <p>CONNECT requests may have authentication headers such as {@code Proxy-Authorization} |
| * that authenticate the client with the proxy.</p> |
| * |
| * @param baseRequest Jetty-specific http request |
| * @param request the http request |
| * @param response the http response |
| * @param serverAddress the remote server address in the form {@code host:port} |
| */ |
| protected void handleConnect(Request baseRequest, HttpServletRequest request, HttpServletResponse response, String serverAddress) |
| { |
| baseRequest.setHandled(true); |
| try |
| { |
| boolean proceed = handleAuthentication(request, response, serverAddress); |
| if (!proceed) |
| { |
| if (LOG.isDebugEnabled()) |
| LOG.debug("Missing proxy authentication"); |
| sendConnectResponse(request, response, HttpServletResponse.SC_PROXY_AUTHENTICATION_REQUIRED); |
| return; |
| } |
| |
| HostPort hostPort = new HostPort(serverAddress); |
| String host = hostPort.getHost(); |
| int port = hostPort.getPort(80); |
| |
| if (!validateDestination(host, port)) |
| { |
| if (LOG.isDebugEnabled()) |
| LOG.debug("Destination {}:{} forbidden", host, port); |
| sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN); |
| return; |
| } |
| |
| HttpTransport transport = baseRequest.getHttpChannel().getHttpTransport(); |
| // TODO Handle CONNECT over HTTP2! |
| if (!(transport instanceof HttpConnection)) |
| { |
| if (LOG.isDebugEnabled()) |
| LOG.debug("CONNECT not supported for {}", transport); |
| sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN); |
| return; |
| } |
| |
| AsyncContext asyncContext = request.startAsync(); |
| asyncContext.setTimeout(0); |
| |
| if (LOG.isDebugEnabled()) |
| LOG.debug("Connecting to {}:{}", host, port); |
| |
| connectToServer(request, host, port, new Promise<SocketChannel>() |
| { |
| @Override |
| public void succeeded(SocketChannel channel) |
| { |
| ConnectContext connectContext = new ConnectContext(request, response, asyncContext, (HttpConnection)transport); |
| if (channel.isConnected()) |
| selector.accept(channel, connectContext); |
| else |
| selector.connect(channel, connectContext); |
| } |
| |
| @Override |
| public void failed(Throwable x) |
| { |
| onConnectFailure(request, response, asyncContext, x); |
| } |
| }); |
| } |
| catch (Exception x) |
| { |
| onConnectFailure(request, response, null, x); |
| } |
| } |
| |
| protected void connectToServer(HttpServletRequest request, String host, int port, Promise<SocketChannel> promise) |
| { |
| SocketChannel channel = null; |
| try |
| { |
| channel = SocketChannel.open(); |
| channel.socket().setTcpNoDelay(true); |
| channel.configureBlocking(false); |
| InetSocketAddress address = newConnectAddress(host, port); |
| channel.connect(address); |
| promise.succeeded(channel); |
| } |
| catch (Throwable x) |
| { |
| close(channel); |
| promise.failed(x); |
| } |
| } |
| |
| private void close(Closeable closeable) |
| { |
| try |
| { |
| if (closeable != null) |
| closeable.close(); |
| } |
| catch (Throwable x) |
| { |
| LOG.ignore(x); |
| } |
| } |
| |
| /** |
| * Creates the server address to connect to. |
| * |
| * @param host The host from the CONNECT request |
| * @param port The port from the CONNECT request |
| * @return The InetSocketAddress to connect to. |
| */ |
| protected InetSocketAddress newConnectAddress(String host, int port) |
| { |
| return new InetSocketAddress(host, port); |
| } |
| |
| protected void onConnectSuccess(ConnectContext connectContext, UpstreamConnection upstreamConnection) |
| { |
| ConcurrentMap<String, Object> context = connectContext.getContext(); |
| HttpServletRequest request = connectContext.getRequest(); |
| prepareContext(request, context); |
| |
| HttpConnection httpConnection = connectContext.getHttpConnection(); |
| EndPoint downstreamEndPoint = httpConnection.getEndPoint(); |
| DownstreamConnection downstreamConnection = newDownstreamConnection(downstreamEndPoint, context, BufferUtil.EMPTY_BUFFER); |
| downstreamConnection.setInputBufferSize(getBufferSize()); |
| |
| upstreamConnection.setConnection(downstreamConnection); |
| downstreamConnection.setConnection(upstreamConnection); |
| if (LOG.isDebugEnabled()) |
| LOG.debug("Connection setup completed: {}<->{}", downstreamConnection, upstreamConnection); |
| |
| HttpServletResponse response = connectContext.getResponse(); |
| sendConnectResponse(request, response, HttpServletResponse.SC_OK); |
| |
| upgradeConnection(request, response, downstreamConnection); |
| |
| connectContext.getAsyncContext().complete(); |
| } |
| |
| protected void onConnectFailure(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, Throwable failure) |
| { |
| if (LOG.isDebugEnabled()) |
| LOG.debug("CONNECT failed", failure); |
| sendConnectResponse(request, response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR); |
| if (asyncContext != null) |
| asyncContext.complete(); |
| } |
| |
| private void sendConnectResponse(HttpServletRequest request, HttpServletResponse response, int statusCode) |
| { |
| try |
| { |
| response.setStatus(statusCode); |
| response.setContentLength(0); |
| if (statusCode != HttpServletResponse.SC_OK) |
| response.setHeader(HttpHeader.CONNECTION.asString(), HttpHeaderValue.CLOSE.asString()); |
| response.getOutputStream().close(); |
| if (LOG.isDebugEnabled()) |
| LOG.debug("CONNECT response sent {} {}", request.getProtocol(), response.getStatus()); |
| } |
| catch (IOException x) |
| { |
| if (LOG.isDebugEnabled()) |
| LOG.debug("Could not send CONNECT response", x); |
| } |
| } |
| |
| /** |
| * <p>Handles the authentication before setting up the tunnel to the remote server.</p> |
| * <p>The default implementation returns true.</p> |
| * |
| * @param request the HTTP request |
| * @param response the HTTP response |
| * @param address the address of the remote server in the form {@code host:port}. |
| * @return true to allow to connect to the remote host, false otherwise |
| */ |
| protected boolean handleAuthentication(HttpServletRequest request, HttpServletResponse response, String address) |
| { |
| return true; |
| } |
| |
| /** |
| * @deprecated use {@link #newDownstreamConnection(EndPoint, ConcurrentMap)} instead |
| */ |
| @Deprecated |
| protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context, ByteBuffer buffer) |
| { |
| return newDownstreamConnection(endPoint, context); |
| } |
| |
| protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context) |
| { |
| return new DownstreamConnection(endPoint, getExecutor(), getByteBufferPool(), context); |
| } |
| |
| protected UpstreamConnection newUpstreamConnection(EndPoint endPoint, ConnectContext connectContext) |
| { |
| return new UpstreamConnection(endPoint, getExecutor(), getByteBufferPool(), connectContext); |
| } |
| |
| protected void prepareContext(HttpServletRequest request, ConcurrentMap<String, Object> context) |
| { |
| } |
| |
| private void upgradeConnection(HttpServletRequest request, HttpServletResponse response, Connection connection) |
| { |
| // Set the new connection as request attribute and change the status to 101 |
| // so that Jetty understands that it has to upgrade the connection |
| request.setAttribute(HttpConnection.UPGRADE_CONNECTION_ATTRIBUTE, connection); |
| response.setStatus(HttpServletResponse.SC_SWITCHING_PROTOCOLS); |
| if (LOG.isDebugEnabled()) |
| LOG.debug("Upgraded connection to {}", connection); |
| } |
| |
| /** |
| * <p>Reads (with non-blocking semantic) into the given {@code buffer} from the given {@code endPoint}.</p> |
| * |
| * @param endPoint the endPoint to read from |
| * @param buffer the buffer to read data into |
| * @param context the context information related to the connection |
| * @return the number of bytes read (possibly 0 since the read is non-blocking) |
| * or -1 if the channel has been closed remotely |
| * @throws IOException if the endPoint cannot be read |
| */ |
| protected int read(EndPoint endPoint, ByteBuffer buffer, ConcurrentMap<String, Object> context) throws IOException |
| { |
| int read = read(endPoint, buffer); |
| if (LOG.isDebugEnabled()) |
| LOG.debug("{} read {} bytes", this, read); |
| return read; |
| } |
| |
| /** |
| * @deprecated override {@link #read(EndPoint, ByteBuffer, ConcurrentMap)} instead |
| */ |
| @Deprecated |
| protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException |
| { |
| return endPoint.fill(buffer); |
| } |
| |
| /** |
| * <p>Writes (with non-blocking semantic) the given buffer of data onto the given endPoint.</p> |
| * |
| * @param endPoint the endPoint to write to |
| * @param buffer the buffer to write |
| * @param callback the completion callback to invoke |
| * @param context the context information related to the connection |
| */ |
| protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback, ConcurrentMap<String, Object> context) |
| { |
| if (LOG.isDebugEnabled()) |
| LOG.debug("{} writing {} bytes", this, buffer.remaining()); |
| write(endPoint, buffer, callback); |
| } |
| |
| /** |
| * @deprecated override {@link #write(EndPoint, ByteBuffer, Callback, ConcurrentMap)} instead |
| */ |
| @Deprecated |
| protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback) |
| { |
| endPoint.write(callback, buffer); |
| } |
| |
| public Set<String> getWhiteListHosts() |
| { |
| return whiteList; |
| } |
| |
| public Set<String> getBlackListHosts() |
| { |
| return blackList; |
| } |
| |
| /** |
| * Checks the given {@code host} and {@code port} against whitelist and blacklist. |
| * |
| * @param host the host to check |
| * @param port the port to check |
| * @return true if it is allowed to connect to the given host and port |
| */ |
| public boolean validateDestination(String host, int port) |
| { |
| String hostPort = host + ":" + port; |
| if (!whiteList.isEmpty()) |
| { |
| if (!whiteList.contains(hostPort)) |
| { |
| if (LOG.isDebugEnabled()) |
| LOG.debug("Host {}:{} not whitelisted", host, port); |
| return false; |
| } |
| } |
| if (!blackList.isEmpty()) |
| { |
| if (blackList.contains(hostPort)) |
| { |
| if (LOG.isDebugEnabled()) |
| LOG.debug("Host {}:{} blacklisted", host, port); |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| @Override |
| public void dump(Appendable out, String indent) throws IOException |
| { |
| dumpThis(out); |
| dump(out, indent, getBeans(), TypeUtil.asList(getHandlers())); |
| } |
| |
| protected class ConnectManager extends SelectorManager |
| { |
| protected ConnectManager(Executor executor, Scheduler scheduler, int selectors) |
| { |
| super(executor, scheduler, selectors); |
| } |
| |
| @Override |
| protected EndPoint newEndPoint(SocketChannel channel, ManagedSelector selector, SelectionKey selectionKey) throws IOException |
| { |
| return new SelectChannelEndPoint(channel, selector, selectionKey, getScheduler(), getIdleTimeout()); |
| } |
| |
| @Override |
| public Connection newConnection(SocketChannel channel, EndPoint endpoint, Object attachment) throws IOException |
| { |
| if (ConnectHandler.LOG.isDebugEnabled()) |
| ConnectHandler.LOG.debug("Connected to {}", channel.getRemoteAddress()); |
| ConnectContext connectContext = (ConnectContext)attachment; |
| UpstreamConnection connection = newUpstreamConnection(endpoint, connectContext); |
| connection.setInputBufferSize(getBufferSize()); |
| return connection; |
| } |
| |
| @Override |
| protected void connectionFailed(SocketChannel channel, final Throwable ex, final Object attachment) |
| { |
| close(channel); |
| ConnectContext connectContext = (ConnectContext)attachment; |
| onConnectFailure(connectContext.request, connectContext.response, connectContext.asyncContext, ex); |
| } |
| } |
| |
| protected static class ConnectContext |
| { |
| private final ConcurrentMap<String, Object> context = new ConcurrentHashMap<>(); |
| private final HttpServletRequest request; |
| private final HttpServletResponse response; |
| private final AsyncContext asyncContext; |
| private final HttpConnection httpConnection; |
| |
| public ConnectContext(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, HttpConnection httpConnection) |
| { |
| this.request = request; |
| this.response = response; |
| this.asyncContext = asyncContext; |
| this.httpConnection = httpConnection; |
| } |
| |
| public ConcurrentMap<String, Object> getContext() |
| { |
| return context; |
| } |
| |
| public HttpServletRequest getRequest() |
| { |
| return request; |
| } |
| |
| public HttpServletResponse getResponse() |
| { |
| return response; |
| } |
| |
| public AsyncContext getAsyncContext() |
| { |
| return asyncContext; |
| } |
| |
| public HttpConnection getHttpConnection() |
| { |
| return httpConnection; |
| } |
| } |
| |
| public class UpstreamConnection extends ProxyConnection |
| { |
| private ConnectContext connectContext; |
| |
| public UpstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConnectContext connectContext) |
| { |
| super(endPoint, executor, bufferPool, connectContext.getContext()); |
| this.connectContext = connectContext; |
| } |
| |
| @Override |
| public void onOpen() |
| { |
| super.onOpen(); |
| onConnectSuccess(connectContext, UpstreamConnection.this); |
| fillInterested(); |
| } |
| |
| @Override |
| protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException |
| { |
| return ConnectHandler.this.read(endPoint, buffer, getContext()); |
| } |
| |
| @Override |
| protected void write(EndPoint endPoint, ByteBuffer buffer,Callback callback) |
| { |
| ConnectHandler.this.write(endPoint, buffer, callback, getContext()); |
| } |
| } |
| |
| public class DownstreamConnection extends ProxyConnection implements Connection.UpgradeTo |
| { |
| private ByteBuffer buffer; |
| |
| public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context) |
| { |
| super(endPoint, executor, bufferPool, context); |
| } |
| |
| /** |
| * @deprecated use {@link #DownstreamConnection(EndPoint, Executor, ByteBufferPool, ConcurrentMap)} instead |
| */ |
| @Deprecated |
| public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context, ByteBuffer buffer) |
| { |
| this(endPoint, executor, bufferPool, context); |
| } |
| |
| @Override |
| public void onUpgradeTo(ByteBuffer buffer) |
| { |
| this.buffer = buffer == null ? BufferUtil.EMPTY_BUFFER : buffer; |
| } |
| |
| @Override |
| public void onOpen() |
| { |
| super.onOpen(); |
| final int remaining = buffer.remaining(); |
| write(getConnection().getEndPoint(), buffer, new Callback() |
| { |
| @Override |
| public void succeeded() |
| { |
| if (LOG.isDebugEnabled()) |
| LOG.debug("{} wrote initial {} bytes to server", DownstreamConnection.this, remaining); |
| fillInterested(); |
| } |
| |
| @Override |
| public void failed(Throwable x) |
| { |
| if (LOG.isDebugEnabled()) |
| LOG.debug(this + " failed to write initial " + remaining + " bytes to server", x); |
| close(); |
| getConnection().close(); |
| } |
| }); |
| } |
| |
| @Override |
| protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException |
| { |
| return ConnectHandler.this.read(endPoint, buffer, getContext()); |
| } |
| |
| @Override |
| protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback) |
| { |
| ConnectHandler.this.write(endPoint, buffer, callback, getContext()); |
| } |
| } |
| } |