On Tue, Oct 30, 2012 at 10:46:30PM -0700, Michael Vrable wrote: > Attached is a first patch (still needs to be tested) at adding better > transport-layer encryption to Xpra--it adds message authentication to each of > the packets to prevent any tampering of the data stream. Please don't commit > it, as it isn't ready for that yet.
Does the mailing list strip attachments? I'm not sure it went through, so here it is again inline. --Michael Vrable commit 1f7e729bbb1e641506ca36ee82bd78c6ff009595 Author: Michael Vrable <[email protected]> Proof-of-concept implementation of a more secure transport layer for Xpra. This uses AES encryption of data packets (in CTR mode to avoid the need for padding), and a truncated HMAC-SHA-256 to provide authentication of the data stream. This assumes that both sides have run some type of key-agreement protocol to establish a shared session secret. I'm working on the key exchange part in a separate patch which will follow. This code isn't yet tested, but should give a basic idea. diff --git a/src/xpra/protocol.py b/src/xpra/protocol.py index 2fdec82..95da1ad 100644 --- a/src/xpra/protocol.py +++ b/src/xpra/protocol.py @@ -11,12 +11,16 @@ from wimpiggy.gobject_compat import import_gobject gobject = import_gobject() gobject.threads_init() +import hashlib +import hmac import sys import socket # for socket.error import zlib import struct import time import os +from Crypto.Cipher import AES +from Crypto.Util import Counter NOYIELD = os.environ.get("XPRA_YIELD") is None @@ -69,6 +73,168 @@ def zlib_compress(datatype, data, level=5): return ZLibCompressed(datatype, cdata, level) +class CryptoError(ValueError): + """Error raised when decryption fails for any reason.""" + pass + + +class TransportCrypto: + """Interface for transport-level encryption layers. + + This class defines the methods that must be implemented to define a + transport-layer encryption/authentication method. The crypto layer is + generally set up after both parties have performed mutual authentication + and established a shared secret value, which is used to key the + encryption/MAC primitives. + + This class does not provide an implementation and so should not be + instantiated directly. Use one of NullCrypto or AESCrypto. + """ + + def name(self): + """Returns the name of the current crypto layer.""" + raise NotImplementedError + + def overhead_bytes(self, packet_len): + """Returns the number of bytes added to a packet of the given size. + + This allows a crypto transport to add extra data for an IV, padding, or + MAC. + """ + raise NotImplementedError + + def encrypt(self, header, payload): + """Encrypt/MAC the specified data payload. + + Returns the modified payload, which will be of length + len(payload) + overhead_bytes(len(payload)) + The encrypted payload does not include the header, but the computed MAC + may cover the values in the header. + """ + raise NotImplementedError + + def decrypt(self, header, payload): + """Decrypt the specified data payload. + + Returns the decrypted payload data, or raises a CryptoError exception + on any failures (bad MAC, incorrect padding if padding is used, etc.). + """ + raise NotImplementedError + + +class NullCrypto(TransportCrypto): + """A transparent transport which performs no encryption/authentication.""" + + def name(self): + return "null" + + def overhead_bytes(self, packet_len): + return 0 + + def encrypt(self, header, payload): + return payload + + def decrypt(self, header, payload): + return payload + + +class AESCrypto(TransportCrypto): + """A crypto layer using AES256 in CTR mode and HMAC-SHA-256. + + For each communication direction, two keys are derived from the shared + session secret: one encryption key and one message authentication key. + Data payloads are encrypted with AES in CTR mode, starting with a counter + value all zeroes. No padding is needed. + + A message authentication code (MAC) is appended to each packet; the MAC is + computed using HMAC-SHA-256 (keyed with the authentication key). The + digest is computed over the concatentation of a 64-bit, big-endian packet + counter (to prevent packet replay/reorder attacks), the unencrypted packet + header, and the encrypted packet data. The hash is truncated to mac_bytes + in length (default: 12 bytes = 96 bits) then appended after the packet + data. The length field encoded in the packet header is the length of the + data payload, not including the header and MAC. + """ + + def __init__(self, session_secret, context, mac_bytes=12): + """Initializes the crypto layer for one direction of a transport. + + Args: + session_secret: A secret value negotiated by both sides of the + connection. The value of session_secret must never be re-used + in a different connection, or security may suffer. This may be + any string value. + context: A string used to distinguish multiple related + instantiations of AESCrypto. For example, a client and server + may compute a single shared session_secret, and use a context + of "server" for data sent by the server and "client" for data + sent by the client. This ensures that separate keys are used + for each data direction. + mac_bytes: Number of bytes to include in the message authentication + code for each packet. The MAC is a truncated HMAC-SHA-256; + mac_bytes can be up to 32 but smaller values reduce overhead at + the risk of allowing undetected errors if mac_bytes is too + small. + """ + self._session_secret = session_secret + self._context = context + + # Derived keys. There are two: + # - A 256-bit AES key for encryption + # - A key used for HMAC authentication + def derive_key(subtype): + return hmac.new(session_secret, "%s-%s" % (context, subtype), + digest_mod=hashlib.sha256).digest() + key_enc = derive_key("aes") + key_mac = derive_key("mac") + + # Start CTR mode counting from zero. This is safe as long as every + # session (and direction) uses a unique encryption key. + self._cipher = AES.new(key_enc, mode=AES.MODE_CTR, + counter=Counter.new(128)) + + self._hmac_key = key_mac + self._mac_bytes = mac_bytes + + # The packet counter. This is incremented for each packet processed. + # The packet counter is included in the MAC for each packet (to prevent + # replay/reordering attacks), but is not explicitly added to the output + # to reduce overhead. + self._packet_counter = 0 + + def name(self): + return "aes256-ctr/hmac256-%d" % (self._mac_bytes * 8) + + def overhead_bytes(self, packet_len): + # No padding is needed for CTR mode, so the only overhead is the MAC + # overhead, independent of packet size. + return self._mac_bytes + + def encrypt(self, header, payload): + payload = self._cipher.encrypt(payload) + mac = hmac.new(self._hmac_key, digest_mod=hashlib.sha256) + mac.update(struct.pack("!Q", self._packet_counter)) + self._packet_counter += 1 + mac.update(header) + mac.update(payload) + mac_value = mac.digest()[0:self._mac_bytes] + return payload + mac_value + + def decrypt(self, header, payload): + if len(payload) < self._mac_bytes: + raise CryptoError("Bad decryption") + payload_data = payload[:-self._mac_bytes] + payload_mac = payload[-self._mac_bytes:] + mac = hmac.new(self._hmac_key, digest_mod=hashlib.sha256) + mac.update(struct.pack("!Q", self._packet_counter)) + self._packet_counter += 1 + mac.update(header) + mac.update(payload_data) + if mac.digest()[0:self._mac_bytes] != payload_mac: + raise CryptoError("Bad decryption") + return self._cipher.decrypt(payload_data) + + class Protocol(object): CONNECTION_LOST = "connection-lost" GIBBERISH = "gibberish" @@ -97,12 +263,8 @@ class Protocol(object): self._encoder = self.bencode self._decompressor = zlib.decompressobj() self._compression_level = 0 - self.cipher_in = None - self.cipher_in_name = None - self.cipher_in_block_size = 0 - self.cipher_out = None - self.cipher_out_name = None - self.cipher_out_block_size = 0 + self.cipher_in = NullCrypto() + self.cipher_out = NullCrypto() def make_daemon_thread(target, name): daemon_thread = Thread(target=target, name=name) daemon_thread.setDaemon(True) @@ -112,33 +274,15 @@ class Protocol(object): self._read_thread = make_daemon_thread(self._read_thread_loop, "read_loop") self._read_parser_thread = make_daemon_thread(self._read_parse_thread_loop, "read_parse_loop") - def get_cipher(self, ciphername, iv, password, key_salt, iterations): - log("get_cipher_in(%s, %s, %s, %s, %s)", ciphername, iv, password, key_salt, iterations) - if not ciphername: - return None, 0 - assert iterations>=100 - assert ciphername=="AES" - assert password and iv - from Crypto.Cipher import AES - from Crypto.Protocol.KDF import PBKDF2 - #stretch the password: - block_size = 32 #fixme: can we derive this? - secret = PBKDF2(password, key_salt, dkLen=block_size, count=iterations) - #secret = (password+password+password+password+password+password+password+password)[:32] - log("get_cipher(%s, %s, %s) secret=%s, block_size=%s", ciphername, iv, password, secret.encode('hex'), block_size) - return AES.new(secret, AES.MODE_CBC, iv), block_size - - def set_cipher_in(self, ciphername, iv, password, key_salt, iterations): - if self.cipher_in_name!=ciphername: - log.info("receiving data using %s encryption", ciphername) - self.cipher_in_name = ciphername - self.cipher_in, self.cipher_in_block_size = self.get_cipher(ciphername, iv, password, key_salt, iterations) - - def set_cipher_out(self, ciphername, iv, password, key_salt, iterations): - if self.cipher_out_name!=ciphername: - log.info("sending data using %s encryption", ciphername) - self.cipher_out_name = ciphername - self.cipher_out, self.cipher_out_block_size = self.get_cipher(ciphername, iv, password, key_salt, iterations) + def set_cipher(self, direction, ciphername, session_secret, context): + cipher = AESCrypto(session_secret, context) + log.info("setting encryption from %s to %s", context, cipher.name()) + if direction == "in": + self.cipher_in = cipher + elif direction == "out": + self.cipher_out = cipher + else: + raise ValueError("Unknown cipher direction: " + direction) def __str__(self): ti = ["%s:%s" % (x.name, x.is_alive()) for x in self.get_threads()] @@ -289,29 +433,12 @@ class Protocol(object): #fire the end_send callback when the last packet (index==0) makes it out: if index==0: ecb = end_send_cb - if self.cipher_out: - proto_flags |= Protocol.FLAGS_CIPHER - #note: since we are padding: l!=len(data) - padding = (self.cipher_out_block_size - len(data) % self.cipher_out_block_size) * " " - if len(padding)==0: - padded = data - else: - padded = data+padding - actual_size = payload_size + len(padding) - assert len(padded)==actual_size - data = self.cipher_out.encrypt(padded) - assert len(data)==actual_size - log("sending %s bytes encrypted with %s padding", payload_size, len(padding)) - if actual_size<16384: - #'p' + protocol-flags + compression_level + packet_index + data_size - if type(data)==unicode: - data = str(data) - header_and_data = struct.pack('!BBBBL%ss' % actual_size, ord("P"), proto_flags, level, index, payload_size, data) - self._write_queue.put((header_and_data, scb, ecb)) - else: - header = struct.pack('!BBBBL', ord("P"), proto_flags, level, index, payload_size) - self._write_queue.put((header, scb, None)) - self._write_queue.put((data, None, ecb)) + header = struct.pack('!BBBBL', + ord("P"), proto_flags, level, + index, payload_size) + data = self.cipher_out.encrypt(header, data) + self._write_queue.put((header, scb, None)) + self._write_queue.put((data, None, ecb)) counter += 1 finally: self.output_packetcount += 1 @@ -437,20 +564,14 @@ class Protocol(object): break #packet still too small #packet format: struct.pack('cBBBL', ...) - 8 bytes try: - _, protocol_flags, compression_level, packet_index, data_size = struct.unpack_from('!cBBBL', read_buffer) + header = read_buffer[:8] + _, protocol_flags, compression_level, packet_index, data_size = struct.unpack('!cBBBL', header) except Exception, e: raise Exception("invalid packet header: %s" % list(read_buffer[:8]), e) read_buffer = read_buffer[8:] bl = len(read_buffer) - if protocol_flags & Protocol.FLAGS_CIPHER: - assert self.cipher_in_block_size>0, "received cipher block but we don't have a cipher do decrypt it with" - padding = (self.cipher_in_block_size - data_size % self.cipher_in_block_size) * " " - payload_size = data_size + len(padding) - else: - #no cipher, no padding: - padding = None - payload_size = data_size - assert payload_size>0 + payload_size = (data_size + + self.cipher_in.overhead_bytes(data_size)) if payload_size>self.max_packet_size: #this packet is seemingly too big, but check again from the main UI thread @@ -474,13 +595,11 @@ class Protocol(object): raw_string = read_buffer[:payload_size] read_buffer = read_buffer[payload_size:] #decrypt if needed: - data = raw_string - if self.cipher_in and protocol_flags & Protocol.FLAGS_CIPHER: - log("received %s encrypted bytes with %s padding", payload_size, len(padding)) - data = self.cipher_in.decrypt(raw_string) - if padding: - assert data.endswith(padding), "decryption failed: string does not end with '%s': %s (%s) -> %s (%s)" % (padding, list(bytearray(raw_string)), type(raw_string), list(bytearray(data)), type(data)) - data = data[:-len(padding)] + try: + data = self.cipher_in.decrypt(header, raw_string) + if len(data) != len(data_size): raise CryptoError() + except CryptoError: + return self._call_connection_lost("Decryption failed: %s" % repr_ellipsized(data)) #uncompress if needed: if compression_level>0: if self.chunked_compression: @@ -490,9 +609,6 @@ class Protocol(object): if sys.version>='3': data = data.decode("latin1") - if self.cipher_in and not (protocol_flags & Protocol.FLAGS_CIPHER): - return self._call_connection_lost("unencrypted packet dropped: %s" % repr_ellipsized(data)) - if self._closed: return if packet_index>0: _______________________________________________ shifter-users mailing list [email protected] http://lists.devloop.org.uk/mailman/listinfo/shifter-users
