summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTodd Lipcon <todd@apache.org>2010-09-28 00:11:01 +0000
committerTodd Lipcon <todd@apache.org>2010-09-28 00:11:01 +0000
commitb1a283f11e50650acc1b0730200b17bf8c5fac07 (patch)
tree946a409d029cb0735b2a280e9c8baa2cdd2d9fc7
parent84a7c2a901ee11433ca755edad1c278172ba7644 (diff)
downloadthrift-b1a283f11e50650acc1b0730200b17bf8c5fac07.tar.gz
THRIFT-912. java: Fix some bugs in SASL implementation, update protocol spec slightly
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@1001973 13f79535-47bb-0310-9956-ffa450edef68
-rw-r--r--doc/thrift-sasl-spec.txt33
-rw-r--r--lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java26
-rw-r--r--lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java34
-rw-r--r--lib/java/src/org/apache/thrift/transport/TSaslTransport.java136
-rw-r--r--lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java311
5 files changed, 397 insertions, 143 deletions
diff --git a/doc/thrift-sasl-spec.txt b/doc/thrift-sasl-spec.txt
index 59bfcf98b..02cf79e93 100644
--- a/doc/thrift-sasl-spec.txt
+++ b/doc/thrift-sasl-spec.txt
@@ -1,6 +1,5 @@
-A Thrift SASL message shall be a byte array of one of the following forms:
+A Thrift SASL message shall be a byte array of the following form:
-| 1-byte START status code | 1-byte mechanism name length | variable length mechanism name | 4-byte payload length | variable-length payload |
| 1-byte status code | 4-byte payload length | variable-length payload |
The length fields shall be interpreted as integers, with the high byte sent
@@ -24,15 +23,10 @@ underlying SASL security mechanism that it supports.
name -> mechanism options.
3. At connection time, the client will initiate communication by sending the
-server a START byte, followed by a 1-byte field indicating the length in bytes
-of the underlying security mechanism name that the client would like to use.
+server a START message. The payload of this message will be the name of the
+underlying security mechanism that the client would like to use.
This mechanism name shall be 1-20 characters in length, and follow the
-specifications for SASL mechanism names specified in RFC 2222. This mechanism
-name shall be followed by a 4-byte, potentially zero-value message length word,
-followed by a potentially zero-length payload. The payload is determined by the
-output byte array of the underlying actual security mechanism, and will be
-empty except for those underlying security protocols which implement the
-optional SASL initial response.
+specifications for SASL mechanism names specified in RFC 2222.
4. The server receives this message and, if the mechanism name provided is
among the set of mechanisms this server transport is configured to accept,
@@ -44,18 +38,25 @@ status code or message indicating failure. No further communication may take
place via this transport. If the mechanism name is one which the server
supports, then proceed to step 5.
-5. The server then provides the byte array of the payload received to its
+5. Following the START message, the client must send another message containing
+the "initial response" of the chosen SASL implementation. The client may send
+this message piggy-backed on the "START" message of step 3. The message type
+of this message must be either "OK" or "COMPLETE", depending on whether the
+SASL implementation indicates that this side of the authentication has been
+satisfied.
+
+6. The server then provides the byte array of the payload received to its
underlying security mechanism. A challenge is generated by the underlying
security mechanism on the server, and this is used as the payload for a message
sent to the client. This message shall consist of an OK byte, followed by the
non-zero message length word, followed by the payload.
-6. The client receives this message from the server and passes the payload to
+7. The client receives this message from the server and passes the payload to
its underlying security mechanism to generate a response. The client then sends
the server an OK byte, followed by the non-zero-value length of the response,
followed by the bytes of the response as the payload.
-7. Steps 5 and 6 are repeated until both security mechanisms are satisfied with
+8. Steps 6 and 7 are repeated until both security mechanisms are satisfied with
the challenge/response exchange. When either side has completed its security
protocol, its next message shall be the COMPLETE byte, followed by a 4-byte
potentially zero-value length word, followed by a potentially zero-length
@@ -78,10 +79,10 @@ be passed to the protocol above the thrift transport by whatever mechanism is
appropriate and idiomatic for the particular language these thrift bindings are
for.
-If step 7 completes successfully, then the communication is considered
+If step 8 completes successfully, then the communication is considered
authenticated and subsequent communication may commence.
-If step 7 fails to complete successfully, then no further communication may
+If step 8 fails to complete successfully, then no further communication may
take place via this transport.
8. All writes to the underlying transport must be prefixed by the 4-byte length
@@ -89,7 +90,7 @@ of the payload data, followed by the payload. All reads from this transport
should read the 4-byte length word, then read the full quantity of bytes
specified by this length word.
-If no SASL QOP (quality of protection) is negotiated during steps 5 and 6, then
+If no SASL QOP (quality of protection) is negotiated during steps 6 and 7, then
all subsequent writes to/reads from this transport are written/read unaltered,
save for the length prefix, to the underlying transport.
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java
index fc8a3ea29..8c1d0e5af 100644
--- a/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java
@@ -75,6 +75,12 @@ public class TSaslClientTransport extends TSaslTransport {
this.mechanism = mechanism;
}
+
+ @Override
+ protected SaslRole getRole() {
+ return SaslRole.CLIENT;
+ }
+
/**
* Performs the client side of the initial portion of the Thrift SASL
* protocol. Generates and sends the initial response to the server, including
@@ -88,21 +94,15 @@ public class TSaslClientTransport extends TSaslTransport {
if (saslClient.hasInitialResponse())
initialResponse = saslClient.evaluateChallenge(initialResponse);
- byte[] mechanismBytes = mechanism.getBytes();
- byte[] messageHeader = new byte[STATUS_BYTES + MECHANISM_NAME_BYTES + mechanismBytes.length
- + PAYLOAD_LENGTH_BYTES];
-
- messageHeader[0] = START;
- messageHeader[1] = (byte) (0xff & mechanismBytes.length);
- System.arraycopy(mechanismBytes, 0, messageHeader, STATUS_BYTES + MECHANISM_NAME_BYTES,
- mechanismBytes.length);
- EncodingUtils.encodeBigEndian(initialResponse.length, messageHeader, STATUS_BYTES
- + MECHANISM_NAME_BYTES + mechanismBytes.length);
-
LOGGER.debug("Sending mechanism name {} and initial response of length {}", mechanism,
initialResponse.length);
- underlyingTransport.write(messageHeader);
- underlyingTransport.write(initialResponse);
+
+ byte[] mechanismBytes = mechanism.getBytes();
+ sendSaslMessage(NegotiationStatus.START,
+ mechanismBytes);
+ // Send initial response
+ sendSaslMessage(saslClient.isComplete() ? NegotiationStatus.COMPLETE : NegotiationStatus.OK,
+ initialResponse);
underlyingTransport.flush();
}
}
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java
index b07e59727..8abcf360e 100644
--- a/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java
@@ -108,6 +108,11 @@ public class TSaslServerTransport extends TSaslTransport {
props, cbh));
}
+ @Override
+ protected SaslRole getRole() {
+ return SaslRole.SERVER;
+ }
+
/**
* Performs the server side of the initial portion of the Thrift SASL protocol.
* Receives the initial response from the client, creates a SASL server using
@@ -116,35 +121,24 @@ public class TSaslServerTransport extends TSaslTransport {
*/
@Override
protected void handleSaslStartMessage() throws TTransportException, SaslException {
- // Get the status byte and length of the mechanism name.
- byte[] messageHeader = new byte[STATUS_BYTES + MECHANISM_NAME_BYTES];
- underlyingTransport.readAll(messageHeader, 0, messageHeader.length);
- LOGGER.debug("Received status {} and mechanism name length {}", messageHeader[0],
- messageHeader[1]);
- if (messageHeader[0] != START) {
- sendAndThrowMessage(ERROR, "Expecting START status, received " + messageHeader[0]);
+ SaslResponse message = receiveSaslMessage();
+
+ LOGGER.debug("Received start message with status {}", message.status);
+ if (message.status != NegotiationStatus.START) {
+ sendAndThrowMessage(NegotiationStatus.ERROR, "Expecting START status, received " + message.status);
}
// Get the mechanism name.
- byte[] mechanismBytes = new byte[messageHeader[1]];
- underlyingTransport.readAll(mechanismBytes, 0, mechanismBytes.length);
-
- String mechanismName = new String(mechanismBytes);
- TSaslServerDefinition serverDefinition = serverDefinitionMap.get(new String(mechanismBytes));
+ String mechanismName = new String(message.payload);
+ TSaslServerDefinition serverDefinition = serverDefinitionMap.get(mechanismName);
LOGGER.debug("Received mechanism name '{}'", mechanismName);
if (serverDefinition == null) {
- sendAndThrowMessage(BAD, "Unsupported mechanism type " + mechanismName);
+ sendAndThrowMessage(NegotiationStatus.BAD, "Unsupported mechanism type " + mechanismName);
}
SaslServer saslServer = Sasl.createSaslServer(serverDefinition.mechanism,
serverDefinition.protocol, serverDefinition.serverName, serverDefinition.props,
serverDefinition.cbh);
-
- // Evaluate the initial response and send the first challenge.
- byte[] initialResponse = new byte[readLength()];
- sendSaslMessage(saslServer.isComplete() ? COMPLETE : OK, saslServer
- .evaluateResponse(initialResponse));
-
setSaslServer(saslServer);
}
@@ -221,7 +215,7 @@ public class TSaslServerTransport extends TSaslTransport {
ret.open();
} catch (TTransportException e) {
LOGGER.debug("failed to open server transport", e);
- return null;
+ throw new RuntimeException(e);
}
transportMap.put(base, ret);
} else {
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
index b5eadb74f..24470d9dc 100644
--- a/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
@@ -21,6 +21,8 @@ package org.apache.thrift.transport;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
+import java.util.Map;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
@@ -48,16 +50,42 @@ abstract class TSaslTransport extends TTransport {
protected static final int STATUS_BYTES = 1;
protected static final int PAYLOAD_LENGTH_BYTES = 4;
+ protected static enum SaslRole {
+ SERVER, CLIENT;
+ }
+
/**
* Status bytes used during the initial Thrift SASL handshake.
*/
- protected static final byte START = 0x01;
- protected static final byte OK = 0x02;
- protected static final byte BAD = 0x03;
- protected static final byte ERROR = 0x04;
- protected static final byte COMPLETE = 0x05;
+ protected static enum NegotiationStatus {
+ START((byte)0x01),
+ OK((byte)0x02),
+ BAD((byte)0x03),
+ ERROR((byte)0x04),
+ COMPLETE((byte)0x05);
+
+ private final byte value;
+
+ private static final Map<Byte, NegotiationStatus> reverseMap =
+ new HashMap<Byte, NegotiationStatus>();
+ static {
+ for (NegotiationStatus s : NegotiationStatus.class.getEnumConstants()) {
+ reverseMap.put(s.getValue(), s);
+ }
+ }
- protected static final Set<Byte> VALID_STATUSES = new HashSet<Byte>(Arrays.asList(START, OK, BAD, ERROR, COMPLETE));
+ private NegotiationStatus(byte val) {
+ this.value = val;
+ }
+
+ public byte getValue() {
+ return value;
+ }
+
+ public static NegotiationStatus byValue(byte val) {
+ return reverseMap.get(val);
+ }
+ }
/**
* Transport underlying this one.
@@ -126,14 +154,16 @@ abstract class TSaslTransport extends TTransport {
* The data to send as the payload of this message.
* @throws TTransportException
*/
- protected void sendSaslMessage(byte status, byte[] payload) throws TTransportException {
+ protected void sendSaslMessage(NegotiationStatus status, byte[] payload) throws TTransportException {
if (payload == null)
payload = new byte[0];
- messageHeader[0] = status;
+ messageHeader[0] = status.getValue();
EncodingUtils.encodeBigEndian(payload.length, messageHeader, STATUS_BYTES);
- LOGGER.debug("Writing message with status {} and payload length {}", status, payload.length);
+ if (LOGGER.isDebugEnabled())
+ LOGGER.debug(getRole() + ": Writing message with status {} and payload length {}",
+ status, payload.length);
underlyingTransport.write(messageHeader);
underlyingTransport.write(payload);
underlyingTransport.flush();
@@ -150,21 +180,25 @@ abstract class TSaslTransport extends TTransport {
protected SaslResponse receiveSaslMessage() throws TTransportException {
underlyingTransport.readAll(messageHeader, 0, messageHeader.length);
- byte status = messageHeader[0];
+ byte statusByte = messageHeader[0];
byte[] payload = new byte[EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES)];
underlyingTransport.readAll(payload, 0, payload.length);
- if (!VALID_STATUSES.contains(status))
- sendAndThrowMessage(ERROR, "Invalid status " + status);
- else if (status == BAD || status == ERROR) {
+ NegotiationStatus status = NegotiationStatus.byValue(statusByte);
+ if (status == null) {
+ sendAndThrowMessage(NegotiationStatus.ERROR, "Invalid status " + statusByte);
+ } else if (status == NegotiationStatus.BAD || status == NegotiationStatus.ERROR) {
try {
- throw new TTransportException(new String(payload, "UTF-8"));
+ String remoteMessage = new String(payload, "UTF-8");
+ throw new TTransportException("Peer indicated failure: " + remoteMessage);
} catch (UnsupportedEncodingException e) {
throw new TTransportException(e);
}
}
- LOGGER.debug("Received message with status {} and payload length {}", status, payload.length);
+ if (LOGGER.isDebugEnabled())
+ LOGGER.debug(getRole() + ": Received message with status {} and payload length {}",
+ status, payload.length);
return new SaslResponse(status, payload);
}
@@ -180,8 +214,13 @@ abstract class TSaslTransport extends TTransport {
* @throws TTransportException
* Always thrown with the message provided.
*/
- protected void sendAndThrowMessage(byte status, String message) throws TTransportException {
- sendSaslMessage(status, message.getBytes());
+ protected void sendAndThrowMessage(NegotiationStatus status, String message) throws TTransportException {
+ try {
+ sendSaslMessage(status, message.getBytes());
+ } catch (Exception e) {
+ LOGGER.warn("Could not send failure response", e);
+ message += "\nAlso, could not send response: " + e.toString();
+ }
throw new TTransportException(message);
}
@@ -195,6 +234,8 @@ abstract class TSaslTransport extends TTransport {
*/
abstract protected void handleSaslStartMessage() throws TTransportException, SaslException;
+ protected abstract SaslRole getRole();
+
/**
* Opens the underlying transport if it's not already open and then performs
* SASL negotiation. If a QOP is negoiated during this SASL handshake, it used
@@ -210,24 +251,55 @@ abstract class TSaslTransport extends TTransport {
underlyingTransport.open();
try {
+ // Negotiate a SASL mechanism. The client also sends its
+ // initial response, or an empty one.
handleSaslStartMessage();
+ LOGGER.debug("{}: Start message handled", getRole());
- SaslResponse message;
- do {
+ SaslResponse message = null;
+ while (!sasl.isComplete()) {
message = receiveSaslMessage();
- if (message.status != COMPLETE && message.status != OK) {
+ if (message.status != NegotiationStatus.COMPLETE &&
+ message.status != NegotiationStatus.OK) {
throw new TTransportException("Expected COMPLETE or OK, got " + message.status);
}
- if (sasl.isComplete() && message.status == COMPLETE)
+ byte[] challenge = sasl.evaluateChallengeOrResponse(message.payload);
+
+ // If we are the client, and the server indicates COMPLETE, we don't need to
+ // send back any further response.
+ if (message.status == NegotiationStatus.COMPLETE &&
+ getRole() == SaslRole.CLIENT) {
+ LOGGER.debug("{}: All done!", getRole());
break;
+ }
- byte[] challenge = sasl.evaluateChallengeOrResponse(message.payload);
- sendSaslMessage(sasl.isComplete() ? COMPLETE : OK, challenge);
- } while (!(sasl.isComplete() && message.status == COMPLETE));
+ sendSaslMessage(sasl.isComplete() ? NegotiationStatus.COMPLETE : NegotiationStatus.OK,
+ challenge);
+ }
+ LOGGER.debug("{}: Main negotiation loop complete", getRole());
+
+ assert sasl.isComplete();
+
+ // If we're the client, and we're complete, but the server isn't
+ // complete yet, we need to wait for its response. This will occur
+ // with ANONYMOUS auth, for example, where we send an initial response
+ // and are immediately complete.
+ if (getRole() == SaslRole.CLIENT &&
+ (message == null || message.status == NegotiationStatus.OK)) {
+ LOGGER.debug("{}: SASL Client receiving last message", getRole());
+ message = receiveSaslMessage();
+ if (message.status != NegotiationStatus.COMPLETE) {
+ throw new TTransportException(
+ "Expected SASL COMPLETE, but got " + message.status);
+ }
+ }
} catch (SaslException e) {
- underlyingTransport.close();
- sendAndThrowMessage(BAD, e.getMessage());
+ try {
+ sendAndThrowMessage(NegotiationStatus.BAD, e.getMessage());
+ } finally {
+ underlyingTransport.close();
+ }
}
String qop = (String) sasl.getNegotiatedProperty(Sasl.QOP);
@@ -241,7 +313,7 @@ abstract class TSaslTransport extends TTransport {
* @return The <code>SaslClient</code>, or <code>null</code> if this transport
* is backed by a <code>SaslServer</code>.
*/
- protected SaslClient getSaslClient() {
+ public SaslClient getSaslClient() {
return sasl.saslClient;
}
@@ -251,7 +323,7 @@ abstract class TSaslTransport extends TTransport {
* @return The <code>SaslServer</code>, or <code>null</code> if this transport
* is backed by a <code>SaslClient</code>.
*/
- protected SaslServer getSaslServer() {
+ public SaslServer getSaslServer() {
return sasl.saslServer;
}
@@ -348,7 +420,7 @@ abstract class TSaslTransport extends TTransport {
throw new TTransportException("Read a negative frame size (" + dataLength + ")!");
byte[] buff = new byte[dataLength];
- LOGGER.debug("reading data length: {}", dataLength);
+ LOGGER.debug("{}: reading data length: {}", getRole(), dataLength);
underlyingTransport.readAll(buff, 0, dataLength);
if (shouldWrap) {
buff = sasl.unwrap(buff, 0, buff.length);
@@ -396,11 +468,11 @@ abstract class TSaslTransport extends TTransport {
/**
* Used exclusively by readSaslMessage to return both a status and data.
*/
- private static class SaslResponse {
- public byte status;
+ protected static class SaslResponse {
+ public NegotiationStatus status;
public byte[] payload;
- public SaslResponse(byte status, byte[] payload) {
+ public SaslResponse(NegotiationStatus status, byte[] payload) {
this.status = status;
this.payload = payload;
}
diff --git a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
index 812028d1c..ca121c1b3 100644
--- a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
+++ b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
@@ -31,6 +31,10 @@ import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslClientFactory;
+import javax.security.sasl.SaslServer;
+import javax.security.sasl.SaslServerFactory;
import javax.security.sasl.SaslException;
import org.apache.thrift.TProcessor;
@@ -75,13 +79,19 @@ public class TestTSaslTransports extends TestCase {
private static class TestSaslCallbackHandler implements CallbackHandler {
+ private final String password;
+
+ public TestSaslCallbackHandler(String password) {
+ this.password = password;
+ }
+
@Override
public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
for (Callback c : callbacks) {
if (c instanceof NameCallback) {
((NameCallback) c).setName(PRINCIPAL);
} else if (c instanceof PasswordCallback) {
- ((PasswordCallback) c).setPassword(PASSWORD.toCharArray());
+ ((PasswordCallback) c).setPassword(password.toCharArray());
} else if (c instanceof AuthorizeCallback) {
((AuthorizeCallback) c).setAuthorized(true);
} else if (c instanceof RealmCallback) {
@@ -93,39 +103,63 @@ public class TestTSaslTransports extends TestCase {
}
}
- private void testSaslOpen(final String mechanism, final Map<String, String> props)
- throws SaslException, TTransportException {
- Thread serverThread = new Thread() {
- public void run() {
- try {
- TServerSocket serverSocket = new TServerSocket(ServerTestBase.PORT);
- TTransport serverTransport = serverSocket.accept();
- TTransport saslServerTransport = new TSaslServerTransport(mechanism, SERVICE, HOST,
- props, new TestSaslCallbackHandler(), serverTransport);
-
- saslServerTransport.open();
-
- byte[] inBuf = new byte[testMessage1.getBytes().length];
- // Deliberately read less than the full buffer to ensure
- // that TSaslTransport is correctly buffering reads. This
- // will fail for the WRAPPED test, if it doesn't work.
- saslServerTransport.readAll(inBuf, 0, 5);
- saslServerTransport.readAll(inBuf, 5, 10);
- saslServerTransport.readAll(inBuf, 15, inBuf.length - 15);
- LOGGER.debug("server got: {}", new String(inBuf));
- assertEquals(new String(inBuf), testMessage1);
-
- LOGGER.debug("server writing: {}", testMessage2);
- saslServerTransport.write(testMessage2.getBytes());
- saslServerTransport.flush();
-
- serverSocket.close();
- saslServerTransport.close();
- } catch (TTransportException e) {
- fail(e.toString());
- }
+ private class ServerThread extends Thread {
+ final String mechanism;
+ final Map<String, String> props;
+ volatile Throwable thrown;
+
+ public ServerThread(String mechanism, Map<String, String> props) {
+ this.mechanism = mechanism;
+ this.props = props;
+ }
+
+ public void run() {
+ try {
+ internalRun();
+ } catch (Throwable t) {
+ thrown = t;
+ }
+ }
+
+ private void internalRun() throws Exception {
+ TServerSocket serverSocket = new TServerSocket(ServerTestBase.PORT);
+ try {
+ acceptAndWrite(serverSocket);
+ } finally {
+ serverSocket.close();
}
- };
+ }
+
+ private void acceptAndWrite(TServerSocket serverSocket)
+ throws Exception {
+ TTransport serverTransport = serverSocket.accept();
+ TTransport saslServerTransport = new TSaslServerTransport(
+ mechanism, SERVICE, HOST,
+ props, new TestSaslCallbackHandler(PASSWORD), serverTransport);
+
+ saslServerTransport.open();
+
+ byte[] inBuf = new byte[testMessage1.getBytes().length];
+ // Deliberately read less than the full buffer to ensure
+ // that TSaslTransport is correctly buffering reads. This
+ // will fail for the WRAPPED test, if it doesn't work.
+ saslServerTransport.readAll(inBuf, 0, 5);
+ saslServerTransport.readAll(inBuf, 5, 10);
+ saslServerTransport.readAll(inBuf, 15, inBuf.length - 15);
+ LOGGER.debug("server got: {}", new String(inBuf));
+ assertEquals(new String(inBuf), testMessage1);
+
+ LOGGER.debug("server writing: {}", testMessage2);
+ saslServerTransport.write(testMessage2.getBytes());
+ saslServerTransport.flush();
+
+ saslServerTransport.close();
+ }
+ }
+
+ private void testSaslOpen(final String mechanism, final Map<String, String> props)
+ throws Exception {
+ ServerThread serverThread = new ServerThread(mechanism, props);
serverThread.start();
try {
@@ -134,44 +168,95 @@ public class TestTSaslTransports extends TestCase {
// Ah well.
}
- TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
- TTransport saslClientTransport = new TSaslClientTransport(mechanism,
- PRINCIPAL, SERVICE, HOST, props, new TestSaslCallbackHandler(), clientSocket);
- saslClientTransport.open();
- LOGGER.debug("client writing: {}", testMessage1);
- saslClientTransport.write(testMessage1.getBytes());
- saslClientTransport.flush();
-
- byte[] inBuf = new byte[testMessage2.getBytes().length];
- saslClientTransport.readAll(inBuf, 0, inBuf.length);
- LOGGER.debug("client got: {}", new String(inBuf));
- assertEquals(new String(inBuf), testMessage2);
-
- TTransportException expectedException = null;
try {
+ TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
+ TTransport saslClientTransport = new TSaslClientTransport(mechanism,
+ PRINCIPAL, SERVICE, HOST, props, new TestSaslCallbackHandler(PASSWORD), clientSocket);
saslClientTransport.open();
- } catch (TTransportException e) {
- expectedException = e;
- }
- assertNotNull(expectedException);
+ LOGGER.debug("client writing: {}", testMessage1);
+ saslClientTransport.write(testMessage1.getBytes());
+ saslClientTransport.flush();
- saslClientTransport.close();
+ byte[] inBuf = new byte[testMessage2.getBytes().length];
+ saslClientTransport.readAll(inBuf, 0, inBuf.length);
+ LOGGER.debug("client got: {}", new String(inBuf));
+ assertEquals(new String(inBuf), testMessage2);
- try {
- serverThread.join();
- } catch (InterruptedException e) {
- // Ah well.
+ TTransportException expectedException = null;
+ try {
+ saslClientTransport.open();
+ } catch (TTransportException e) {
+ expectedException = e;
+ }
+ assertNotNull(expectedException);
+
+ saslClientTransport.close();
+ } catch (Exception e) {
+ LOGGER.warn("Exception caught", e);
+ throw e;
+ } finally {
+ serverThread.interrupt();
+ try {
+ serverThread.join();
+ } catch (InterruptedException e) {
+ // Ah well.
+ }
+ assertNull(serverThread.thrown);
}
}
- public void testUnwrappedOpen() throws SaslException, TTransportException {
+ public void testUnwrappedOpen() throws Exception {
testSaslOpen(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
}
- public void testWrappedOpen() throws SaslException, TTransportException {
+ public void testWrappedOpen() throws Exception {
testSaslOpen(WRAPPED_MECHANISM, WRAPPED_PROPS);
}
+ public void testAnonymousOpen() throws Exception {
+ testSaslOpen("ANONYMOUS", null);
+ }
+
+ /**
+ * Test that we get the proper exceptions thrown back the server when
+ * the client provides invalid password.
+ */
+ public void testBadPassword() throws Exception {
+ ServerThread serverThread = new ServerThread(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
+ serverThread.start();
+
+ try {
+ Thread.sleep(1000);
+ } catch (InterruptedException e) {
+ // Ah well.
+ }
+
+ boolean clientSidePassed = true;
+
+ try {
+ TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
+ TTransport saslClientTransport = new TSaslClientTransport(
+ UNWRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, UNWRAPPED_PROPS,
+ new TestSaslCallbackHandler("NOT THE PASSWORD"), clientSocket);
+ saslClientTransport.open();
+ clientSidePassed = false;
+ fail("Was able to open transport with bad password");
+ } catch (TTransportException tte) {
+ LOGGER.error("Exception for bad password", tte);
+ assertNotNull(tte.getMessage());
+ assertTrue(tte.getMessage().contains("Invalid response"));
+
+ } finally {
+ serverThread.interrupt();
+ serverThread.join();
+
+ if (clientSidePassed) {
+ assertNotNull(serverThread.thrown);
+ assertTrue(serverThread.thrown.getMessage().contains("Invalid response"));
+ }
+ }
+ }
+
public void testWithServer() throws Exception {
new TestTSaslTransportsWithServer().testIt();
}
@@ -183,8 +268,9 @@ public class TestTSaslTransports extends TestCase {
@Override
public TTransport getClientTransport(TTransport underlyingTransport) throws Exception {
- return new TSaslClientTransport(WRAPPED_MECHANISM,
- PRINCIPAL, SERVICE, HOST, WRAPPED_PROPS, new TestSaslCallbackHandler(), underlyingTransport);
+ return new TSaslClientTransport(
+ WRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, WRAPPED_PROPS,
+ new TestSaslCallbackHandler(PASSWORD), underlyingTransport);
}
@Override
@@ -195,8 +281,9 @@ public class TestTSaslTransports extends TestCase {
// Transport
TServerSocket socket = new TServerSocket(PORT);
- TTransportFactory factory = new TSaslServerTransport.Factory(WRAPPED_MECHANISM,
- SERVICE, HOST, WRAPPED_PROPS, new TestSaslCallbackHandler());
+ TTransportFactory factory = new TSaslServerTransport.Factory(
+ WRAPPED_MECHANISM, SERVICE, HOST, WRAPPED_PROPS,
+ new TestSaslCallbackHandler(PASSWORD));
server = new TSimpleServer(processor, socket, factory, protoFactory);
// Run it
@@ -222,4 +309,104 @@ public class TestTSaslTransports extends TestCase {
}
+
+ /**
+ * Implementation of SASL ANONYMOUS, used for testing client-side
+ * intial responses.
+ */
+ private static class AnonymousClient implements SaslClient {
+ private final String username;
+ private boolean hasProvidedInitialResponse;
+
+ public AnonymousClient(String username) {
+ this.username = username;
+ }
+
+ public String getMechanismName() { return "ANONYMOUS"; }
+ public boolean hasInitialResponse() { return true; }
+ public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
+ if (hasProvidedInitialResponse) {
+ throw new SaslException("Already complete!");
+ }
+
+ try {
+ hasProvidedInitialResponse = true;
+ return username.getBytes("UTF-8");
+ } catch (IOException e) {
+ throw new SaslException(e.toString());
+ }
+ }
+ public boolean isComplete() { return hasProvidedInitialResponse; }
+ public byte[] unwrap(byte[] incoming, int offset, int len) {
+ throw new UnsupportedOperationException();
+ }
+ public byte[] wrap(byte[] outgoing, int offset, int len) {
+ throw new UnsupportedOperationException();
+ }
+ public Object getNegotiatedProperty(String propName) { return null; }
+ public void dispose() {}
+ }
+
+ private static class AnonymousServer implements SaslServer {
+ private String user;
+ public String getMechanismName() { return "ANONYMOUS"; }
+ public byte[] evaluateResponse(byte[] response) throws SaslException {
+ try {
+ this.user = new String(response, "UTF-8");
+ } catch (IOException e) {
+ throw new SaslException(e.toString());
+ }
+ return null;
+ }
+ public boolean isComplete() { return user != null; }
+ public String getAuthorizationID() { return user; }
+ public byte[] unwrap(byte[] incoming, int offset, int len) {
+ throw new UnsupportedOperationException();
+ }
+ public byte[] wrap(byte[] outgoing, int offset, int len) {
+ throw new UnsupportedOperationException();
+ }
+ public Object getNegotiatedProperty(String propName) { return null; }
+ public void dispose() {}
+
+ }
+
+ public static class SaslAnonymousFactory
+ implements SaslClientFactory, SaslServerFactory {
+
+ public SaslClient createSaslClient(
+ String[] mechanisms, String authorizationId, String protocol,
+ String serverName, Map<String,?> props, CallbackHandler cbh)
+ {
+ for (String mech : mechanisms) {
+ if ("ANONYMOUS".equals(mech)) {
+ return new AnonymousClient(authorizationId);
+ }
+ }
+ return null;
+ }
+
+ public SaslServer createSaslServer(
+ String mechanism, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh)
+ {
+ if ("ANONYMOUS".equals(mechanism)) {
+ return new AnonymousServer();
+ }
+ return null;
+ }
+ public String[] getMechanismNames(Map<String, ?> props) {
+ return new String[] { "ANONYMOUS" };
+ }
+ }
+
+ static {
+ java.security.Security.addProvider(new SaslAnonymousProvider());
+ }
+ public static class SaslAnonymousProvider extends java.security.Provider {
+ public SaslAnonymousProvider() {
+ super("ThriftSaslAnonymous", 1.0, "Thrift Anonymous SASL provider");
+ put("SaslClientFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
+ put("SaslServerFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
+ }
+ }
}