mod_websocket_handle_incoming now performs the core logic for mod_websocket_data_framing. This will let us turn the processing loop into a combination read/write loop more easily.
The state variables have all been moved into a separate WebSocketReadState struct. --- mod_websocket.c | 652 +++++++++++++++++++++++++++++--------------------------- 1 file changed, 342 insertions(+), 310 deletions(-) diff --git a/mod_websocket.c b/mod_websocket.c index 62be765..1e88ef6 100644 --- a/mod_websocket.c +++ b/mod_websocket.c @@ -476,6 +476,327 @@ typedef struct _WebSocketFrameData unsigned int utf8_state; } WebSocketFrameData; +/* Variables that need to persist across calls to mod_websocket_handle_incoming */ +typedef struct +{ + int framing_state; + unsigned short status_code; + /* XXX fin and opcode appear to be duplicated with frame; can they be removed? */ + unsigned char fin; + unsigned char opcode; + WebSocketFrameData control_frame; + WebSocketFrameData message_frame; + WebSocketFrameData *frame; + apr_int64_t payload_length; + apr_int64_t mask_offset; + apr_int64_t extension_bytes_remaining; + int payload_length_bytes_remaining; + int masking; + int mask_index; + unsigned char mask[4]; +} WebSocketReadState; + +static void mod_websocket_handle_incoming(const WebSocketServer *server, + unsigned char *block, + apr_int64_t block_size, + WebSocketReadState *state, + websocket_config_rec *conf, + void *plugin_private) +{ + apr_int64_t block_offset = 0; + + while (block_offset < block_size) { + switch (state->framing_state) { + case DATA_FRAMING_START: + /* + * Since we don't currently support any extensions, + * the reserve bits must be 0 + */ + if ((FRAME_GET_RSV1(block[block_offset]) != 0) || + (FRAME_GET_RSV2(block[block_offset]) != 0) || + (FRAME_GET_RSV3(block[block_offset]) != 0)) { + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_PROTOCOL_ERROR; + break; + } + state->fin = FRAME_GET_FIN(block[block_offset]); + state->opcode = FRAME_GET_OPCODE(block[block_offset++]); + + state->framing_state = DATA_FRAMING_PAYLOAD_LENGTH; + + if (state->opcode >= 0x8) { /* Control frame */ + if (state->fin) { + state->frame = &state->control_frame; + state->frame->opcode = state->opcode; + state->frame->utf8_state = UTF8_VALID; + } + else { + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_PROTOCOL_ERROR; + break; + } + } + else { /* Message frame */ + state->frame = &state->message_frame; + if (state->opcode) { + if (state->frame->fin) { + state->frame->opcode = state->opcode; + state->frame->utf8_state = UTF8_VALID; + } + else { + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_PROTOCOL_ERROR; + break; + } + } + else if (state->frame->fin || + ((state->opcode = state->frame->opcode) == 0)) { + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_PROTOCOL_ERROR; + break; + } + state->frame->fin = state->fin; + } + state->payload_length = 0; + state->payload_length_bytes_remaining = 0; + + if (block_offset >= block_size) { + break; /* Only break if we need more data */ + } + case DATA_FRAMING_PAYLOAD_LENGTH: + state->payload_length = (apr_int64_t) + FRAME_GET_PAYLOAD_LEN(block[block_offset]); + state->masking = FRAME_GET_MASK(block[block_offset++]); + + if (state->payload_length == 126) { + state->payload_length = 0; + state->payload_length_bytes_remaining = 2; + } + else if (state->payload_length == 127) { + state->payload_length = 0; + state->payload_length_bytes_remaining = 8; + } + else { + state->payload_length_bytes_remaining = 0; + } + if ((state->masking == 0) || /* Client-side mask is required */ + ((state->opcode >= 0x8) && /* Control opcodes cannot have a payload larger than 125 bytes */ + (state->payload_length_bytes_remaining != 0))) { + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_PROTOCOL_ERROR; + break; + } + else { + state->framing_state = DATA_FRAMING_PAYLOAD_LENGTH_EXT; + } + if (block_offset >= block_size) { + break; /* Only break if we need more data */ + } + case DATA_FRAMING_PAYLOAD_LENGTH_EXT: + while ((state->payload_length_bytes_remaining > 0) && + (block_offset < block_size)) { + state->payload_length *= 256; + state->payload_length += block[block_offset++]; + state->payload_length_bytes_remaining--; + } + if (state->payload_length_bytes_remaining == 0) { + if ((state->payload_length < 0) || + (state->payload_length > conf->payload_limit)) { + /* Invalid payload length */ + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = (server->state->protocol_version >= 13) ? + STATUS_CODE_MESSAGE_TOO_LARGE : + STATUS_CODE_RESERVED; + break; + } + else if (state->masking != 0) { + state->framing_state = DATA_FRAMING_MASK; + } + else { + state->framing_state = DATA_FRAMING_EXTENSION_DATA; + break; + } + } + if (block_offset >= block_size) { + break; /* Only break if we need more data */ + } + case DATA_FRAMING_MASK: + while ((state->mask_index < 4) && (block_offset < block_size)) { + state->mask[state->mask_index++] = block[block_offset++]; + } + if (state->mask_index == 4) { + state->framing_state = DATA_FRAMING_EXTENSION_DATA; + state->mask_offset = 0; + state->mask_index = 0; + if ((state->mask[0] == 0) && (state->mask[1] == 0) && + (state->mask[2] == 0) && (state->mask[3] == 0)) { + state->masking = 0; + } + } + else { + break; + } + /* Fall through */ + case DATA_FRAMING_EXTENSION_DATA: + /* Deal with extension data when we support them -- FIXME */ + if (state->extension_bytes_remaining == 0) { + if (state->payload_length > 0) { + state->frame->application_data = (unsigned char *) + realloc(state->frame->application_data, + state->frame->application_data_offset + + state->payload_length); + if (state->frame->application_data == NULL) { + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = (server->state->protocol_version >= 13) ? + STATUS_CODE_INTERNAL_ERROR : + STATUS_CODE_GOING_AWAY; + break; + } + } + state->framing_state = DATA_FRAMING_APPLICATION_DATA; + } + /* Fall through */ + case DATA_FRAMING_APPLICATION_DATA: + { + apr_int64_t block_data_length; + apr_int64_t block_length = 0; + apr_uint64_t application_data_offset = + state->frame->application_data_offset; + unsigned char *application_data = + state->frame->application_data; + + block_length = block_size - block_offset; + block_data_length = + (state->payload_length > + block_length) ? block_length : state->payload_length; + + if (state->masking) { + apr_int64_t i; + + if (state->opcode == OPCODE_TEXT) { + unsigned int utf8_state = state->frame->utf8_state; + unsigned char c; + + for (i = 0; i < block_data_length; i++) { + c = block[block_offset++] ^ + state->mask[state->mask_offset++ & 3]; + utf8_state = + validate_utf8[utf8_state + c]; + if (utf8_state == UTF8_INVALID) { + state->payload_length = block_data_length; + break; + } + application_data + [application_data_offset++] = c; + } + state->frame->utf8_state = utf8_state; + } + else { + /* Need to optimize the unmasking -- FIXME */ + for (i = 0; i < block_data_length; i++) { + application_data + [application_data_offset++] = + block[block_offset++] ^ + state->mask[state->mask_offset++ & 3]; + } + } + } + else if (block_data_length > 0) { + memcpy(&application_data[application_data_offset], + &block[block_offset], block_data_length); + if (state->opcode == OPCODE_TEXT) { + apr_int64_t i, application_data_end = + application_data_offset + + block_data_length; + unsigned int utf8_state = state->frame->utf8_state; + + for (i = application_data_offset; + i < application_data_end; i++) { + utf8_state = + validate_utf8[utf8_state + + application_data[i]]; + if (utf8_state == UTF8_INVALID) { + state->payload_length = block_data_length; + break; + } + } + state->frame->utf8_state = utf8_state; + } + application_data_offset += block_data_length; + block_offset += block_data_length; + } + state->payload_length -= block_data_length; + + if (state->payload_length == 0) { + int message_type = MESSAGE_TYPE_INVALID; + + switch (state->opcode) { + case OPCODE_TEXT: + if ((state->fin && + (state->frame->utf8_state != UTF8_VALID)) || + (state->frame->utf8_state == UTF8_INVALID)) { + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_INVALID_UTF8; + } + else { + message_type = MESSAGE_TYPE_TEXT; + } + break; + case OPCODE_BINARY: + message_type = MESSAGE_TYPE_BINARY; + break; + case OPCODE_CLOSE: + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_OK; + break; + case OPCODE_PING: + apr_thread_mutex_lock(server->state->mutex); + mod_websocket_send_internal(server->state, + MESSAGE_TYPE_PONG, + application_data, + application_data_offset); + apr_thread_mutex_unlock(server->state->mutex); + break; + case OPCODE_PONG: + break; + default: + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_PROTOCOL_ERROR; + break; + } + if (state->fin && (message_type != MESSAGE_TYPE_INVALID)) { + conf->plugin->on_message(plugin_private, + server, message_type, + application_data, + application_data_offset); + } + if (state->framing_state != DATA_FRAMING_CLOSE) { + state->framing_state = DATA_FRAMING_START; + + if (state->fin) { + if (state->frame->application_data != NULL) { + free(state->frame->application_data); + state->frame->application_data = NULL; + } + application_data_offset = 0; + } + } + } + state->frame->application_data_offset = + application_data_offset; + } + break; + case DATA_FRAMING_CLOSE: + block_offset = block_size; + break; + default: + state->framing_state = DATA_FRAMING_CLOSE; + state->status_code = STATUS_CODE_PROTOCOL_ERROR; + break; + } + } +} + /* * The data framing handler requires that the server state mutex is locked by * the caller upon entering this function. It will be locked when leaving too. @@ -509,19 +830,19 @@ static void mod_websocket_data_framing(const WebSocketServer *server, (apr_pollset_create(&pollset, 1, pool, 0) == APR_SUCCESS)) { unsigned char block[BLOCK_DATA_SIZE]; apr_int64_t block_size; - apr_int64_t extension_bytes_remaining = 0; - apr_int64_t payload_length = 0; - apr_int64_t mask_offset = 0; - int framing_state = DATA_FRAMING_START; - int payload_length_bytes_remaining = 0; - int mask_index = 0, masking = 0; - unsigned char mask[4] = { 0, 0, 0, 0 }; - unsigned char fin = 0, opcode = 0xFF; - WebSocketFrameData control_frame = { 0, NULL, 1, 8, UTF8_VALID }; - WebSocketFrameData message_frame = { 0, NULL, 1, 0, UTF8_VALID }; - WebSocketFrameData *frame = &control_frame; - unsigned short status_code = STATUS_CODE_OK; unsigned char status_code_buffer[2]; + WebSocketReadState read_state = { 0 }; + + read_state.framing_state = DATA_FRAMING_START; + read_state.status_code = STATUS_CODE_OK; + read_state.control_frame.fin = 1; + read_state.control_frame.opcode = 8; + read_state.control_frame.utf8_state = UTF8_VALID; + read_state.message_frame.fin = 1; + read_state.message_frame.opcode = 0; + read_state.message_frame.utf8_state = UTF8_VALID; + read_state.frame = &read_state.control_frame; + read_state.opcode = 0xFF; /* Initialize the pollset */ pollfd.p = pool; @@ -534,8 +855,7 @@ static void mod_websocket_data_framing(const WebSocketServer *server, state->obb = obb; apr_thread_mutex_unlock(state->mutex); - while ((framing_state != DATA_FRAMING_CLOSE)) { - apr_int64_t block_offset = 0; + while ((read_state.framing_state != DATA_FRAMING_CLOSE)) { apr_status_t rv; do { @@ -548,307 +868,19 @@ static void mod_websocket_data_framing(const WebSocketServer *server, break; } - while (block_offset < block_size) { - switch (framing_state) { - case DATA_FRAMING_START: - /* - * Since we don't currently support any extensions, - * the reserve bits must be 0 - */ - if ((FRAME_GET_RSV1(block[block_offset]) != 0) || - (FRAME_GET_RSV2(block[block_offset]) != 0) || - (FRAME_GET_RSV3(block[block_offset]) != 0)) { - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_PROTOCOL_ERROR; - break; - } - fin = FRAME_GET_FIN(block[block_offset]); - opcode = FRAME_GET_OPCODE(block[block_offset++]); - - framing_state = DATA_FRAMING_PAYLOAD_LENGTH; - - if (opcode >= 0x8) { /* Control frame */ - if (fin) { - frame = &control_frame; - frame->opcode = opcode; - frame->utf8_state = UTF8_VALID; - } - else { - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_PROTOCOL_ERROR; - break; - } - } - else { /* Message frame */ - frame = &message_frame; - if (opcode) { - if (frame->fin) { - frame->opcode = opcode; - frame->utf8_state = UTF8_VALID; - } - else { - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_PROTOCOL_ERROR; - break; - } - } - else if (frame->fin || - ((opcode = frame->opcode) == 0)) { - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_PROTOCOL_ERROR; - break; - } - frame->fin = fin; - } - payload_length = 0; - payload_length_bytes_remaining = 0; - - if (block_offset >= block_size) { - break; /* Only break if we need more data */ - } - case DATA_FRAMING_PAYLOAD_LENGTH: - payload_length = (apr_int64_t) - FRAME_GET_PAYLOAD_LEN(block[block_offset]); - masking = FRAME_GET_MASK(block[block_offset++]); - - if (payload_length == 126) { - payload_length = 0; - payload_length_bytes_remaining = 2; - } - else if (payload_length == 127) { - payload_length = 0; - payload_length_bytes_remaining = 8; - } - else { - payload_length_bytes_remaining = 0; - } - if ((masking == 0) || /* Client-side mask is required */ - ((opcode >= 0x8) && /* Control opcodes cannot have a payload larger than 125 bytes */ - (payload_length_bytes_remaining != 0))) { - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_PROTOCOL_ERROR; - break; - } - else { - framing_state = DATA_FRAMING_PAYLOAD_LENGTH_EXT; - } - if (block_offset >= block_size) { - break; /* Only break if we need more data */ - } - case DATA_FRAMING_PAYLOAD_LENGTH_EXT: - while ((payload_length_bytes_remaining > 0) && - (block_offset < block_size)) { - payload_length *= 256; - payload_length += block[block_offset++]; - payload_length_bytes_remaining--; - } - if (payload_length_bytes_remaining == 0) { - if ((payload_length < 0) || - (payload_length > conf->payload_limit)) { - /* Invalid payload length */ - framing_state = DATA_FRAMING_CLOSE; - status_code = (state->protocol_version >= 13) ? - STATUS_CODE_MESSAGE_TOO_LARGE : - STATUS_CODE_RESERVED; - break; - } - else if (masking != 0) { - framing_state = DATA_FRAMING_MASK; - } - else { - framing_state = DATA_FRAMING_EXTENSION_DATA; - break; - } - } - if (block_offset >= block_size) { - break; /* Only break if we need more data */ - } - case DATA_FRAMING_MASK: - while ((mask_index < 4) && (block_offset < block_size)) { - mask[mask_index++] = block[block_offset++]; - } - if (mask_index == 4) { - framing_state = DATA_FRAMING_EXTENSION_DATA; - mask_offset = 0; - mask_index = 0; - if ((mask[0] == 0) && (mask[1] == 0) && - (mask[2] == 0) && (mask[3] == 0)) { - masking = 0; - } - } - else { - break; - } - /* Fall through */ - case DATA_FRAMING_EXTENSION_DATA: - /* Deal with extension data when we support them -- FIXME */ - if (extension_bytes_remaining == 0) { - if (payload_length > 0) { - frame->application_data = (unsigned char *) - realloc(frame->application_data, - frame->application_data_offset + - payload_length); - if (frame->application_data == NULL) { - framing_state = DATA_FRAMING_CLOSE; - status_code = (state->protocol_version >= 13) ? - STATUS_CODE_INTERNAL_ERROR : - STATUS_CODE_GOING_AWAY; - break; - } - } - framing_state = DATA_FRAMING_APPLICATION_DATA; - } - /* Fall through */ - case DATA_FRAMING_APPLICATION_DATA: - { - apr_int64_t block_data_length; - apr_int64_t block_length = 0; - apr_uint64_t application_data_offset = - frame->application_data_offset; - unsigned char *application_data = - frame->application_data; - - block_length = block_size - block_offset; - block_data_length = - (payload_length > - block_length) ? block_length : payload_length; - - if (masking) { - apr_int64_t i; - - if (opcode == OPCODE_TEXT) { - unsigned int utf8_state = frame->utf8_state; - unsigned char c; - - for (i = 0; i < block_data_length; i++) { - c = block[block_offset++] ^ - mask[mask_offset++ & 3]; - utf8_state = - validate_utf8[utf8_state + c]; - if (utf8_state == UTF8_INVALID) { - payload_length = block_data_length; - break; - } - application_data - [application_data_offset++] = c; - } - frame->utf8_state = utf8_state; - } - else { - /* Need to optimize the unmasking -- FIXME */ - for (i = 0; i < block_data_length; i++) { - application_data - [application_data_offset++] = - block[block_offset++] ^ - mask[mask_offset++ & 3]; - } - } - } - else if (block_data_length > 0) { - memcpy(&application_data[application_data_offset], - &block[block_offset], block_data_length); - if (opcode == OPCODE_TEXT) { - apr_int64_t i, application_data_end = - application_data_offset + - block_data_length; - unsigned int utf8_state = frame->utf8_state; - - for (i = application_data_offset; - i < application_data_end; i++) { - utf8_state = - validate_utf8[utf8_state + - application_data[i]]; - if (utf8_state == UTF8_INVALID) { - payload_length = block_data_length; - break; - } - } - frame->utf8_state = utf8_state; - } - application_data_offset += block_data_length; - block_offset += block_data_length; - } - payload_length -= block_data_length; - - if (payload_length == 0) { - int message_type = MESSAGE_TYPE_INVALID; - - switch (opcode) { - case OPCODE_TEXT: - if ((fin && - (frame->utf8_state != UTF8_VALID)) || - (frame->utf8_state == UTF8_INVALID)) { - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_INVALID_UTF8; - } - else { - message_type = MESSAGE_TYPE_TEXT; - } - break; - case OPCODE_BINARY: - message_type = MESSAGE_TYPE_BINARY; - break; - case OPCODE_CLOSE: - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_OK; - break; - case OPCODE_PING: - apr_thread_mutex_lock(state->mutex); - mod_websocket_send_internal(state, - MESSAGE_TYPE_PONG, - application_data, - application_data_offset); - apr_thread_mutex_unlock(state->mutex); - break; - case OPCODE_PONG: - break; - default: - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_PROTOCOL_ERROR; - break; - } - if (fin && (message_type != MESSAGE_TYPE_INVALID)) { - conf->plugin->on_message(plugin_private, - server, message_type, - application_data, - application_data_offset); - } - if (framing_state != DATA_FRAMING_CLOSE) { - framing_state = DATA_FRAMING_START; - - if (fin) { - if (frame->application_data != NULL) { - free(frame->application_data); - frame->application_data = NULL; - } - application_data_offset = 0; - } - } - } - frame->application_data_offset = - application_data_offset; - } - break; - case DATA_FRAMING_CLOSE: - block_offset = block_size; - break; - default: - framing_state = DATA_FRAMING_CLOSE; - status_code = STATUS_CODE_PROTOCOL_ERROR; - break; - } - } + mod_websocket_handle_incoming(server, block, block_size, + &read_state, conf, plugin_private); } - if (message_frame.application_data != NULL) { - free(message_frame.application_data); + if (read_state.message_frame.application_data != NULL) { + free(read_state.message_frame.application_data); } - if (control_frame.application_data != NULL) { - free(control_frame.application_data); + if (read_state.control_frame.application_data != NULL) { + free(read_state.control_frame.application_data); } /* Send server-side closing handshake */ - status_code_buffer[0] = (status_code >> 8) & 0xFF; - status_code_buffer[1] = status_code & 0xFF; + status_code_buffer[0] = (read_state.status_code >> 8) & 0xFF; + status_code_buffer[1] = read_state.status_code & 0xFF; apr_thread_mutex_lock(state->mutex); mod_websocket_send_internal(state, MESSAGE_TYPE_CLOSE, -- 2.1.1