From 85783fe730f0539472ade0373128f427108d9e95 Mon Sep 17 00:00:00 2001
From: Dirkjan Bussink <d.bussink@gmail.com>
Date: Tue, 12 Feb 2019 08:58:10 +0000
Subject: [PATCH 6/6] Add implementation for Encrypt-then-MAC mode

This adds the OpenSSH HMACs that do encrypt then mac. This is a more
secure mode than the original HMAC. Newer AEAD ciphers like chacha20 and
AES-GCM are already encrypt-then-mac, but this also adds it for older
legacy clients that don't support those ciphers yet.

Reviewed-by: Jon Simons <jon@jonsimons.org>
---
 src/kex.c          |  8 ++---
 src/packet.c       | 96 ++++++++++++++++++++++++++++++++++++++----------------
 src/packet_crypt.c | 35 ++++++++++++++++----
 src/wrapper.c      | 21 +++++++-----
 4 files changed, 113 insertions(+), 47 deletions(-)

diff --git a/src/kex.c b/src/kex.c
index aa783f8f..03d66e70 100644
--- a/src/kex.c
+++ b/src/kex.c
@@ -146,8 +146,8 @@ static const char *default_methods[] = {
   PUBLIC_KEY_ALGORITHMS,
   AES BLOWFISH DES,
   AES BLOWFISH DES,
-  "hmac-sha2-256,hmac-sha2-512,hmac-sha1",
-  "hmac-sha2-256,hmac-sha2-512,hmac-sha1",
+  "hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com,hmac-sha1-etm@openssh.com,hmac-sha2-256,hmac-sha2-512,hmac-sha1",
+  "hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com,hmac-sha1-etm@openssh.com,hmac-sha2-256,hmac-sha2-512,hmac-sha1",
   "none",
   "none",
   "",
@@ -161,8 +161,8 @@ static const char *supported_methods[] = {
   PUBLIC_KEY_ALGORITHMS,
   CHACHA20 AES BLOWFISH DES_SUPPORTED,
   CHACHA20 AES BLOWFISH DES_SUPPORTED,
-  "hmac-sha2-256,hmac-sha2-512,hmac-sha1",
-  "hmac-sha2-256,hmac-sha2-512,hmac-sha1",
+  "hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com,hmac-sha1-etm@openssh.com,hmac-sha2-256,hmac-sha2-512,hmac-sha1",
+  "hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com,hmac-sha1-etm@openssh.com,hmac-sha2-256,hmac-sha2-512,hmac-sha1",
   ZLIB,
   ZLIB,
   "",
diff --git a/src/packet.c b/src/packet.c
index d0c5d60b..056d80c8 100644
--- a/src/packet.c
+++ b/src/packet.c
@@ -1045,6 +1045,8 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
     size_t processed = 0; /* number of byte processed from the callback */
     enum ssh_packet_filter_result_e filter_result;
     struct ssh_crypto_struct *crypto = NULL;
+    int etm = 0;
+    int etm_packet_offset = 0;
     bool ok;
 
     crypto = ssh_packet_get_current_crypto(session, SSH_DIRECTION_IN);
@@ -1052,9 +1054,17 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
         current_macsize = hmac_digest_len(crypto->in_hmac);
         blocksize = crypto->in_cipher->blocksize;
         lenfield_blocksize = crypto->in_cipher->lenfield_blocksize;
+        etm = crypto->in_hmac_etm;
     }
 
-    if (lenfield_blocksize == 0) {
+    if (etm) {
+        /* In EtM mode packet size is unencrypted. This means
+         * we need to use this offset and set the block size
+         * that is part of the encrypted part to 0.
+         */
+        etm_packet_offset = sizeof(uint32_t);
+        lenfield_blocksize = 0;
+    } else if (lenfield_blocksize == 0) {
         lenfield_blocksize = blocksize;
     }
     if (data == NULL) {
@@ -1077,10 +1087,10 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
 #endif
     switch(session->packet_state) {
         case PACKET_STATE_INIT:
-            if (receivedlen < lenfield_blocksize) {
+            if (receivedlen < lenfield_blocksize + etm_packet_offset) {
                 /*
-                 * We didn't receive enough data to read at least one
-                 * block size, give up
+                 * We didn't receive enough data to read either at least one
+                 * block size or the unencrypted length in EtM mode.
                  */
 #ifdef DEBUG_PACKET
                 SSH_LOG(SSH_LOG_PACKET,
@@ -1107,13 +1117,20 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
                 }
             }
 
-            ptr = ssh_buffer_allocate(session->in_buffer, lenfield_blocksize);
-            if (ptr == NULL) {
-                goto error;
+            if (!etm) {
+                ptr = ssh_buffer_allocate(session->in_buffer, lenfield_blocksize);
+                if (ptr == NULL) {
+                    goto error;
+                }
+                packet_len = ssh_packet_decrypt_len(session, ptr, (uint8_t *)data);
+                to_be_read = packet_len - lenfield_blocksize + sizeof(uint32_t);
+            } else {
+                /* Length is unencrypted in case of Encrypt-then-MAC */
+                packet_len = PULL_BE_U32(data, 0);
+                to_be_read = packet_len - etm_packet_offset;
             }
-            processed += lenfield_blocksize;
-            packet_len = ssh_packet_decrypt_len(session, ptr, (uint8_t *)data);
 
+            processed += lenfield_blocksize + etm_packet_offset;
             if (packet_len > MAX_PACKET_LEN) {
                 ssh_set_error(session,
                               SSH_FATAL,
@@ -1121,7 +1138,6 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
                               packet_len, packet_len);
                 goto error;
             }
-            to_be_read = packet_len - lenfield_blocksize + sizeof(uint32_t);
             if (to_be_read < 0) {
                 /* remote sshd sends invalid sizes? */
                 ssh_set_error(session,
@@ -1136,7 +1152,7 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
             FALL_THROUGH;
         case PACKET_STATE_SIZEREAD:
             packet_len = session->in_packet.len;
-            processed = lenfield_blocksize;
+            processed = lenfield_blocksize + etm_packet_offset;
             to_be_read = packet_len + sizeof(uint32_t) + current_macsize;
             /* if to_be_read is zero, the whole packet was blocksize bytes. */
             if (to_be_read != 0) {
@@ -1151,13 +1167,13 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
                     return 0;
                 }
 
-                packet_second_block = (uint8_t*)data + lenfield_blocksize;
+                packet_second_block = (uint8_t*)data + lenfield_blocksize + etm_packet_offset;
                 processed = to_be_read - current_macsize;
             }
 
             /* remaining encrypted bytes from the packet, MAC not included */
             packet_remaining =
-                packet_len - (lenfield_blocksize - sizeof(uint32_t));
+                packet_len - (lenfield_blocksize - sizeof(uint32_t) + etm_packet_offset);
             cleartext_packet = ssh_buffer_allocate(session->in_buffer,
                                                    packet_remaining);
             if (cleartext_packet == NULL) {
@@ -1166,6 +1182,19 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
 
             if (packet_second_block != NULL) {
                 if (crypto != NULL) {
+                    mac = packet_second_block + packet_remaining;
+
+                    if (etm) {
+                        rc = ssh_packet_hmac_verify(session,
+                                                    data,
+                                                    processed,
+                                                    mac,
+                                                    crypto->in_hmac);
+                        if (rc < 0) {
+                            ssh_set_error(session, SSH_FATAL, "HMAC error");
+                            goto error;
+                        }
+                    }
                     /*
                      * Decrypt the rest of the packet (lenfield_blocksize bytes
                      * already have been decrypted)
@@ -1174,8 +1203,8 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
                         rc = ssh_packet_decrypt(session,
                                                 cleartext_packet,
                                                 (uint8_t *)data,
-                                                lenfield_blocksize,
-                                                processed - lenfield_blocksize);
+                                                lenfield_blocksize + etm_packet_offset,
+                                                processed - (lenfield_blocksize + etm_packet_offset));
                         if (rc < 0) {
                             ssh_set_error(session,
                                           SSH_FATAL,
@@ -1183,16 +1212,17 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
                             goto error;
                         }
                     }
-                    mac = packet_second_block + packet_remaining;
 
-                    rc = ssh_packet_hmac_verify(session,
-                                                ssh_buffer_get(session->in_buffer),
-                                                ssh_buffer_get_len(session->in_buffer),
-                                                mac,
-                                                crypto->in_hmac);
-                    if (rc < 0) {
-                        ssh_set_error(session, SSH_FATAL, "HMAC error");
-                        goto error;
+                    if (!etm) {
+                        rc = ssh_packet_hmac_verify(session,
+                                                    ssh_buffer_get(session->in_buffer),
+                                                    ssh_buffer_get_len(session->in_buffer),
+                                                    mac,
+                                                    crypto->in_hmac);
+                        if (rc < 0) {
+                            ssh_set_error(session, SSH_FATAL, "HMAC error");
+                            goto error;
+                        }
                     }
                     processed += current_macsize;
                 } else {
@@ -1212,8 +1242,10 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
             }
 #endif
 
-            /* skip the size field which has been processed before */
-            ssh_buffer_pass_bytes(session->in_buffer, sizeof(uint32_t));
+            if (!etm) {
+                /* skip the size field which has been processed before */
+                ssh_buffer_pass_bytes(session->in_buffer, sizeof(uint32_t));
+            }
 
             rc = ssh_buffer_get_u8(session->in_buffer, &padding);
             if (rc == 0) {
@@ -1525,12 +1557,15 @@ static int packet_send2(ssh_session session)
     uint8_t header[5] = {0};
     uint8_t type, *payload;
     int rc = SSH_ERROR;
+    int etm = 0;
+    int etm_packet_offset = 0;
 
     crypto = ssh_packet_get_current_crypto(session, SSH_DIRECTION_OUT);
     if (crypto) {
         blocksize = crypto->out_cipher->blocksize;
         lenfield_blocksize = crypto->out_cipher->lenfield_blocksize;
         hmac_type = crypto->out_hmac;
+        etm = crypto->out_hmac_etm;
     } else {
         hmac_type = session->next_crypto->out_hmac;
     }
@@ -1539,6 +1574,11 @@ static int packet_send2(ssh_session session)
     type = payload[0]; /* type is the first byte of the packet now */
 
     payloadsize = currentlen;
+    if (etm) {
+        etm_packet_offset = sizeof(uint32_t);
+        lenfield_blocksize = 0;
+    }
+
 #ifdef WITH_ZLIB
     if (crypto != NULL && crypto->do_compress_out &&
         ssh_buffer_get_len(session->out_buffer) > 0) {
@@ -1552,7 +1592,7 @@ static int packet_send2(ssh_session session)
     compsize = currentlen;
     /* compressed payload + packet len (4) + padding_size len (1) */
     /* totallen - lenfield_blocksize must be equal to 0 (mod blocksize) */
-    padding_size = (blocksize - ((blocksize - lenfield_blocksize + currentlen + 5) % blocksize));
+    padding_size = (blocksize - ((blocksize - lenfield_blocksize - etm_packet_offset + currentlen + 5) % blocksize));
     if (padding_size < 4) {
         padding_size += blocksize;
     }
@@ -1567,7 +1607,7 @@ static int packet_send2(ssh_session session)
         }
     }
 
-    finallen = currentlen + padding_size + 1;
+    finallen = currentlen - etm_packet_offset + padding_size + 1;
 
     PUSH_BE_U32(header, 0, finallen);
     PUSH_BE_U8(header, 4, padding_size);
diff --git a/src/packet_crypt.c b/src/packet_crypt.c
index d146c634..205fde5f 100644
--- a/src/packet_crypt.c
+++ b/src/packet_crypt.c
@@ -42,6 +42,7 @@
 #include "libssh/wrapper.h"
 #include "libssh/crypto.h"
 #include "libssh/buffer.h"
+#include "libssh/bytearray.h"
 
 /** @internal
  * @brief decrypt the packet length from a raw encrypted packet, and store the first decrypted
@@ -132,9 +133,11 @@ unsigned char *ssh_packet_encrypt(ssh_session session, void *data, uint32_t len)
   struct ssh_cipher_struct *cipher = NULL;
   HMACCTX ctx = NULL;
   char *out = NULL;
+  int etm_packet_offset = 0;
   unsigned int finallen, blocksize;
   uint32_t seq, lenfield_blocksize;
   enum ssh_hmac_e type;
+  int etm;
 
   assert(len);
 
@@ -145,7 +148,15 @@ unsigned char *ssh_packet_encrypt(ssh_session session, void *data, uint32_t len)
 
   blocksize = crypto->out_cipher->blocksize;
   lenfield_blocksize = crypto->out_cipher->lenfield_blocksize;
-  if ((len - lenfield_blocksize) % blocksize != 0) {
+
+  type = crypto->out_hmac;
+  etm = crypto->out_hmac_etm;
+
+  if (etm) {
+      etm_packet_offset = sizeof(uint32_t);
+  }
+
+  if ((len - lenfield_blocksize - etm_packet_offset) % blocksize != 0) {
       ssh_set_error(session, SSH_FATAL, "Cryptographic functions must be set"
                     " on at least one blocksize (received %d)", len);
       return NULL;
@@ -155,23 +166,35 @@ unsigned char *ssh_packet_encrypt(ssh_session session, void *data, uint32_t len)
     return NULL;
   }
 
-  type = crypto->out_hmac;
   seq = ntohl(session->send_seq);
   cipher = crypto->out_cipher;
 
   if (cipher->aead_encrypt != NULL) {
       cipher->aead_encrypt(cipher, data, out, len,
             crypto->hmacbuf, session->send_seq);
+      memcpy(data, out, len);
   } else {
       ctx = hmac_init(crypto->encryptMAC, hmac_digest_len(type), type);
       if (ctx == NULL) {
         SAFE_FREE(out);
         return NULL;
       }
-      hmac_update(ctx,(unsigned char *)&seq,sizeof(uint32_t));
-      hmac_update(ctx,data,len);
-      hmac_final(ctx, crypto->hmacbuf, &finallen);
 
+      if (!etm) {
+          hmac_update(ctx, (unsigned char *)&seq, sizeof(uint32_t));
+          hmac_update(ctx, data, len);
+          hmac_final(ctx, crypto->hmacbuf, &finallen);
+      }
+
+      cipher->encrypt(cipher, (uint8_t*)data + etm_packet_offset, out, len - etm_packet_offset);
+      memcpy((uint8_t*)data + etm_packet_offset, out, len - etm_packet_offset);
+
+      if (etm) {
+          PUSH_BE_U32(data, 0, len - etm_packet_offset);
+          hmac_update(ctx, (unsigned char *)&seq, sizeof(uint32_t));
+          hmac_update(ctx, data, len);
+          hmac_final(ctx, crypto->hmacbuf, &finallen);
+      }
 #ifdef DEBUG_CRYPTO
       ssh_print_hexa("mac: ",data,hmac_digest_len(type));
       if (finallen != hmac_digest_len(type)) {
@@ -179,9 +202,7 @@ unsigned char *ssh_packet_encrypt(ssh_session session, void *data, uint32_t len)
       }
       ssh_print_hexa("Packet hmac", crypto->hmacbuf, hmac_digest_len(type));
 #endif
-      cipher->encrypt(cipher, data, out, len);
   }
-  memcpy(data, out, len);
   explicit_bzero(out, len);
   SAFE_FREE(out);
 
diff --git a/src/wrapper.c b/src/wrapper.c
index f068e5b9..8563ecfd 100644
--- a/src/wrapper.c
+++ b/src/wrapper.c
@@ -56,14 +56,19 @@
 #include "libssh/curve25519.h"
 
 static struct ssh_hmac_struct ssh_hmac_tab[] = {
-  { "hmac-sha1",     SSH_HMAC_SHA1,          0 },
-  { "hmac-sha2-256", SSH_HMAC_SHA256,        0 },
-  { "hmac-sha2-384", SSH_HMAC_SHA384,        0 },
-  { "hmac-sha2-512", SSH_HMAC_SHA512,        0 },
-  { "hmac-md5",      SSH_HMAC_MD5,           0 },
-  { "aead-poly1305", SSH_HMAC_AEAD_POLY1305, 0 },
-  { "aead-gcm",      SSH_HMAC_AEAD_GCM,      0 },
-  { NULL,            0,                      0 }
+  { "hmac-sha1",                     SSH_HMAC_SHA1,          0 },
+  { "hmac-sha2-256",                 SSH_HMAC_SHA256,        0 },
+  { "hmac-sha2-384",                 SSH_HMAC_SHA384,        0 },
+  { "hmac-sha2-512",                 SSH_HMAC_SHA512,        0 },
+  { "hmac-md5",                      SSH_HMAC_MD5,           0 },
+  { "aead-poly1305",                 SSH_HMAC_AEAD_POLY1305, 0 },
+  { "aead-gcm",                      SSH_HMAC_AEAD_GCM,      0 },
+  { "hmac-sha1-etm@openssh.com",     SSH_HMAC_SHA1,          1 },
+  { "hmac-sha2-256-etm@openssh.com", SSH_HMAC_SHA256,        1 },
+  { "hmac-sha2-384-etm@openssh.com", SSH_HMAC_SHA384,        1 },
+  { "hmac-sha2-512-etm@openssh.com", SSH_HMAC_SHA512,        1 },
+  { "hmac-md5-etm@openssh.com",      SSH_HMAC_MD5,           1 },
+  { NULL,                            0,                      0 }
 };
 
 struct ssh_hmac_struct *ssh_get_hmactab(void) {
-- 
2.11.0

