summaryrefslogtreecommitdiff
path: root/qpid/java/broker-core/src/main/java/org/apache/qpid/server/transport/NonBlockingConnection.java
diff options
context:
space:
mode:
Diffstat (limited to 'qpid/java/broker-core/src/main/java/org/apache/qpid/server/transport/NonBlockingConnection.java')
-rw-r--r--qpid/java/broker-core/src/main/java/org/apache/qpid/server/transport/NonBlockingConnection.java642
1 files changed, 642 insertions, 0 deletions
diff --git a/qpid/java/broker-core/src/main/java/org/apache/qpid/server/transport/NonBlockingConnection.java b/qpid/java/broker-core/src/main/java/org/apache/qpid/server/transport/NonBlockingConnection.java
new file mode 100644
index 0000000000..ae5816a0d1
--- /dev/null
+++ b/qpid/java/broker-core/src/main/java/org/apache/qpid/server/transport/NonBlockingConnection.java
@@ -0,0 +1,642 @@
+/*
+*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+package org.apache.qpid.server.transport;
+
+import java.io.IOException;
+import java.net.SocketAddress;
+import java.nio.ByteBuffer;
+import java.nio.channels.SocketChannel;
+import java.security.Principal;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+import java.util.ListIterator;
+import java.util.Set;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLEngineResult;
+import javax.net.ssl.SSLPeerUnverifiedException;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.qpid.server.protocol.ServerProtocolEngine;
+import org.apache.qpid.server.util.Action;
+import org.apache.qpid.transport.ByteBufferSender;
+import org.apache.qpid.transport.SenderException;
+import org.apache.qpid.transport.network.NetworkConnection;
+import org.apache.qpid.transport.network.Ticker;
+import org.apache.qpid.transport.network.TransportEncryption;
+import org.apache.qpid.transport.network.security.ssl.SSLUtil;
+import org.apache.qpid.util.SystemUtils;
+
+public class NonBlockingConnection implements NetworkConnection, ByteBufferSender
+{
+ private static final Logger LOGGER = LoggerFactory.getLogger(NonBlockingConnection.class);
+ private final SocketChannel _socketChannel;
+ private final long _timeout;
+ private final Ticker _ticker;
+ private final SelectorThread _selector;
+ private int _maxReadIdle;
+ private int _maxWriteIdle;
+ private Principal _principal;
+ private boolean _principalChecked;
+ private final Object _lock = new Object();
+
+ public static final int NUMBER_OF_BYTES_FOR_TLS_CHECK = 6;
+
+ private final ConcurrentLinkedQueue<ByteBuffer> _buffers = new ConcurrentLinkedQueue<>();
+ private final List<ByteBuffer> _encryptedOutput = new ArrayList<>();
+
+ private final String _remoteSocketAddress;
+ private final AtomicBoolean _closed = new AtomicBoolean(false);
+ private final ServerProtocolEngine _protocolEngine;
+ private final int _receiveBufSize;
+ private final Set<TransportEncryption> _encryptionSet;
+ private final SSLContext _sslContext;
+ private final Runnable _onTransportEncryptionAction;
+ private ByteBuffer _netInputBuffer;
+ private SSLEngine _sslEngine;
+
+ private ByteBuffer _currentBuffer;
+
+ private TransportEncryption _transportEncryption;
+ private SSLEngineResult _status;
+ private volatile boolean _fullyWritten = true;
+ private boolean _workDone;
+
+
+ public NonBlockingConnection(SocketChannel socketChannel,
+ ServerProtocolEngine delegate,
+ int sendBufferSize,
+ int receiveBufferSize,
+ long timeout,
+ Ticker ticker,
+ final Set<TransportEncryption> encryptionSet,
+ final SSLContext sslContext,
+ final boolean wantClientAuth,
+ final boolean needClientAuth,
+ final Collection<String> enabledCipherSuites,
+ final Collection<String> disabledCipherSuites,
+ final Runnable onTransportEncryptionAction,
+ final SelectorThread selectorThread)
+ {
+ _socketChannel = socketChannel;
+ _timeout = timeout;
+ _ticker = ticker;
+ _selector = selectorThread;
+
+ _protocolEngine = delegate;
+ _receiveBufSize = receiveBufferSize;
+ _encryptionSet = encryptionSet;
+ _sslContext = sslContext;
+ _onTransportEncryptionAction = onTransportEncryptionAction;
+
+ delegate.setWorkListener(new Action<ServerProtocolEngine>()
+ {
+ @Override
+ public void performAction(final ServerProtocolEngine object)
+ {
+ _selector.wakeup();
+ }
+ });
+
+ if(encryptionSet.size() == 1)
+ {
+ _transportEncryption = _encryptionSet.iterator().next();
+ if (_transportEncryption == TransportEncryption.TLS)
+ {
+ onTransportEncryptionAction.run();
+ }
+ }
+
+ if(encryptionSet.contains(TransportEncryption.TLS))
+ {
+ _sslEngine = _sslContext.createSSLEngine();
+ _sslEngine.setUseClientMode(false);
+ SSLUtil.removeSSLv3Support(_sslEngine);
+ SSLUtil.updateEnabledCipherSuites(_sslEngine, enabledCipherSuites, disabledCipherSuites);
+
+ if(needClientAuth)
+ {
+ _sslEngine.setNeedClientAuth(true);
+ }
+ else if(wantClientAuth)
+ {
+ _sslEngine.setWantClientAuth(true);
+ }
+ _netInputBuffer = ByteBuffer.allocate(Math.max(_sslEngine.getSession().getPacketBufferSize(), _receiveBufSize * 2));
+ }
+
+ try
+ {
+ _remoteSocketAddress = _socketChannel.getRemoteAddress().toString();
+ _socketChannel.configureBlocking(false);
+ }
+ catch (IOException e)
+ {
+ throw new SenderException("Unable to prepare the channel for non-blocking IO", e);
+ }
+
+
+ }
+
+
+ public Ticker getTicker()
+ {
+ return _ticker;
+ }
+
+ public SocketChannel getSocketChannel()
+ {
+ return _socketChannel;
+ }
+
+ public void start()
+ {
+ }
+
+ public ByteBufferSender getSender()
+ {
+ return this;
+ }
+
+ public void close()
+ {
+ LOGGER.debug("Closing " + _remoteSocketAddress);
+ if(_closed.compareAndSet(false,true))
+ {
+ _protocolEngine.notifyWork();
+ getSelector().wakeup();
+ }
+ }
+
+ public SocketAddress getRemoteAddress()
+ {
+ return _socketChannel.socket().getRemoteSocketAddress();
+ }
+
+ public SocketAddress getLocalAddress()
+ {
+ return _socketChannel.socket().getLocalSocketAddress();
+ }
+
+ public void setMaxWriteIdle(int sec)
+ {
+ _maxWriteIdle = sec;
+ }
+
+ public void setMaxReadIdle(int sec)
+ {
+ _maxReadIdle = sec;
+ }
+
+ @Override
+ public Principal getPeerPrincipal()
+ {
+ synchronized (_lock)
+ {
+ if(!_principalChecked)
+ {
+ if (_sslEngine != null)
+ {
+ try
+ {
+ _principal = _sslEngine.getSession().getPeerPrincipal();
+ }
+ catch (SSLPeerUnverifiedException e)
+ {
+ return null;
+ }
+ }
+
+ _principalChecked = true;
+ }
+
+ return _principal;
+ }
+ }
+
+ @Override
+ public int getMaxReadIdle()
+ {
+ return _maxReadIdle;
+ }
+
+ @Override
+ public int getMaxWriteIdle()
+ {
+ return _maxWriteIdle;
+ }
+
+ public boolean canRead()
+ {
+ return true;
+ }
+
+ public boolean waitingForWrite()
+ {
+ return !_fullyWritten;
+ }
+
+ public boolean isStateChanged()
+ {
+
+ return _protocolEngine.hasWork();
+ }
+
+ public boolean doWork()
+ {
+ _protocolEngine.clearWork();
+ final boolean closed = _closed.get();
+ if (!closed)
+ {
+ try
+ {
+ _workDone = false;
+
+ long currentTime = System.currentTimeMillis();
+ int tick = _ticker.getTimeToNextTick(currentTime);
+ if (tick <= 0)
+ {
+ _ticker.tick(currentTime);
+ }
+
+ _protocolEngine.setMessageAssignmentSuspended(true);
+
+ _protocolEngine.processPending();
+
+ _protocolEngine.setTransportBlockedForWriting(!doWrite());
+ boolean dataRead = doRead();
+ _fullyWritten = doWrite();
+ _protocolEngine.setTransportBlockedForWriting(!_fullyWritten);
+
+ if(dataRead || (_workDone && _netInputBuffer != null && _netInputBuffer.position() != 0))
+ {
+ _protocolEngine.notifyWork();
+ }
+
+ // tell all consumer targets that it is okay to accept more
+ _protocolEngine.setMessageAssignmentSuspended(false);
+ }
+ catch (IOException e)
+ {
+ LOGGER.info("Exception performing I/O for thread '" + _remoteSocketAddress + "': " + e);
+ LOGGER.debug("Closing " + _remoteSocketAddress);
+ if(_closed.compareAndSet(false,true))
+ {
+ _protocolEngine.notifyWork();
+ }
+ }
+ }
+ else
+ {
+
+ if(!SystemUtils.isWindows())
+ {
+ try
+ {
+ _socketChannel.shutdownInput();
+ }
+ catch (IOException e)
+ {
+ LOGGER.info("Exception shutting down input for thread '" + _remoteSocketAddress + "': " + e);
+
+ }
+ }
+ try
+ {
+ while(!doWrite())
+ {
+ }
+ }
+ catch (IOException e)
+ {
+ LOGGER.info("Exception performing final write/close for thread '" + _remoteSocketAddress + "': " + e);
+
+ }
+ LOGGER.debug("Closing receiver");
+ _protocolEngine.closed();
+
+ try
+ {
+ if(!SystemUtils.isWindows())
+ {
+ _socketChannel.shutdownOutput();
+ }
+
+ _socketChannel.close();
+ }
+ catch (IOException e)
+ {
+ LOGGER.info("Exception closing socket thread '" + _remoteSocketAddress + "': " + e);
+ }
+ }
+
+ return closed;
+
+ }
+
+ public SelectorThread getSelector()
+ {
+ return _selector;
+ }
+
+ public boolean looksLikeSSLv2ClientHello(final byte[] headerBytes)
+ {
+ return headerBytes[0] == -128 &&
+ headerBytes[3] == 3 && // SSL 3.0 / TLS 1.x
+ (headerBytes[4] == 0 || // SSL 3.0
+ headerBytes[4] == 1 || // TLS 1.0
+ headerBytes[4] == 2 || // TLS 1.1
+ headerBytes[4] == 3);
+ }
+
+ public boolean doRead() throws IOException
+ {
+ boolean readData = false;
+ if(_transportEncryption == TransportEncryption.NONE)
+ {
+ int remaining = 0;
+ while (remaining == 0 && !_closed.get())
+ {
+ if (_currentBuffer == null || _currentBuffer.remaining() == 0)
+ {
+ _currentBuffer = ByteBuffer.allocate(_receiveBufSize);
+ }
+ int read = _socketChannel.read(_currentBuffer);
+ if(read > 0)
+ {
+ readData = true;
+ }
+ if (read == -1)
+ {
+ _closed.set(true);
+ }
+ remaining = _currentBuffer.remaining();
+ if (LOGGER.isDebugEnabled())
+ {
+ LOGGER.debug("Read " + read + " byte(s)");
+ }
+ ByteBuffer dup = _currentBuffer.duplicate();
+ dup.flip();
+ _currentBuffer = _currentBuffer.slice();
+ _protocolEngine.received(dup);
+ }
+ }
+ else if(_transportEncryption == TransportEncryption.TLS)
+ {
+ int read = 1;
+ while(!_closed.get() && read > 0 && _sslEngine.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NEED_WRAP && (_status == null || _status.getStatus() != SSLEngineResult.Status.CLOSED))
+ {
+ read = _socketChannel.read(_netInputBuffer);
+ if (read == -1)
+ {
+ _closed.set(true);
+ }
+ else if(read > 0)
+ {
+ readData = true;
+ }
+ if (LOGGER.isDebugEnabled())
+ {
+ LOGGER.debug("Read " + read + " encrypted bytes ");
+ }
+
+ _netInputBuffer.flip();
+
+
+ int unwrapped = 0;
+ boolean tasksRun;
+ do
+ {
+ ByteBuffer appInputBuffer =
+ ByteBuffer.allocate(_sslEngine.getSession().getApplicationBufferSize() + 50);
+
+ _status = _sslEngine.unwrap(_netInputBuffer, appInputBuffer);
+ tasksRun = runSSLEngineTasks(_status);
+
+ appInputBuffer.flip();
+ unwrapped = appInputBuffer.remaining();
+ if(unwrapped > 0)
+ {
+ readData = true;
+ }
+ _protocolEngine.received(appInputBuffer);
+ }
+ while(unwrapped > 0 || tasksRun);
+
+ _netInputBuffer.compact();
+
+ }
+ }
+ else
+ {
+ int read = 1;
+ while (!_closed.get() && read > 0)
+ {
+
+ read = _socketChannel.read(_netInputBuffer);
+ if (read == -1)
+ {
+ _closed.set(true);
+ }
+
+ if (LOGGER.isDebugEnabled())
+ {
+ LOGGER.debug("Read " + read + " possibly encrypted bytes " + _netInputBuffer);
+ }
+
+ if (_netInputBuffer.position() >= NUMBER_OF_BYTES_FOR_TLS_CHECK)
+ {
+ _netInputBuffer.flip();
+ final byte[] headerBytes = new byte[NUMBER_OF_BYTES_FOR_TLS_CHECK];
+ ByteBuffer dup = _netInputBuffer.duplicate();
+ dup.get(headerBytes);
+
+ _transportEncryption = looksLikeSSL(headerBytes) ? TransportEncryption.TLS : TransportEncryption.NONE;
+ LOGGER.debug("Identified transport encryption as " + _transportEncryption);
+
+ if (_transportEncryption == TransportEncryption.NONE)
+ {
+ _protocolEngine.received(_netInputBuffer);
+ }
+ else
+ {
+ _onTransportEncryptionAction.run();
+ _netInputBuffer.compact();
+ readData = doRead();
+ }
+ break;
+ }
+ }
+ }
+ return readData;
+ }
+
+ public boolean doWrite() throws IOException
+ {
+
+ ByteBuffer[] bufArray = new ByteBuffer[_buffers.size()];
+ Iterator<ByteBuffer> bufferIterator = _buffers.iterator();
+ for (int i = 0; i < bufArray.length; i++)
+ {
+ bufArray[i] = bufferIterator.next();
+ }
+
+ int byteBuffersWritten = 0;
+
+ if(_transportEncryption == TransportEncryption.NONE)
+ {
+
+
+ long written = _socketChannel.write(bufArray);
+ if (LOGGER.isDebugEnabled())
+ {
+ LOGGER.debug("Written " + written + " bytes");
+ }
+
+ for (ByteBuffer buf : bufArray)
+ {
+ if (buf.remaining() == 0)
+ {
+ byteBuffersWritten++;
+ _buffers.poll();
+ }
+ }
+
+
+ return bufArray.length == byteBuffersWritten;
+ }
+ else if(_transportEncryption == TransportEncryption.TLS)
+ {
+ int remaining = 0;
+ do
+ {
+ if(_sslEngine.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NEED_UNWRAP)
+ {
+ _workDone = true;
+ final ByteBuffer netBuffer = ByteBuffer.allocate(_sslEngine.getSession().getPacketBufferSize());
+ _status = _sslEngine.wrap(bufArray, netBuffer);
+ runSSLEngineTasks(_status);
+
+ netBuffer.flip();
+ remaining = netBuffer.remaining();
+ if (remaining != 0)
+ {
+ _encryptedOutput.add(netBuffer);
+ }
+ for (ByteBuffer buf : bufArray)
+ {
+ if (buf.remaining() == 0)
+ {
+ byteBuffersWritten++;
+ _buffers.poll();
+ }
+ }
+ }
+
+ }
+ while(remaining != 0 && _sslEngine.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NEED_UNWRAP);
+ ByteBuffer[] encryptedBuffers = _encryptedOutput.toArray(new ByteBuffer[_encryptedOutput.size()]);
+ long written = _socketChannel.write(encryptedBuffers);
+ if (LOGGER.isDebugEnabled())
+ {
+ LOGGER.debug("Written " + written + " encrypted bytes");
+ }
+ ListIterator<ByteBuffer> iter = _encryptedOutput.listIterator();
+ while(iter.hasNext())
+ {
+ ByteBuffer buf = iter.next();
+ if(buf.remaining() == 0)
+ {
+ iter.remove();
+ }
+ else
+ {
+ break;
+ }
+ }
+
+ return bufArray.length == byteBuffersWritten;
+
+ }
+ else
+ {
+ return true;
+ }
+ }
+
+ public boolean looksLikeSSLv3ClientHello(final byte[] headerBytes)
+ {
+ return headerBytes[0] == 22 && // SSL Handshake
+ (headerBytes[1] == 3 && // SSL 3.0 / TLS 1.x
+ (headerBytes[2] == 0 || // SSL 3.0
+ headerBytes[2] == 1 || // TLS 1.0
+ headerBytes[2] == 2 || // TLS 1.1
+ headerBytes[2] == 3)) && // TLS1.2
+ (headerBytes[5] == 1); // client_hello
+ }
+
+ public boolean runSSLEngineTasks(final SSLEngineResult status)
+ {
+ if(status.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK)
+ {
+ Runnable task;
+ while((task = _sslEngine.getDelegatedTask()) != null)
+ {
+ task.run();
+ }
+ return true;
+ }
+ return false;
+ }
+
+ public boolean looksLikeSSL(final byte[] headerBytes)
+ {
+ return looksLikeSSLv3ClientHello(headerBytes) || looksLikeSSLv2ClientHello(headerBytes);
+ }
+
+ @Override
+ public void send(final ByteBuffer msg)
+ {
+ assert Thread.currentThread().getName().startsWith(SelectorThread.IO_THREAD_NAME_PREFIX) : "Send called by unexpected thread " + Thread.currentThread().getName();
+
+ if (_closed.get())
+ {
+ LOGGER.warn("Send ignored as the connection is already closed");
+ }
+ else
+ {
+ _buffers.add(msg);
+ _protocolEngine.notifyWork();
+ }
+ }
+
+ @Override
+ public void flush()
+ {
+ }
+}