diff options
Diffstat (limited to 'kafka/conn.py')
-rw-r--r-- | kafka/conn.py | 83 |
1 files changed, 5 insertions, 78 deletions
diff --git a/kafka/conn.py b/kafka/conn.py index e4938c7..dfb8d78 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -1,16 +1,11 @@ from __future__ import absolute_import, division -import base64 import copy import errno -import hashlib -import hmac import io import logging from random import shuffle, uniform -from uuid import uuid4 - # selectors in stdlib as of py3.4 try: import selectors # pylint: disable=import-error @@ -34,6 +29,7 @@ from kafka.protocol.commit import OffsetFetchRequest from kafka.protocol.metadata import MetadataRequest from kafka.protocol.parser import KafkaProtocol from kafka.protocol.types import Int32, Int8 +from kafka.scram import ScramClient from kafka.version import __version__ @@ -42,12 +38,6 @@ if six.PY2: TimeoutError = socket.error BlockingIOError = Exception - def xor_bytes(left, right): - return bytearray(ord(lb) ^ ord(rb) for lb, rb in zip(left, right)) -else: - def xor_bytes(left, right): - return bytes(lb ^ rb for lb, rb in zip(left, right)) - log = logging.getLogger(__name__) DEFAULT_KAFKA_PORT = 9092 @@ -107,69 +97,6 @@ class ConnectionStates(object): AUTHENTICATING = '<authenticating>' -class ScramClient: - MECHANISMS = { - 'SCRAM-SHA-256': hashlib.sha256, - 'SCRAM-SHA-512': hashlib.sha512 - } - - def __init__(self, user, password, mechanism): - self.nonce = str(uuid4()).replace('-', '') - self.auth_message = '' - self.salted_password = None - self.user = user - self.password = password.encode() - self.hashfunc = self.MECHANISMS[mechanism] - self.hashname = ''.join(mechanism.lower().split('-')[1:3]) - self.stored_key = None - self.client_key = None - self.client_signature = None - self.client_proof = None - self.server_key = None - self.server_signature = None - - def first_message(self): - client_first_bare = 'n={},r={}'.format(self.user, self.nonce) - self.auth_message += client_first_bare - return 'n,,' + client_first_bare - - def process_server_first_message(self, server_first_message): - self.auth_message += ',' + server_first_message - params = dict(pair.split('=', 1) for pair in server_first_message.split(',')) - server_nonce = params['r'] - if not server_nonce.startswith(self.nonce): - raise ValueError("Server nonce, did not start with client nonce!") - self.nonce = server_nonce - self.auth_message += ',c=biws,r=' + self.nonce - - salt = base64.b64decode(params['s'].encode()) - iterations = int(params['i']) - self.create_salted_password(salt, iterations) - - self.client_key = self.hmac(self.salted_password, b'Client Key') - self.stored_key = self.hashfunc(self.client_key).digest() - self.client_signature = self.hmac(self.stored_key, self.auth_message.encode()) - self.client_proof = xor_bytes(self.client_key, self.client_signature) - self.server_key = self.hmac(self.salted_password, b'Server Key') - self.server_signature = self.hmac(self.server_key, self.auth_message.encode()) - - def hmac(self, key, msg): - return hmac.new(key, msg, digestmod=self.hashfunc).digest() - - def create_salted_password(self, salt, iterations): - self.salted_password = hashlib.pbkdf2_hmac( - self.hashname, self.password, salt, iterations - ) - - def final_message(self): - client_final_no_proof = 'c=biws,r=' + self.nonce - return 'c=biws,r={},p={}'.format(self.nonce, base64.b64encode(self.client_proof).decode()) - - def process_server_final_message(self, server_final_message): - params = dict(pair.split('=', 1) for pair in server_final_message.split(',')) - if self.server_signature != base64.b64decode(params['v'].encode()): - raise ValueError("Server sent wrong signature!") - class BrokerConnection(object): """Initialize a Kafka broker connection @@ -747,20 +674,20 @@ class BrokerConnection(object): close = False else: try: - client_first = scram_client.first_message().encode() + client_first = scram_client.first_message().encode('utf-8') size = Int32.encode(len(client_first)) self._send_bytes_blocking(size + client_first) (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4)) - server_first = self._recv_bytes_blocking(data_len).decode() + server_first = self._recv_bytes_blocking(data_len).decode('utf-8') scram_client.process_server_first_message(server_first) - client_final = scram_client.final_message().encode() + client_final = scram_client.final_message().encode('utf-8') size = Int32.encode(len(client_final)) self._send_bytes_blocking(size + client_final) (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4)) - server_final = self._recv_bytes_blocking(data_len).decode() + server_final = self._recv_bytes_blocking(data_len).decode('utf-8') scram_client.process_server_final_message(server_final) except (ConnectionError, TimeoutError) as e: |