| // |
| // ======================================================================== |
| // Copyright (c) 1995-2017 Mort Bay Consulting Pty. Ltd. |
| // ------------------------------------------------------------------------ |
| // All rights reserved. This program and the accompanying materials |
| // are made available under the terms of the Eclipse Public License v1.0 |
| // and Apache License v2.0 which accompanies this distribution. |
| // |
| // The Eclipse Public License is available at |
| // http://www.eclipse.org/legal/epl-v10.html |
| // |
| // The Apache License v2.0 is available at |
| // http://www.opensource.org/licenses/apache2.0.php |
| // |
| // You may elect to redistribute this code under either of these licenses. |
| // ======================================================================== |
| // |
| |
| package org.eclipse.jetty.io; |
| |
| import java.io.BufferedReader; |
| import java.io.File; |
| import java.io.IOException; |
| import java.io.InputStreamReader; |
| import java.net.Socket; |
| import java.nio.ByteBuffer; |
| import java.nio.channels.SelectionKey; |
| import java.nio.channels.ServerSocketChannel; |
| import java.nio.channels.SocketChannel; |
| import java.nio.charset.StandardCharsets; |
| import java.util.concurrent.CountDownLatch; |
| import java.util.concurrent.TimeUnit; |
| import java.util.concurrent.atomic.AtomicInteger; |
| |
| import javax.net.ssl.SSLEngine; |
| import javax.net.ssl.SSLException; |
| import javax.net.ssl.SSLSocket; |
| |
| import org.eclipse.jetty.io.ssl.SslConnection; |
| import org.eclipse.jetty.toolchain.test.MavenTestingUtils; |
| import org.eclipse.jetty.util.BufferUtil; |
| import org.eclipse.jetty.util.FutureCallback; |
| import org.eclipse.jetty.util.ssl.SslContextFactory; |
| import org.eclipse.jetty.util.thread.QueuedThreadPool; |
| import org.eclipse.jetty.util.thread.Scheduler; |
| import org.eclipse.jetty.util.thread.TimerScheduler; |
| import org.junit.After; |
| import org.junit.Assert; |
| import org.junit.Before; |
| import org.junit.BeforeClass; |
| import org.junit.Test; |
| |
| |
| public class SslConnectionTest |
| { |
| private static SslContextFactory __sslCtxFactory=new SslContextFactory(); |
| private static ByteBufferPool __byteBufferPool = new LeakTrackingByteBufferPool(new MappedByteBufferPool.Tagged()); |
| |
| protected volatile EndPoint _lastEndp; |
| private volatile boolean _testFill=true; |
| private volatile FutureCallback _writeCallback; |
| protected ServerSocketChannel _connector; |
| final AtomicInteger _dispatches = new AtomicInteger(); |
| protected QueuedThreadPool _threadPool = new QueuedThreadPool() |
| { |
| @Override |
| public void execute(Runnable job) |
| { |
| _dispatches.incrementAndGet(); |
| super.execute(job); |
| } |
| |
| }; |
| protected Scheduler _scheduler = new TimerScheduler(); |
| protected SelectorManager _manager = new SelectorManager(_threadPool, _scheduler) |
| { |
| @Override |
| public Connection newConnection(SocketChannel channel, EndPoint endpoint, Object attachment) |
| { |
| SSLEngine engine = __sslCtxFactory.newSSLEngine(); |
| engine.setUseClientMode(false); |
| SslConnection sslConnection = new SslConnection(__byteBufferPool, getExecutor(), endpoint, engine); |
| sslConnection.setRenegotiationAllowed(__sslCtxFactory.isRenegotiationAllowed()); |
| sslConnection.setRenegotiationLimit(__sslCtxFactory.getRenegotiationLimit()); |
| Connection appConnection = new TestConnection(sslConnection.getDecryptedEndPoint()); |
| sslConnection.getDecryptedEndPoint().setConnection(appConnection); |
| return sslConnection; |
| } |
| |
| @Override |
| protected SelectChannelEndPoint newEndPoint(SocketChannel channel, ManagedSelector selectSet, SelectionKey selectionKey) throws IOException |
| { |
| SelectChannelEndPoint endp = new TestEP(channel,selectSet, selectionKey, getScheduler(), 60000); |
| _lastEndp=endp; |
| return endp; |
| } |
| }; |
| |
| static final AtomicInteger __startBlocking = new AtomicInteger(); |
| static final AtomicInteger __blockFor = new AtomicInteger(); |
| private static class TestEP extends SelectChannelEndPoint |
| { |
| |
| public TestEP(SocketChannel channel, ManagedSelector selector, SelectionKey key, Scheduler scheduler, long idleTimeout) |
| { |
| super(channel,selector,key,scheduler,idleTimeout); |
| } |
| |
| @Override |
| protected void onIncompleteFlush() |
| { |
| super.onIncompleteFlush(); |
| } |
| |
| |
| @Override |
| public boolean flush(ByteBuffer... buffers) throws IOException |
| { |
| if (__startBlocking.get()==0 || __startBlocking.decrementAndGet()==0) |
| { |
| if (__blockFor.get()>0 && __blockFor.getAndDecrement()>0) |
| { |
| return false; |
| } |
| } |
| String s=BufferUtil.toDetailString(buffers[0]); |
| boolean flushed=super.flush(buffers); |
| return flushed; |
| } |
| } |
| |
| |
| @BeforeClass |
| public static void initSslEngine() throws Exception |
| { |
| File keystore = MavenTestingUtils.getTestResourceFile("keystore"); |
| __sslCtxFactory.setKeyStorePath(keystore.getAbsolutePath()); |
| __sslCtxFactory.setKeyStorePassword("storepwd"); |
| __sslCtxFactory.setKeyManagerPassword("keypwd"); |
| __sslCtxFactory.start(); |
| } |
| |
| @Before |
| public void startManager() throws Exception |
| { |
| _testFill=true; |
| _writeCallback=null; |
| _lastEndp=null; |
| _connector = ServerSocketChannel.open(); |
| _connector.socket().bind(null); |
| _threadPool.start(); |
| _scheduler.start(); |
| _manager.start(); |
| __sslCtxFactory.setRenegotiationAllowed(true); |
| __sslCtxFactory.setRenegotiationLimit(-1); |
| |
| } |
| |
| @After |
| public void stopManager() throws Exception |
| { |
| if (_lastEndp.isOpen()) |
| _lastEndp.close(); |
| _manager.stop(); |
| _scheduler.stop(); |
| _threadPool.stop(); |
| _connector.close(); |
| } |
| |
| public class TestConnection extends AbstractConnection |
| { |
| ByteBuffer _in = BufferUtil.allocate(8*1024); |
| |
| public TestConnection(EndPoint endp) |
| { |
| super(endp, _threadPool); |
| } |
| |
| @Override |
| public void onOpen() |
| { |
| super.onOpen(); |
| if (_testFill) |
| fillInterested(); |
| else |
| { |
| getExecutor().execute(new Runnable() |
| { |
| |
| @Override |
| public void run() |
| { |
| getEndPoint().write(_writeCallback,BufferUtil.toBuffer("Hello Client")); |
| } |
| }); |
| } |
| } |
| |
| @Override |
| public void onClose() |
| { |
| super.onClose(); |
| } |
| |
| @Override |
| public synchronized void onFillable() |
| { |
| EndPoint endp = getEndPoint(); |
| try |
| { |
| boolean progress=true; |
| while(progress) |
| { |
| progress=false; |
| |
| // Fill the input buffer with everything available |
| int filled=endp.fill(_in); |
| while (filled>0) |
| { |
| progress=true; |
| filled=endp.fill(_in); |
| } |
| |
| // Write everything |
| int l=_in.remaining(); |
| if (l>0) |
| { |
| FutureCallback blockingWrite= new FutureCallback(); |
| endp.write(blockingWrite,_in); |
| blockingWrite.get(); |
| } |
| |
| // are we done? |
| if (endp.isInputShutdown()) |
| { |
| endp.shutdownOutput(); |
| } |
| } |
| } |
| catch(InterruptedException|EofException e) |
| { |
| SelectChannelEndPoint.LOG.ignore(e); |
| } |
| catch(Exception e) |
| { |
| SelectChannelEndPoint.LOG.warn(e); |
| } |
| finally |
| { |
| if (endp.isOpen()) |
| fillInterested(); |
| } |
| } |
| } |
| protected SSLSocket newClient() throws IOException |
| { |
| SSLSocket socket = __sslCtxFactory.newSslSocket(); |
| socket.connect(_connector.socket().getLocalSocketAddress()); |
| return socket; |
| } |
| |
| @Test |
| public void testHelloWorld() throws Exception |
| { |
| Socket client = newClient(); |
| client.setSoTimeout(60000); |
| |
| SocketChannel server = _connector.accept(); |
| server.configureBlocking(false); |
| _manager.accept(server); |
| |
| client.getOutputStream().write("Hello".getBytes(StandardCharsets.UTF_8)); |
| byte[] buffer = new byte[1024]; |
| int len=client.getInputStream().read(buffer); |
| Assert.assertEquals(5, len); |
| Assert.assertEquals("Hello",new String(buffer,0,len,StandardCharsets.UTF_8)); |
| |
| _dispatches.set(0); |
| client.getOutputStream().write("World".getBytes(StandardCharsets.UTF_8)); |
| len=5; |
| while(len>0) |
| len-=client.getInputStream().read(buffer); |
| |
| client.close(); |
| } |
| |
| @Test |
| public void testRenegotiate() throws Exception |
| { |
| SSLSocket client = newClient(); |
| client.setSoTimeout(60000); |
| |
| SocketChannel server = _connector.accept(); |
| server.configureBlocking(false); |
| _manager.accept(server); |
| |
| client.getOutputStream().write("Hello".getBytes(StandardCharsets.UTF_8)); |
| byte[] buffer = new byte[1024]; |
| int len=client.getInputStream().read(buffer); |
| Assert.assertEquals(5, len); |
| Assert.assertEquals("Hello",new String(buffer,0,len,StandardCharsets.UTF_8)); |
| |
| client.startHandshake(); |
| |
| client.getOutputStream().write("World".getBytes(StandardCharsets.UTF_8)); |
| len=client.getInputStream().read(buffer); |
| Assert.assertEquals(5, len); |
| Assert.assertEquals("World",new String(buffer,0,len,StandardCharsets.UTF_8)); |
| |
| client.close(); |
| } |
| |
| @Test |
| public void testRenegotiateNotAllowed() throws Exception |
| { |
| __sslCtxFactory.setRenegotiationAllowed(false); |
| |
| SSLSocket client = newClient(); |
| client.setSoTimeout(60000); |
| |
| SocketChannel server = _connector.accept(); |
| server.configureBlocking(false); |
| _manager.accept(server); |
| |
| client.getOutputStream().write("Hello".getBytes(StandardCharsets.UTF_8)); |
| byte[] buffer = new byte[1024]; |
| int len=client.getInputStream().read(buffer); |
| Assert.assertEquals(5, len); |
| Assert.assertEquals("Hello",new String(buffer,0,len,StandardCharsets.UTF_8)); |
| |
| client.startHandshake(); |
| |
| client.getOutputStream().write("World".getBytes(StandardCharsets.UTF_8)); |
| try |
| { |
| client.getInputStream().read(buffer); |
| Assert.fail(); |
| } |
| catch(SSLException e) |
| { |
| // expected |
| } |
| } |
| |
| @Test |
| public void testRenegotiateLimit() throws Exception |
| { |
| __sslCtxFactory.setRenegotiationAllowed(true); |
| __sslCtxFactory.setRenegotiationLimit(2); |
| |
| SSLSocket client = newClient(); |
| client.setSoTimeout(60000); |
| |
| SocketChannel server = _connector.accept(); |
| server.configureBlocking(false); |
| _manager.accept(server); |
| |
| client.getOutputStream().write("Good".getBytes(StandardCharsets.UTF_8)); |
| byte[] buffer = new byte[1024]; |
| int len=client.getInputStream().read(buffer); |
| Assert.assertEquals(4, len); |
| Assert.assertEquals("Good",new String(buffer,0,len,StandardCharsets.UTF_8)); |
| |
| client.startHandshake(); |
| |
| client.getOutputStream().write("Bye".getBytes(StandardCharsets.UTF_8)); |
| len=client.getInputStream().read(buffer); |
| Assert.assertEquals(3, len); |
| Assert.assertEquals("Bye",new String(buffer,0,len,StandardCharsets.UTF_8)); |
| |
| client.startHandshake(); |
| |
| client.getOutputStream().write("Cruel".getBytes(StandardCharsets.UTF_8)); |
| len=client.getInputStream().read(buffer); |
| Assert.assertEquals(5, len); |
| Assert.assertEquals("Cruel",new String(buffer,0,len,StandardCharsets.UTF_8)); |
| |
| client.startHandshake(); |
| |
| client.getOutputStream().write("World".getBytes(StandardCharsets.UTF_8)); |
| try |
| { |
| client.getInputStream().read(buffer); |
| Assert.fail(); |
| } |
| catch(SSLException e) |
| { |
| // expected |
| } |
| } |
| |
| |
| |
| @Test |
| public void testWriteOnConnect() throws Exception |
| { |
| _testFill=false; |
| |
| _writeCallback = new FutureCallback(); |
| Socket client = newClient(); |
| client.setSoTimeout(10000); |
| |
| SocketChannel server = _connector.accept(); |
| server.configureBlocking(false); |
| _manager.accept(server); |
| |
| byte[] buffer = new byte[1024]; |
| int len=client.getInputStream().read(buffer); |
| Assert.assertEquals("Hello Client",new String(buffer,0,len,StandardCharsets.UTF_8)); |
| Assert.assertEquals(null,_writeCallback.get(100,TimeUnit.MILLISECONDS)); |
| client.close(); |
| } |
| |
| |
| |
| @Test |
| public void testBlockedWrite() throws Exception |
| { |
| Socket client = newClient(); |
| client.setSoTimeout(5000); |
| |
| SocketChannel server = _connector.accept(); |
| server.configureBlocking(false); |
| _manager.accept(server); |
| |
| __startBlocking.set(5); |
| __blockFor.set(3); |
| |
| client.getOutputStream().write("Hello".getBytes(StandardCharsets.UTF_8)); |
| byte[] buffer = new byte[1024]; |
| int len=client.getInputStream().read(buffer); |
| Assert.assertEquals(5, len); |
| Assert.assertEquals("Hello",new String(buffer,0,len,StandardCharsets.UTF_8)); |
| |
| _dispatches.set(0); |
| client.getOutputStream().write("World".getBytes(StandardCharsets.UTF_8)); |
| len=5; |
| while(len>0) |
| len-=client.getInputStream().read(buffer); |
| Assert.assertEquals(0, len); |
| client.close(); |
| } |
| |
| @Test |
| public void testManyLines() throws Exception |
| { |
| final Socket client = newClient(); |
| client.setSoTimeout(10000); |
| |
| SocketChannel server = _connector.accept(); |
| server.configureBlocking(false); |
| _manager.accept(server); |
| |
| final int LINES=20; |
| final CountDownLatch count=new CountDownLatch(LINES); |
| |
| |
| new Thread() |
| { |
| @Override |
| public void run() |
| { |
| try |
| { |
| BufferedReader in = new BufferedReader(new InputStreamReader(client.getInputStream(),StandardCharsets.UTF_8)); |
| while(count.getCount()>0) |
| { |
| String line=in.readLine(); |
| if (line==null) |
| break; |
| // System.err.println(line); |
| count.countDown(); |
| } |
| } |
| catch(IOException e) |
| { |
| e.printStackTrace(); |
| } |
| } |
| }.start(); |
| |
| for (int i=0;i<LINES;i++) |
| { |
| client.getOutputStream().write(("HelloWorld "+i+"\n").getBytes(StandardCharsets.UTF_8)); |
| // System.err.println("wrote"); |
| if (i%1000==0) |
| { |
| client.getOutputStream().flush(); |
| Thread.sleep(10); |
| } |
| } |
| |
| Assert.assertTrue(count.await(20,TimeUnit.SECONDS)); |
| client.close(); |
| |
| } |
| |
| |
| } |