Greetings. I've been trying to make a patch for squid, so that it
could read client hello on connect requests and set the SNI without
using ssl_bump, as that requires generating certificates and is too
complicated for my needs. Here's the patch I've come up with. It seems
to be working, but I'm getting a bunch of connections in CLOSE_WAIT
state after using it under load. I can't seem to reproduce it locally,
but I bet I don't know something, or did something wrong. Can anyone
code check this patch, please? Also, not sure if it's the correct
place to post this. The patch is applicable to the latest release in
4.x series - 4.15.

-- 
HisShadow
diff --git a/src/SquidConfig.h b/src/SquidConfig.h
index b696ffc..e5fbc2d 100644
--- a/src/SquidConfig.h
+++ b/src/SquidConfig.h
@@ -365,6 +365,7 @@ public:
         acl_access *sendHit;
         acl_access *storeMiss;
         acl_access *stats_collection;
+        acl_access *banned_domains;
 #if SQUID_SNMP
 
         acl_access *snmp;
diff --git a/src/cf.data.pre b/src/cf.data.pre
index 4aef432..3250545 100644
--- a/src/cf.data.pre
+++ b/src/cf.data.pre
@@ -10157,4 +10157,13 @@ DOC_START
 		server_pconn_for_nonretriable allow SpeedIsWorthTheRisk
 DOC_END
 
+NAME: banned_domains
+TYPE: acl_access
+DEFAULT: none
+DEFAULT_DOC: Banned domains.
+LOC: Config.accessList.banned_domains
+DOC_START
+	Banned domains.
+DOC_END
+
 EOF
diff --git a/src/client_side.h b/src/client_side.h
index 9fe8463..a1b861e 100644
--- a/src/client_side.h
+++ b/src/client_side.h
@@ -120,6 +120,8 @@ public:
      */
     void setAuth(const Auth::UserRequest::Pointer &aur, const char *cause);
 #endif
+    /// TLS client delivered SNI value. Empty string if none has been received.
+    SBuf tlsClientSni_;
 
     Ip::Address log_addr;
 
@@ -413,8 +415,6 @@ private:
     unsigned short tlsConnectPort; ///< The TLS server port number as passed in the CONNECT request
     SBuf sslCommonName_; ///< CN name for SSL certificate generation
 
-    /// TLS client delivered SNI value. Empty string if none has been received.
-    SBuf tlsClientSni_;
     SBuf sslBumpCertKey; ///< Key to use to store/retrieve generated certificate
 
     /// HTTPS server cert. fetching state for bump-ssl-server-first
diff --git a/src/tunnel.cc b/src/tunnel.cc
index 217e947..6a015ca 100644
--- a/src/tunnel.cc
+++ b/src/tunnel.cc
@@ -79,6 +79,7 @@ public:
     static void ReadServer(const Comm::ConnectionPointer &, char *buf, size_t len, Comm::Flag errcode, int xerrno, void *data);
     static void WriteClientDone(const Comm::ConnectionPointer &, char *buf, size_t len, Comm::Flag flag, int xerrno, void *data);
     static void WriteServerDone(const Comm::ConnectionPointer &, char *buf, size_t len, Comm::Flag flag, int xerrno, void *data);
+    static void CloseConnections(const Comm::ConnectionPointer &, char *buf, size_t len, Comm::Flag flag, int xerrno, void *data);
 
     /// Starts reading peer response to our CONNECT request.
     void readConnectResponse();
@@ -177,6 +178,10 @@ public:
     SBuf preReadServerData;
     time_t startTime; ///< object creation time, before any peer selection/connection attempts
 
+    SBuf tlsData;
+    size_t tlsBodySize, tlsAlreadyRead, tlsHeaderLeftToRead;
+    bool tlsFirstByteChecked;
+
     void copyRead(Connection &from, IOCB *completion);
 
     /// continue to set up connection to a peer, going async for SSL peers
@@ -224,6 +229,7 @@ public:
     void readConnectResponseDone(char *buf, size_t len, Comm::Flag errcode, int xerrno);
     void copyClientBytes();
     void copyServerBytes();
+    void copyAlert();
 };
 
 static const char *const conn_established = "HTTP/1.1 200 Connection established\r\n\r\n";
@@ -872,11 +878,13 @@ static void
 tunnelStartShoveling(TunnelStateData *tunnelState)
 {
     assert(!tunnelState->waitingForConnectExchange());
+    if (!tunnelState->tlsData.isEmpty()) {
+        tunnelState->tlsData.consume();
+    }
     *tunnelState->status_ptr = Http::scOkay;
     if (tunnelState->logTag_ptr)
         *tunnelState->logTag_ptr = LOG_TCP_TUNNEL;
     if (cbdataReferenceValid(tunnelState)) {
-
         // Shovel any payload already pushed into reply buffer by the server response
         if (!tunnelState->server.len)
             tunnelState->copyServerBytes();
@@ -895,6 +903,248 @@ tunnelStartShoveling(TunnelStateData *tunnelState)
     }
 }
 
+static bool isSNICompatible(SBuf &header) {
+    if (((uint8_t)header[0] & 0x80) && header[2] == 1) {
+        return false;
+    }
+
+    if (header[1] < 3) {
+        return false;
+    }
+    return true;
+}
+
+void
+TunnelStateData::CloseConnections(const Comm::ConnectionPointer &, char *, size_t, Comm::Flag, int, void *data) {
+    TunnelStateData *tunnelState = (TunnelStateData *)data;
+    CbcPointer<TunnelStateData> safetyLock(tunnelState);
+
+    if (Comm::IsConnOpen(tunnelState->client.conn))
+        tunnelState->client.conn->close();
+
+    if (Comm::IsConnOpen(tunnelState->server.conn))
+        tunnelState->server.conn->close();
+}
+
+// send this to client if request is denied
+static const unsigned char alert[] = {0x15, 0x03, 0x01, 0x00, 0x02, 0x02, 0x28};
+
+void
+TunnelStateData::copyAlert() {
+    size_t copyBytes = 7;
+    memcpy(server.buf, alert, copyBytes);
+    server.bytesIn(copyBytes);
+    if (keepGoingAfterRead(copyBytes, Comm::OK, 0, server, client))
+        copy(copyBytes, server, client, TunnelStateData::CloseConnections);
+}
+
+static void doneACLCheck(allow_t answer, void *data) {
+    TunnelStateData *tunnelState = (TunnelStateData *)data;
+    if (answer.allowed()) {
+        debugs(26, 3, "banned_domains returned allowed");
+        tunnelState->preReadClientData.append(tunnelState->tlsData);
+        tunnelStartShoveling(tunnelState);
+    } else {
+        debugs(26, 3, "banned_domains returned denied");
+        tunnelState->copyAlert();
+    }
+}
+
+static void readClientHelloDone(TunnelStateData *tunnelState) {
+    Security::HandshakeParser parser(Security::HandshakeParser::fromClient);
+    SBuf finalTlsData;
+    if (tunnelState->tlsBodySize + 5 < tunnelState->tlsData.length()) {
+        finalTlsData = tunnelState->tlsData.substr(0, tunnelState->tlsBodySize + 5);
+    } else {
+        finalTlsData = tunnelState->tlsData;
+    }
+
+    debugs(26, 3, "Parsing tls data");
+    bool unsupportedProtocol = false;
+    try {
+        parser.parseHello(finalTlsData);
+        debugs(26, 3, "Parse successful");
+    } catch(const std::exception& e) {
+        unsupportedProtocol = true;
+        debugs(26, 3, "Failed to parse TLS: " << e.what());
+    }
+
+    if (!unsupportedProtocol && !parser.details->serverName.isEmpty()) {
+        debugs(26, 3, "Found server name " << parser.details->serverName);
+        if (tunnelState->al && tunnelState->al->request && tunnelState->al->request->clientConnectionManager.valid()) {
+            tunnelState->al->request->clientConnectionManager->tlsClientSni_ = parser.details->serverName;
+        }
+        if (Config.accessList.banned_domains && tunnelState->http.valid()) {
+            ACLFilledChecklist *ch = clientAclChecklistCreate(Config.accessList.banned_domains, tunnelState->http.get());
+            ch->nonBlockingCheck(doneACLCheck, (void *)tunnelState);
+        } else {
+            tunnelState->preReadClientData.append(tunnelState->tlsData);
+            tunnelStartShoveling(tunnelState);
+        }
+    } else {
+        debugs(26, 3, "No SNI found, proceeding without it");
+        tunnelState->preReadClientData.append(tunnelState->tlsData);
+        tunnelStartShoveling(tunnelState);
+    }
+}
+
+static void readClientHello(const Comm::ConnectionPointer &conn, char *, size_t len, Comm::Flag flag, int, void *data) {
+    TunnelStateData *tunnelState = (TunnelStateData *)data;
+
+    if (flag != Comm::OK) {
+        *tunnelState->status_ptr = Http::scInternalServerError;
+        tunnelErrorComplete(conn->fd, data, 0);
+        return;
+    }
+
+    size_t toRead = tunnelState->tlsBodySize - tunnelState->tlsAlreadyRead;
+    debugs(26, 3, "Reading tls body, toRead=" << toRead << ", tlsAlreadyRead=" << tunnelState->tlsAlreadyRead);
+    CommIoCbParams rd(tunnelState);
+    rd.conn = tunnelState->client.conn;
+    rd.size = toRead;
+    SBuf body;
+    body.reserveSpace(toRead);
+
+    switch(Comm::ReadNow(rd, body)) {
+    case Comm::ENDFILE:
+    case Comm::COMM_ERROR:
+        *tunnelState->status_ptr = Http::scInternalServerError;
+        tunnelErrorComplete(conn->fd, data, 0);
+        return;
+    default:
+        break;
+    }
+
+    if (!body.isEmpty()) {
+        tunnelState->tlsAlreadyRead += body.length();
+        tunnelState->tlsData.append(body);
+    }
+    debugs(26, 3, "Read " << rd.size << " bytes");
+
+    if (tunnelState->tlsBodySize <= tunnelState->tlsAlreadyRead) {
+        debugs(26, 3, "tls body is done reading");
+        readClientHelloDone(tunnelState);
+        return;
+    }
+    AsyncCall::Pointer call = commCbCall(5,5, "readClientHello",
+                                        CommIoCbPtrFun(readClientHello, tunnelState));
+    Comm::Read(tunnelState->client.conn, call);
+}
+
+static void readClientHelloHeader(const Comm::ConnectionPointer &conn, char *, size_t len, Comm::Flag flag, int, void *data) {
+    TunnelStateData *tunnelState = (TunnelStateData *)data;
+
+    if (flag != Comm::OK) {
+        *tunnelState->status_ptr = Http::scInternalServerError;
+        tunnelErrorComplete(conn->fd, data, 0);
+        return;
+    }
+
+    CommIoCbParams rd(tunnelState);
+    rd.conn = tunnelState->client.conn;
+    rd.size = tunnelState->tlsHeaderLeftToRead;
+    debugs(26, 3, "Reading tls header, left to read is " << rd.size);
+
+    SBuf header;
+    header.reserveSpace(rd.size);
+    switch(Comm::ReadNow(rd, header)) {
+    case Comm::ENDFILE:
+    case Comm::COMM_ERROR:
+        *tunnelState->status_ptr = Http::scInternalServerError;
+        tunnelErrorComplete(conn->fd, data, 0);
+        return;
+    default:
+        break;
+    }
+
+    debugs(26, 3, "Read " << rd.size << " bytes");
+    if (!header.isEmpty()) {
+        tunnelState->tlsData.append(header);
+        tunnelState->tlsHeaderLeftToRead -= rd.size;
+    }
+
+    if (!tunnelState->tlsFirstByteChecked && rd.size >= 1 && tunnelState->tlsData[0] != 0x16) {
+        debugs(26, 3, "First byte isn't 0x16, initiating blind pump");
+        tunnelState->tlsFirstByteChecked = true;
+        tunnelState->preReadClientData.append(tunnelState->tlsData);
+        tunnelStartShoveling(tunnelState);
+        return;
+    }
+
+    if (tunnelState->tlsHeaderLeftToRead == 0) {
+        if (!isSNICompatible(tunnelState->tlsData)) {
+            tunnelState->preReadClientData.append(tunnelState->tlsData);
+            tunnelStartShoveling(tunnelState);
+            return;
+        }
+        uint8_t firstByte = tunnelState->tlsData[3];
+        uint8_t secondByte = tunnelState->tlsData[4];
+        tunnelState->tlsBodySize = ((uint32_t)firstByte << 8) | (uint32_t)secondByte;
+        tunnelState->tlsAlreadyRead = 0;
+        debugs(26, 3, "Header is done reading, tlsBodySize=" << tunnelState->tlsBodySize);
+        AsyncCall::Pointer call = commCbCall(5,5, "readClientHello",
+                                            CommIoCbPtrFun(readClientHello, tunnelState));
+        Comm::Read(tunnelState->client.conn, call);
+    } else {
+        AsyncCall::Pointer call = commCbCall(5,5, "readClientHelloHeader",
+                                            CommIoCbPtrFun(readClientHelloHeader, tunnelState));
+        Comm::Read(tunnelState->client.conn, call);
+    }
+}
+
+static void
+startClientHelloRead(TunnelStateData *tunnelState) {
+    SBuf::size_type initialLength = 0;
+    if (tunnelState->http.valid() && tunnelState->http->getConn() && !tunnelState->http->getConn()->inBuf.isEmpty()) {
+        SBuf * const in = &tunnelState->http->getConn()->inBuf;
+        debugs(26, 3, "We have non-empty buffer from client, check first byte to equal 0x16");
+        tunnelState->tlsFirstByteChecked = true;
+        if (in->at(0) != 0x16) {
+            debugs(26, 3, "First byte isn't 0x16, just shovel data");
+            tunnelStartShoveling(tunnelState);
+            return;
+        } else {
+            debugs(26, 3, "First byte is 0x16, continue reading client hello");
+            tunnelState->tlsData.append(*in);
+            initialLength = tunnelState->tlsData.length();
+            in->consume();
+        }
+    }
+
+    debugs(26, 3, "initialLength=" << initialLength);
+
+    if (initialLength >= 0 && initialLength < 5) {
+        tunnelState->tlsHeaderLeftToRead = 5 - initialLength;
+
+        debugs(26, 3, "Left to read header: " << tunnelState->tlsHeaderLeftToRead);
+        AsyncCall::Pointer call = commCbCall(5,5, "readClientHelloHeader",
+                                            CommIoCbPtrFun(readClientHelloHeader, tunnelState));
+        Comm::Read(tunnelState->client.conn, call);
+    } else if (initialLength >= 5) {
+        if (!isSNICompatible(tunnelState->tlsData)) {
+            tunnelState->preReadClientData.append(tunnelState->tlsData);
+            tunnelStartShoveling(tunnelState);
+            return;
+        }
+        uint8_t firstByte = tunnelState->tlsData[3];
+        uint8_t secondByte = tunnelState->tlsData[4];
+        uint32_t tlsBodySize = ((uint32_t)firstByte << 8) | (uint32_t)secondByte;
+        tunnelState->tlsBodySize = tlsBodySize;
+        initialLength -= 5;
+        tunnelState->tlsAlreadyRead = initialLength;
+        debugs(26, 3, "tlsBodySize=" << tlsBodySize << ", tlsAlreadyRead=" << initialLength);
+        if (tlsBodySize <= tunnelState->tlsAlreadyRead) {
+            debugs(26, 3, "tlsBodySize <= tlsAlreadyRead, client hello is done reading");
+            readClientHelloDone(tunnelState);
+            return;
+        }
+
+        AsyncCall::Pointer call = commCbCall(5,5, "readClientHello",
+                                            CommIoCbPtrFun(readClientHello, tunnelState));
+        Comm::Read(tunnelState->client.conn, call);
+    }
+}
+
 /**
  * All the pieces we need to write to client and/or server connection
  * have been written.
@@ -918,7 +1168,14 @@ tunnelConnectedWriteDone(const Comm::ConnectionPointer &conn, char *, size_t len
         http->out.size += len;
     }
 
-    tunnelStartShoveling(tunnelState);
+    unsigned short port = tunnelState->request->url.port();
+    if (port == 443) {
+        debugs(26, 3, "Port=443, starting TLS hello read");
+        tunnelState->tlsFirstByteChecked = false;
+        startClientHelloRead(tunnelState);
+    } else {
+        tunnelStartShoveling(tunnelState);
+    }
 }
 
 /// Called when we are done writing CONNECT request to a peer.
_______________________________________________
squid-users mailing list
squid-users@lists.squid-cache.org
http://lists.squid-cache.org/listinfo/squid-users

Reply via email to