This is an automated email from the ASF dual-hosted git repository. hgruszecki pushed a commit to branch deserialization-hardening in repository https://gitbox.apache.org/repos/asf/iggy.git
commit 26fb1d8988e523b27b51df03f229143441c7e1a1 Author: Hubert Gruszecki <[email protected]> AuthorDate: Mon Feb 16 10:59:58 2026 +0100 fix(security): harden all deserializers against malformed packets A single truncated or corrupted network packet could panic the server via unchecked slice indexing in from_bytes/from_raw_bytes across 24 command deserializers. Since each IggyShard runs on a single compio thread, one panic kills the entire shard. Replace every direct bytes[n] and bytes[a..b] access with checked .get().ok_or(IggyError::InvalidCommand) so malformed input returns an error instead of crashing. Notable fixes beyond the mechanical pattern: - IggyMessageView::new() now returns Result, validating full message size upfront; payload()/calculate_checksum() no longer use unwrap_or_default() on potentially corrupt data - update_permissions had an .unwrap() on attacker-controlled permissions_length — replaced with proper error propagation - send_messages_handler message_size arithmetic could underflow replaced with checked_sub() chain - LoginUser now rejects 1-3 trailing bytes as corrupt instead of silently ignoring incomplete length prefixes - Removed dead name.len() != name_length guards that were unreachable after .get() + from_utf8() validation --- .../consumer_groups/create_consumer_group.rs | 30 +++- .../consumer_offsets/delete_consumer_offset.rs | 10 +- .../consumer_offsets/get_consumer_offset.rs | 10 +- .../consumer_offsets/store_consumer_offset.rs | 14 +- .../src/commands/messages/flush_unsaved_buffer.rs | 6 +- core/common/src/commands/messages/poll_messages.rs | 24 ++- .../src/commands/partitions/create_partitions.rs | 4 +- .../src/commands/partitions/delete_partitions.rs | 4 +- .../create_personal_access_token.rs | 23 +-- .../delete_personal_access_token.rs | 41 ++++- .../login_with_personal_access_token.rs | 41 ++++- .../src/commands/segments/delete_segments.rs | 8 +- core/common/src/commands/streams/create_stream.rs | 41 ++++- core/common/src/commands/streams/update_stream.rs | 15 +- core/common/src/commands/topics/create_topic.rs | 59 +++++-- core/common/src/commands/topics/update_topic.rs | 55 ++++-- core/common/src/commands/users/change_password.rs | 43 ++++- core/common/src/commands/users/create_user.rs | 76 ++++++--- core/common/src/commands/users/login_user.rs | 189 ++++++++++++++++----- .../src/commands/users/update_permissions.rs | 19 ++- core/common/src/commands/users/update_user.rs | 41 ++++- core/common/src/types/identifier/mod.rs | 77 +++++++-- core/common/src/types/message/message_view.rs | 47 ++--- core/common/src/types/message/messages_batch.rs | 7 +- .../common/src/types/message/messages_batch_mut.rs | 38 +++-- core/common/src/types/message/partitioning.rs | 31 +--- .../handlers/messages/send_messages_handler.rs | 25 ++- 27 files changed, 717 insertions(+), 261 deletions(-) diff --git a/core/common/src/commands/consumer_groups/create_consumer_group.rs b/core/common/src/commands/consumer_groups/create_consumer_group.rs index c03d39964..f697bc9c7 100644 --- a/core/common/src/commands/consumer_groups/create_consumer_group.rs +++ b/core/common/src/commands/consumer_groups/create_consumer_group.rs @@ -96,10 +96,14 @@ impl BytesSerializable for CreateConsumerGroup { position += stream_id.get_size_bytes().as_bytes_usize(); let topic_id = Identifier::from_bytes(bytes.slice(position..))?; position += topic_id.get_size_bytes().as_bytes_usize(); - let name_length = bytes[position]; - let name = from_utf8(&bytes[position + 1..position + 1 + name_length as usize]) - .map_err(|_| IggyError::InvalidUtf8)? - .to_string(); + let name_length = *bytes.get(position).ok_or(IggyError::InvalidCommand)?; + let name = from_utf8( + bytes + .get(position + 1..position + 1 + name_length as usize) + .ok_or(IggyError::InvalidCommand)?, + ) + .map_err(|_| IggyError::InvalidUtf8)? + .to_string(); let command = CreateConsumerGroup { stream_id, topic_id, @@ -142,6 +146,24 @@ mod tests { assert_eq!(name, command.name); } + #[test] + fn from_bytes_should_fail_on_empty_input() { + assert!(CreateConsumerGroup::from_bytes(Bytes::new()).is_err()); + } + + #[test] + fn from_bytes_should_fail_on_truncated_input() { + let command = CreateConsumerGroup::default(); + let bytes = command.to_bytes(); + for i in 0..bytes.len() - 1 { + let truncated = bytes.slice(..i); + assert!( + CreateConsumerGroup::from_bytes(truncated).is_err(), + "expected error for truncation at byte {i}" + ); + } + } + #[test] fn should_be_deserialized_from_bytes() { let stream_id = Identifier::numeric(1).unwrap(); diff --git a/core/common/src/commands/consumer_offsets/delete_consumer_offset.rs b/core/common/src/commands/consumer_offsets/delete_consumer_offset.rs index b608a8d55..b0fc12aa2 100644 --- a/core/common/src/commands/consumer_offsets/delete_consumer_offset.rs +++ b/core/common/src/commands/consumer_offsets/delete_consumer_offset.rs @@ -99,7 +99,8 @@ impl BytesSerializable for DeleteConsumerOffset { } let mut position = 0; - let consumer_kind = ConsumerKind::from_code(bytes[0])?; + let consumer_kind = + ConsumerKind::from_code(*bytes.first().ok_or(IggyError::InvalidCommand)?)?; let consumer_id = Identifier::from_bytes(bytes.slice(1..))?; position += 1 + consumer_id.get_size_bytes().as_bytes_usize(); let consumer = Consumer { @@ -110,10 +111,11 @@ impl BytesSerializable for DeleteConsumerOffset { position += stream_id.get_size_bytes().as_bytes_usize(); let topic_id = Identifier::from_bytes(bytes.slice(position..))?; position += topic_id.get_size_bytes().as_bytes_usize(); - // Decode partition_id with flag byte: 1 = Some, 0 = None - let has_partition_id = bytes[position]; + let has_partition_id = *bytes.get(position).ok_or(IggyError::InvalidCommand)?; let partition_id_value = u32::from_le_bytes( - bytes[position + 1..position + 5] + bytes + .get(position + 1..position + 5) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); diff --git a/core/common/src/commands/consumer_offsets/get_consumer_offset.rs b/core/common/src/commands/consumer_offsets/get_consumer_offset.rs index 0f542e09f..f27624948 100644 --- a/core/common/src/commands/consumer_offsets/get_consumer_offset.rs +++ b/core/common/src/commands/consumer_offsets/get_consumer_offset.rs @@ -104,7 +104,8 @@ impl BytesSerializable for GetConsumerOffset { } let mut position = 0; - let consumer_kind = ConsumerKind::from_code(bytes[0])?; + let consumer_kind = + ConsumerKind::from_code(*bytes.first().ok_or(IggyError::InvalidCommand)?)?; let consumer_id = Identifier::from_bytes(bytes.slice(1..))?; position += 1 + consumer_id.get_size_bytes().as_bytes_usize(); let consumer = Consumer { @@ -115,10 +116,11 @@ impl BytesSerializable for GetConsumerOffset { position += stream_id.get_size_bytes().as_bytes_usize(); let topic_id = Identifier::from_bytes(bytes.slice(position..))?; position += topic_id.get_size_bytes().as_bytes_usize(); - // Decode partition_id with flag byte: 1 = Some, 0 = None - let has_partition_id = bytes[position]; + let has_partition_id = *bytes.get(position).ok_or(IggyError::InvalidCommand)?; let partition_id_value = u32::from_le_bytes( - bytes[position + 1..position + 5] + bytes + .get(position + 1..position + 5) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); diff --git a/core/common/src/commands/consumer_offsets/store_consumer_offset.rs b/core/common/src/commands/consumer_offsets/store_consumer_offset.rs index c7705fbfb..d8512df75 100644 --- a/core/common/src/commands/consumer_offsets/store_consumer_offset.rs +++ b/core/common/src/commands/consumer_offsets/store_consumer_offset.rs @@ -104,7 +104,8 @@ impl BytesSerializable for StoreConsumerOffset { } let mut position = 0; - let consumer_kind = ConsumerKind::from_code(bytes[0])?; + let consumer_kind = + ConsumerKind::from_code(*bytes.first().ok_or(IggyError::InvalidCommand)?)?; let consumer_id = Identifier::from_bytes(bytes.slice(1..))?; position += 1 + consumer_id.get_size_bytes().as_bytes_usize(); let consumer = Consumer { @@ -115,10 +116,11 @@ impl BytesSerializable for StoreConsumerOffset { position += stream_id.get_size_bytes().as_bytes_usize(); let topic_id = Identifier::from_bytes(bytes.slice(position..))?; position += topic_id.get_size_bytes().as_bytes_usize(); - // Decode partition_id with flag byte: 1 = Some, 0 = None - let has_partition_id = bytes[position]; + let has_partition_id = *bytes.get(position).ok_or(IggyError::InvalidCommand)?; let partition_id_value = u32::from_le_bytes( - bytes[position + 1..position + 5] + bytes + .get(position + 1..position + 5) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); @@ -128,7 +130,9 @@ impl BytesSerializable for StoreConsumerOffset { None }; let offset = u64::from_le_bytes( - bytes[position + 5..position + 13] + bytes + .get(position + 5..position + 13) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); diff --git a/core/common/src/commands/messages/flush_unsaved_buffer.rs b/core/common/src/commands/messages/flush_unsaved_buffer.rs index beb26e5e9..2235a911b 100644 --- a/core/common/src/commands/messages/flush_unsaved_buffer.rs +++ b/core/common/src/commands/messages/flush_unsaved_buffer.rs @@ -86,12 +86,14 @@ impl BytesSerializable for FlushUnsavedBuffer { let topic_id = Identifier::from_bytes(bytes.slice(position..))?; position += topic_id.to_bytes().len(); let partition_id = u32::from_le_bytes( - bytes[position..position + 4] + bytes + .get(position..position + 4) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); position += 4; - let fsync = bytes[position] == 1; + let fsync = *bytes.get(position).ok_or(IggyError::InvalidCommand)? == 1; Ok(FlushUnsavedBuffer { stream_id, topic_id, diff --git a/core/common/src/commands/messages/poll_messages.rs b/core/common/src/commands/messages/poll_messages.rs index c52b335c8..ac266ca4c 100644 --- a/core/common/src/commands/messages/poll_messages.rs +++ b/core/common/src/commands/messages/poll_messages.rs @@ -157,7 +157,8 @@ impl BytesSerializable for PollMessages { } let mut position = 0; - let consumer_kind = ConsumerKind::from_code(bytes[0])?; + let consumer_kind = + ConsumerKind::from_code(*bytes.first().ok_or(IggyError::InvalidCommand)?)?; let consumer_id = Identifier::from_bytes(bytes.slice(1..))?; position += 1 + consumer_id.get_size_bytes().as_bytes_usize(); let consumer = Consumer { @@ -168,10 +169,11 @@ impl BytesSerializable for PollMessages { position += stream_id.get_size_bytes().as_bytes_usize(); let topic_id = Identifier::from_bytes(bytes.slice(position..))?; position += topic_id.get_size_bytes().as_bytes_usize(); - // Decode partition_id with flag byte: 1 = Some, 0 = None - let has_partition_id = bytes[position]; + let has_partition_id = *bytes.get(position).ok_or(IggyError::InvalidCommand)?; let partition_id_value = u32::from_le_bytes( - bytes[position + 1..position + 5] + bytes + .get(position + 1..position + 5) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); @@ -180,10 +182,13 @@ impl BytesSerializable for PollMessages { } else { None }; - let polling_kind = PollingKind::from_code(bytes[position + 5])?; + let polling_kind = + PollingKind::from_code(*bytes.get(position + 5).ok_or(IggyError::InvalidCommand)?)?; position += 6; let value = u64::from_le_bytes( - bytes[position..position + 8] + bytes + .get(position..position + 8) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); @@ -192,12 +197,13 @@ impl BytesSerializable for PollMessages { value, }; let count = u32::from_le_bytes( - bytes[position + 8..position + 12] + bytes + .get(position + 8..position + 12) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); - let auto_commit = bytes[position + 12]; - let auto_commit = matches!(auto_commit, 1); + let auto_commit = *bytes.get(position + 12).ok_or(IggyError::InvalidCommand)? == 1; let command = PollMessages { consumer, stream_id, diff --git a/core/common/src/commands/partitions/create_partitions.rs b/core/common/src/commands/partitions/create_partitions.rs index a01d3b26d..2bd9ee454 100644 --- a/core/common/src/commands/partitions/create_partitions.rs +++ b/core/common/src/commands/partitions/create_partitions.rs @@ -92,7 +92,9 @@ impl BytesSerializable for CreatePartitions { let topic_id = Identifier::from_bytes(bytes.slice(position..))?; position += topic_id.get_size_bytes().as_bytes_usize(); let partitions_count = u32::from_le_bytes( - bytes[position..position + 4] + bytes + .get(position..position + 4) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); diff --git a/core/common/src/commands/partitions/delete_partitions.rs b/core/common/src/commands/partitions/delete_partitions.rs index 12be4448a..9e0cf1ddc 100644 --- a/core/common/src/commands/partitions/delete_partitions.rs +++ b/core/common/src/commands/partitions/delete_partitions.rs @@ -92,7 +92,9 @@ impl BytesSerializable for DeletePartitions { let topic_id = Identifier::from_bytes(bytes.slice(position..))?; position += topic_id.get_size_bytes().as_bytes_usize(); let partitions_count = u32::from_le_bytes( - bytes[position..position + 4] + bytes + .get(position..position + 4) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); diff --git a/core/common/src/commands/personal_access_tokens/create_personal_access_token.rs b/core/common/src/commands/personal_access_tokens/create_personal_access_token.rs index fa457a33e..395822854 100644 --- a/core/common/src/commands/personal_access_tokens/create_personal_access_token.rs +++ b/core/common/src/commands/personal_access_tokens/create_personal_access_token.rs @@ -82,17 +82,20 @@ impl BytesSerializable for CreatePersonalAccessToken { return Err(IggyError::InvalidCommand); } - let name_length = bytes[0]; - let name = from_utf8(&bytes.slice(1..1 + name_length as usize)) - .map_err(|_| IggyError::InvalidUtf8)? - .to_string(); - if name.len() != name_length as usize { - return Err(IggyError::InvalidCommand); - } - - let position = 1 + name_length as usize; + let name_length = *bytes.first().ok_or(IggyError::InvalidCommand)? as usize; + let name = from_utf8( + bytes + .get(1..1 + name_length) + .ok_or(IggyError::InvalidCommand)?, + ) + .map_err(|_| IggyError::InvalidUtf8)? + .to_string(); + + let position = 1 + name_length; let expiry = u64::from_le_bytes( - bytes[position..position + 8] + bytes + .get(position..position + 8) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); diff --git a/core/common/src/commands/personal_access_tokens/delete_personal_access_token.rs b/core/common/src/commands/personal_access_tokens/delete_personal_access_token.rs index 064f78c48..0d19b68c9 100644 --- a/core/common/src/commands/personal_access_tokens/delete_personal_access_token.rs +++ b/core/common/src/commands/personal_access_tokens/delete_personal_access_token.rs @@ -76,13 +76,14 @@ impl BytesSerializable for DeletePersonalAccessToken { return Err(IggyError::InvalidCommand); } - let name_length = bytes[0]; - let name = from_utf8(&bytes[1..1 + name_length as usize]) - .map_err(|_| IggyError::InvalidUtf8)? - .to_string(); - if name.len() != name_length as usize { - return Err(IggyError::InvalidCommand); - } + let name_length = *bytes.first().ok_or(IggyError::InvalidCommand)? as usize; + let name = from_utf8( + bytes + .get(1..1 + name_length) + .ok_or(IggyError::InvalidCommand)?, + ) + .map_err(|_| IggyError::InvalidUtf8)? + .to_string(); let command = DeletePersonalAccessToken { name }; Ok(command) @@ -112,6 +113,32 @@ mod tests { assert_eq!(name, command.name); } + #[test] + fn from_bytes_should_fail_on_empty_input() { + assert!(DeletePersonalAccessToken::from_bytes(Bytes::new()).is_err()); + } + + #[test] + fn from_bytes_should_fail_on_truncated_input() { + let command = DeletePersonalAccessToken::default(); + let bytes = command.to_bytes(); + for i in 0..bytes.len() - 1 { + let truncated = bytes.slice(..i); + assert!( + DeletePersonalAccessToken::from_bytes(truncated).is_err(), + "expected error for truncation at byte {i}" + ); + } + } + + #[test] + fn from_bytes_should_fail_on_corrupted_name_length() { + let mut buf = BytesMut::new(); + buf.put_u8(255); + buf.put_slice(b"short"); + assert!(DeletePersonalAccessToken::from_bytes(buf.freeze()).is_err()); + } + #[test] fn should_be_deserialized_from_bytes() { let name = "test"; diff --git a/core/common/src/commands/personal_access_tokens/login_with_personal_access_token.rs b/core/common/src/commands/personal_access_tokens/login_with_personal_access_token.rs index cdba04e37..59fafa9c1 100644 --- a/core/common/src/commands/personal_access_tokens/login_with_personal_access_token.rs +++ b/core/common/src/commands/personal_access_tokens/login_with_personal_access_token.rs @@ -73,13 +73,14 @@ impl BytesSerializable for LoginWithPersonalAccessToken { return Err(IggyError::InvalidCommand); } - let token_length = bytes[0]; - let token = from_utf8(&bytes[1..1 + token_length as usize]) - .map_err(|_| IggyError::InvalidUtf8)? - .to_string(); - if token.len() != token_length as usize { - return Err(IggyError::InvalidCommand); - } + let token_length = *bytes.first().ok_or(IggyError::InvalidCommand)? as usize; + let token = from_utf8( + bytes + .get(1..1 + token_length) + .ok_or(IggyError::InvalidCommand)?, + ) + .map_err(|_| IggyError::InvalidUtf8)? + .to_string(); let command = LoginWithPersonalAccessToken { token }; Ok(command) @@ -109,6 +110,32 @@ mod tests { assert_eq!(token, command.token); } + #[test] + fn from_bytes_should_fail_on_empty_input() { + assert!(LoginWithPersonalAccessToken::from_bytes(Bytes::new()).is_err()); + } + + #[test] + fn from_bytes_should_fail_on_truncated_input() { + let command = LoginWithPersonalAccessToken::default(); + let bytes = command.to_bytes(); + for i in 0..bytes.len() - 1 { + let truncated = bytes.slice(..i); + assert!( + LoginWithPersonalAccessToken::from_bytes(truncated).is_err(), + "expected error for truncation at byte {i}" + ); + } + } + + #[test] + fn from_bytes_should_fail_on_corrupted_token_length() { + let mut buf = BytesMut::new(); + buf.put_u8(255); + buf.put_slice(b"short"); + assert!(LoginWithPersonalAccessToken::from_bytes(buf.freeze()).is_err()); + } + #[test] fn should_be_deserialized_from_bytes() { let token = "test"; diff --git a/core/common/src/commands/segments/delete_segments.rs b/core/common/src/commands/segments/delete_segments.rs index 841a84ac7..c57833f4c 100644 --- a/core/common/src/commands/segments/delete_segments.rs +++ b/core/common/src/commands/segments/delete_segments.rs @@ -97,13 +97,17 @@ impl BytesSerializable for DeleteSegments { let topic_id = Identifier::from_bytes(bytes.slice(position..))?; position += topic_id.get_size_bytes().as_bytes_usize(); let partition_id = u32::from_le_bytes( - bytes[position..position + std::mem::size_of::<u32>()] + bytes + .get(position..position + std::mem::size_of::<u32>()) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); position += std::mem::size_of::<u32>(); let segments_count = u32::from_le_bytes( - bytes[position..position + std::mem::size_of::<u32>()] + bytes + .get(position..position + std::mem::size_of::<u32>()) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); diff --git a/core/common/src/commands/streams/create_stream.rs b/core/common/src/commands/streams/create_stream.rs index 9ff2fec5a..d5964eb5e 100644 --- a/core/common/src/commands/streams/create_stream.rs +++ b/core/common/src/commands/streams/create_stream.rs @@ -73,13 +73,14 @@ impl BytesSerializable for CreateStream { return Err(IggyError::InvalidCommand); } - let name_length = bytes[0]; - let name = from_utf8(&bytes[1..1 + name_length as usize]) - .map_err(|_| IggyError::InvalidUtf8)? - .to_string(); - if name.len() != name_length as usize { - return Err(IggyError::InvalidCommand); - } + let name_length = *bytes.first().ok_or(IggyError::InvalidCommand)? as usize; + let name = from_utf8( + bytes + .get(1..1 + name_length) + .ok_or(IggyError::InvalidCommand)?, + ) + .map_err(|_| IggyError::InvalidUtf8)? + .to_string(); let command = CreateStream { name }; Ok(command) @@ -110,6 +111,32 @@ mod tests { assert_eq!(name, command.name); } + #[test] + fn from_bytes_should_fail_on_empty_input() { + assert!(CreateStream::from_bytes(Bytes::new()).is_err()); + } + + #[test] + fn from_bytes_should_fail_on_truncated_input() { + let command = CreateStream::default(); + let bytes = command.to_bytes(); + for i in 0..bytes.len() - 1 { + let truncated = bytes.slice(..i); + assert!( + CreateStream::from_bytes(truncated).is_err(), + "expected error for truncation at byte {i}" + ); + } + } + + #[test] + fn from_bytes_should_fail_on_corrupted_name_length() { + let mut buf = BytesMut::new(); + buf.put_u8(255); + buf.put_slice(b"short"); + assert!(CreateStream::from_bytes(buf.freeze()).is_err()); + } + #[test] fn should_be_deserialized_from_bytes() { let name = "test".to_string(); diff --git a/core/common/src/commands/streams/update_stream.rs b/core/common/src/commands/streams/update_stream.rs index ce3600f88..a202105bf 100644 --- a/core/common/src/commands/streams/update_stream.rs +++ b/core/common/src/commands/streams/update_stream.rs @@ -85,13 +85,14 @@ impl BytesSerializable for UpdateStream { let mut position = 0; let stream_id = Identifier::from_bytes(bytes.clone())?; position += stream_id.get_size_bytes().as_bytes_usize(); - let name_length = bytes[position]; - let name = from_utf8(&bytes[position + 1..position + 1 + name_length as usize]) - .map_err(|_| IggyError::InvalidUtf8)? - .to_string(); - if name.len() != name_length as usize { - return Err(IggyError::InvalidCommand); - } + let name_length = *bytes.get(position).ok_or(IggyError::InvalidCommand)? as usize; + let name = from_utf8( + bytes + .get(position + 1..position + 1 + name_length) + .ok_or(IggyError::InvalidCommand)?, + ) + .map_err(|_| IggyError::InvalidUtf8)? + .to_string(); let command = UpdateStream { stream_id, name }; Ok(command) diff --git a/core/common/src/commands/topics/create_topic.rs b/core/common/src/commands/topics/create_topic.rs index 48cd3d525..bd4d87f48 100644 --- a/core/common/src/commands/topics/create_topic.rs +++ b/core/common/src/commands/topics/create_topic.rs @@ -127,34 +127,43 @@ impl BytesSerializable for CreateTopic { let stream_id = Identifier::from_bytes(bytes.clone())?; position += stream_id.get_size_bytes().as_bytes_usize(); let partitions_count = u32::from_le_bytes( - bytes[position..position + 4] + bytes + .get(position..position + 4) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); - let compression_algorithm = CompressionAlgorithm::from_code(bytes[position + 4])?; + let compression_algorithm = CompressionAlgorithm::from_code( + *bytes.get(position + 4).ok_or(IggyError::InvalidCommand)?, + )?; let message_expiry = u64::from_le_bytes( - bytes[position + 5..position + 13] + bytes + .get(position + 5..position + 13) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); let message_expiry: IggyExpiry = message_expiry.into(); let max_topic_size = u64::from_le_bytes( - bytes[position + 13..position + 21] + bytes + .get(position + 13..position + 21) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); let max_topic_size: MaxTopicSize = max_topic_size.into(); - let replication_factor = match bytes[position + 21] { + let replication_factor = match *bytes.get(position + 21).ok_or(IggyError::InvalidCommand)? { 0 => None, factor => Some(factor), }; - let name_length = bytes[position + 22]; - let name = from_utf8(&bytes[position + 23..(position + 23 + name_length as usize)]) - .map_err(|_| IggyError::InvalidUtf8)? - .to_string(); - if name.len() != name_length as usize { - return Err(IggyError::InvalidCommand); - } + let name_length = *bytes.get(position + 22).ok_or(IggyError::InvalidCommand)? as usize; + let name = from_utf8( + bytes + .get(position + 23..position + 23 + name_length) + .ok_or(IggyError::InvalidCommand)?, + ) + .map_err(|_| IggyError::InvalidUtf8)? + .to_string(); let command = CreateTopic { stream_id, partitions_count, @@ -229,6 +238,32 @@ mod tests { assert_eq!(name, command.name); } + #[test] + fn from_bytes_should_fail_on_empty_input() { + assert!(CreateTopic::from_bytes(Bytes::new()).is_err()); + } + + #[test] + fn from_bytes_should_fail_on_truncated_input() { + let command = CreateTopic { + stream_id: Identifier::numeric(1).unwrap(), + partitions_count: 3, + compression_algorithm: CompressionAlgorithm::None, + message_expiry: IggyExpiry::NeverExpire, + max_topic_size: MaxTopicSize::ServerDefault, + replication_factor: Some(1), + name: "test".to_string(), + }; + let bytes = command.to_bytes(); + for i in 0..bytes.len() - 1 { + let truncated = bytes.slice(..i); + assert!( + CreateTopic::from_bytes(truncated).is_err(), + "expected error for truncation at byte {i}" + ); + } + } + #[test] fn should_be_deserialized_from_bytes() { let stream_id = Identifier::numeric(1).unwrap(); diff --git a/core/common/src/commands/topics/update_topic.rs b/core/common/src/commands/topics/update_topic.rs index 83dadb69c..6f7bf5a02 100644 --- a/core/common/src/commands/topics/update_topic.rs +++ b/core/common/src/commands/topics/update_topic.rs @@ -128,31 +128,38 @@ impl BytesSerializable for UpdateTopic { position += stream_id.get_size_bytes().as_bytes_usize(); let topic_id = Identifier::from_bytes(bytes.slice(position..))?; position += topic_id.get_size_bytes().as_bytes_usize(); - let compression_algorithm = CompressionAlgorithm::from_code(bytes[position])?; + let compression_algorithm = CompressionAlgorithm::from_code( + *bytes.get(position).ok_or(IggyError::InvalidCommand)?, + )?; position += 1; let message_expiry = u64::from_le_bytes( - bytes[position..position + 8] + bytes + .get(position..position + 8) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); let message_expiry: IggyExpiry = message_expiry.into(); let max_topic_size = u64::from_le_bytes( - bytes[position + 8..position + 16] + bytes + .get(position + 8..position + 16) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); let max_topic_size: MaxTopicSize = max_topic_size.into(); - let replication_factor = match bytes[position + 16] { + let replication_factor = match *bytes.get(position + 16).ok_or(IggyError::InvalidCommand)? { 0 => None, factor => Some(factor), }; - let name_length = bytes[position + 17]; - let name = from_utf8(&bytes[position + 18..(position + 18 + name_length as usize)]) - .map_err(|_| IggyError::InvalidUtf8)? - .to_string(); - if name.len() != name_length as usize { - return Err(IggyError::InvalidCommand); - } + let name_length = *bytes.get(position + 17).ok_or(IggyError::InvalidCommand)? as usize; + let name = from_utf8( + bytes + .get(position + 18..position + 18 + name_length) + .ok_or(IggyError::InvalidCommand)?, + ) + .map_err(|_| IggyError::InvalidUtf8)? + .to_string(); let command = UpdateTopic { stream_id, topic_id, @@ -229,6 +236,32 @@ mod tests { assert_eq!(name, command.name); } + #[test] + fn from_bytes_should_fail_on_empty_input() { + assert!(UpdateTopic::from_bytes(Bytes::new()).is_err()); + } + + #[test] + fn from_bytes_should_fail_on_truncated_input() { + let command = UpdateTopic { + stream_id: Identifier::numeric(1).unwrap(), + topic_id: Identifier::numeric(2).unwrap(), + compression_algorithm: CompressionAlgorithm::None, + message_expiry: IggyExpiry::NeverExpire, + max_topic_size: MaxTopicSize::ServerDefault, + replication_factor: Some(1), + name: "test".to_string(), + }; + let bytes = command.to_bytes(); + for i in 0..bytes.len() - 1 { + let truncated = bytes.slice(..i); + assert!( + UpdateTopic::from_bytes(truncated).is_err(), + "expected error for truncation at byte {i}" + ); + } + } + #[test] fn should_be_deserialized_from_bytes() { let stream_id = Identifier::numeric(1).unwrap(); diff --git a/core/common/src/commands/users/change_password.rs b/core/common/src/commands/users/change_password.rs index 4a5663778..ba8e251db 100644 --- a/core/common/src/commands/users/change_password.rs +++ b/core/common/src/commands/users/change_password.rs @@ -101,18 +101,25 @@ impl BytesSerializable for ChangePassword { let user_id = Identifier::from_bytes(bytes.clone())?; let mut position = user_id.get_size_bytes().as_bytes_usize(); - let current_password_length = bytes[position]; + let current_password_length = *bytes.get(position).ok_or(IggyError::InvalidCommand)?; position += 1; - let current_password = - from_utf8(&bytes[position..position + current_password_length as usize]) - .map_err(|_| IggyError::InvalidUtf8)? - .to_string(); + let current_password = from_utf8( + bytes + .get(position..position + current_password_length as usize) + .ok_or(IggyError::InvalidCommand)?, + ) + .map_err(|_| IggyError::InvalidUtf8)? + .to_string(); position += current_password_length as usize; - let new_password_length = bytes[position]; + let new_password_length = *bytes.get(position).ok_or(IggyError::InvalidCommand)?; position += 1; - let new_password = from_utf8(&bytes[position..position + new_password_length as usize]) - .map_err(|_| IggyError::InvalidUtf8)? - .to_string(); + let new_password = from_utf8( + bytes + .get(position..position + new_password_length as usize) + .ok_or(IggyError::InvalidCommand)?, + ) + .map_err(|_| IggyError::InvalidUtf8)? + .to_string(); let command = ChangePassword { user_id, @@ -160,6 +167,24 @@ mod tests { assert_eq!(new_password, command.new_password); } + #[test] + fn from_bytes_should_fail_on_empty_input() { + assert!(ChangePassword::from_bytes(Bytes::new()).is_err()); + } + + #[test] + fn from_bytes_should_fail_on_truncated_input() { + let command = ChangePassword::default(); + let bytes = command.to_bytes(); + for i in 0..bytes.len() - 1 { + let truncated = bytes.slice(..i); + assert!( + ChangePassword::from_bytes(truncated).is_err(), + "expected error for truncation at byte {i}" + ); + } + } + #[test] fn should_be_deserialized_from_bytes() { let user_id = Identifier::numeric(1).unwrap(); diff --git a/core/common/src/commands/users/create_user.rs b/core/common/src/commands/users/create_user.rs index 7d23c54ee..641302802 100644 --- a/core/common/src/commands/users/create_user.rs +++ b/core/common/src/commands/users/create_user.rs @@ -110,28 +110,30 @@ impl BytesSerializable for CreateUser { return Err(IggyError::InvalidCommand); } - let username_length = bytes[0]; - let username = from_utf8(&bytes[1..1 + username_length as usize]) - .map_err(|_| IggyError::InvalidUtf8)? - .to_string(); - if username.len() != username_length as usize { - return Err(IggyError::InvalidCommand); - } + let username_length = *bytes.first().ok_or(IggyError::InvalidCommand)? as usize; + let username = from_utf8( + bytes + .get(1..1 + username_length) + .ok_or(IggyError::InvalidCommand)?, + ) + .map_err(|_| IggyError::InvalidUtf8)? + .to_string(); - let mut position = 1 + username_length as usize; - let password_length = bytes[position]; + let mut position = 1 + username_length; + let password_length = *bytes.get(position).ok_or(IggyError::InvalidCommand)? as usize; position += 1; - let password = from_utf8(&bytes[position..position + password_length as usize]) - .map_err(|_| IggyError::InvalidUtf8)? - .to_string(); - if password.len() != password_length as usize { - return Err(IggyError::InvalidCommand); - } + let password = from_utf8( + bytes + .get(position..position + password_length) + .ok_or(IggyError::InvalidCommand)?, + ) + .map_err(|_| IggyError::InvalidUtf8)? + .to_string(); - position += password_length as usize; - let status = UserStatus::from_code(bytes[position])?; + position += password_length; + let status = UserStatus::from_code(*bytes.get(position).ok_or(IggyError::InvalidCommand)?)?; position += 1; - let has_permissions = bytes[position]; + let has_permissions = *bytes.get(position).ok_or(IggyError::InvalidCommand)?; if has_permissions > 1 { return Err(IggyError::InvalidCommand); } @@ -139,14 +141,18 @@ impl BytesSerializable for CreateUser { position += 1; let permissions = if has_permissions == 1 { let permissions_length = u32::from_le_bytes( - bytes[position..position + 4] + bytes + .get(position..position + 4) + .ok_or(IggyError::InvalidCommand)? .try_into() .map_err(|_| IggyError::InvalidNumberEncoding)?, ); position += 4; - Some(Permissions::from_bytes( - bytes.slice(position..position + permissions_length as usize), - )?) + let end = position + permissions_length as usize; + if end > bytes.len() { + return Err(IggyError::InvalidCommand); + } + Some(Permissions::from_bytes(bytes.slice(position..end))?) } else { None }; @@ -232,6 +238,32 @@ mod tests { assert_eq!(permissions, command.permissions.unwrap()); } + #[test] + fn from_bytes_should_fail_on_empty_input() { + assert!(CreateUser::from_bytes(Bytes::new()).is_err()); + } + + #[test] + fn from_bytes_should_fail_on_truncated_input() { + let command = CreateUser::default(); + let bytes = command.to_bytes(); + for i in 0..bytes.len() - 1 { + let truncated = bytes.slice(..i); + assert!( + CreateUser::from_bytes(truncated).is_err(), + "expected error for truncation at byte {i}" + ); + } + } + + #[test] + fn from_bytes_should_fail_on_corrupted_username_length() { + let mut buf = BytesMut::new(); + buf.put_u8(255); + buf.put_slice(b"short"); + assert!(CreateUser::from_bytes(buf.freeze()).is_err()); + } + #[test] fn should_be_deserialized_from_bytes() { let username = "user"; diff --git a/core/common/src/commands/users/login_user.rs b/core/common/src/commands/users/login_user.rs index edd1e8426..981558f2b 100644 --- a/core/common/src/commands/users/login_user.rs +++ b/core/common/src/commands/users/login_user.rs @@ -114,55 +114,83 @@ impl BytesSerializable for LoginUser { return Err(IggyError::InvalidCommand); } - let username_length = bytes[0]; - let username = from_utf8(&bytes[1..=(username_length as usize)]) - .map_err(|_| IggyError::InvalidUtf8)? - .to_string(); - if username.len() != username_length as usize { - return Err(IggyError::InvalidCommand); - } + let username_length = *bytes.first().ok_or(IggyError::InvalidCommand)? as usize; + let username = from_utf8( + bytes + .get(1..1 + username_length) + .ok_or(IggyError::InvalidCommand)?, + ) + .map_err(|_| IggyError::InvalidUtf8)? + .to_string(); - let password_length = bytes[1 + username_length as usize]; + let pos = 1 + username_length; + let password_length = *bytes.get(pos).ok_or(IggyError::InvalidCommand)? as usize; let password = from_utf8( - &bytes[2 + username_length as usize - ..2 + username_length as usize + password_length as usize], + bytes + .get(pos + 1..pos + 1 + password_length) + .ok_or(IggyError::InvalidCommand)?, ) .map_err(|_| IggyError::InvalidUtf8)? .to_string(); - if password.len() != password_length as usize { - return Err(IggyError::InvalidCommand); - } - let position = 2 + username_length as usize + password_length as usize; - let version_length = u32::from_le_bytes( - bytes[position..position + 4] - .try_into() - .map_err(|_| IggyError::InvalidNumberEncoding)?, - ); - let version = match version_length { - 0 => None, - _ => { - let version = - from_utf8(&bytes[position + 4..position + 4 + version_length as usize]) - .map_err(|_| IggyError::InvalidUtf8)? - .to_string(); - Some(version) + let mut position = pos + 1 + password_length; + + // Version and context fields are optional for backward compatibility + // with older SDKs (e.g. v0.8.0) that don't send them. + // However, 1-3 trailing bytes (incomplete u32 length prefix) are rejected + // as they indicate a corrupt payload rather than a valid old-SDK format. + let remaining = bytes.len() - position; + let version = if remaining == 0 { + None + } else if remaining < 4 { + return Err(IggyError::InvalidCommand); + } else { + let version_length = u32::from_le_bytes( + bytes[position..position + 4] + .try_into() + .map_err(|_| IggyError::InvalidNumberEncoding)?, + ); + position += 4; + match version_length { + 0 => None, + _ => { + let version = from_utf8( + bytes + .get(position..position + version_length as usize) + .ok_or(IggyError::InvalidCommand)?, + ) + .map_err(|_| IggyError::InvalidUtf8)? + .to_string(); + position += version_length as usize; + Some(version) + } } }; - let position = position + 4 + version_length as usize; - let context_length = u32::from_le_bytes( - bytes[position..position + 4] - .try_into() - .map_err(|_| IggyError::InvalidNumberEncoding)?, - ); - let context = match context_length { - 0 => None, - _ => { - let context = - from_utf8(&bytes[position + 4..position + 4 + context_length as usize]) - .map_err(|_| IggyError::InvalidUtf8)? - .to_string(); - Some(context) + + let remaining = bytes.len() - position; + let context = if remaining == 0 { + None + } else if remaining < 4 { + return Err(IggyError::InvalidCommand); + } else { + let context_length = u32::from_le_bytes( + bytes[position..position + 4] + .try_into() + .map_err(|_| IggyError::InvalidNumberEncoding)?, + ); + position += 4; + match context_length { + 0 => None, + _ => { + let context = from_utf8( + bytes + .get(position..position + context_length as usize) + .ok_or(IggyError::InvalidCommand)?, + ) + .map_err(|_| IggyError::InvalidUtf8)? + .to_string(); + Some(context) + } } }; @@ -226,6 +254,85 @@ mod tests { assert_eq!(context, command.context); } + #[test] + fn from_bytes_should_fail_on_empty_input() { + assert!(LoginUser::from_bytes(Bytes::new()).is_err()); + } + + #[test] + fn from_bytes_should_fail_on_truncated_input() { + let command = LoginUser { + username: "user".to_string(), + password: "secret".to_string(), + version: Some("1.0.0".to_string()), + context: Some("test".to_string()), + }; + let bytes = command.to_bytes(); + // Truncate at every position up to (but not including) the version field. + // Positions within username/password must error; positions at or past the + // version boundary are valid old-SDK payloads. + let version_offset = 2 + command.username.len() + command.password.len(); + for i in 0..version_offset { + let truncated = bytes.slice(..i); + assert!( + LoginUser::from_bytes(truncated).is_err(), + "expected error for truncation at byte {i}" + ); + } + } + + #[test] + fn from_bytes_should_fail_on_corrupted_username_length() { + let mut buf = BytesMut::new(); + buf.put_u8(255); // username_length = 255 + buf.put_slice(b"short"); + assert!(LoginUser::from_bytes(buf.freeze()).is_err()); + } + + #[test] + fn from_bytes_should_fail_on_trailing_bytes() { + let username = "user"; + let password = "secret"; + let mut bytes = BytesMut::new(); + #[allow(clippy::cast_possible_truncation)] + bytes.put_u8(username.len() as u8); + bytes.put_slice(username.as_bytes()); + #[allow(clippy::cast_possible_truncation)] + bytes.put_u8(password.len() as u8); + bytes.put_slice(password.as_bytes()); + + // 1-3 trailing bytes (incomplete u32 length prefix) must be rejected + for extra in 1..=3u8 { + let mut buf = bytes.clone(); + for i in 0..extra { + buf.put_u8(i); + } + assert!( + LoginUser::from_bytes(buf.freeze()).is_err(), + "expected error for {extra} trailing byte(s)" + ); + } + } + + #[test] + fn from_bytes_should_accept_old_sdk_format_without_version_context() { + let username = "user"; + let password = "secret"; + let mut bytes = BytesMut::new(); + #[allow(clippy::cast_possible_truncation)] + bytes.put_u8(username.len() as u8); + bytes.put_slice(username.as_bytes()); + #[allow(clippy::cast_possible_truncation)] + bytes.put_u8(password.len() as u8); + bytes.put_slice(password.as_bytes()); + + let command = LoginUser::from_bytes(bytes.freeze()).unwrap(); + assert_eq!(command.username, username); + assert_eq!(command.password, password); + assert_eq!(command.version, None); + assert_eq!(command.context, None); + } + #[test] fn should_be_deserialized_from_bytes() { let username = "user"; diff --git a/core/common/src/commands/users/update_permissions.rs b/core/common/src/commands/users/update_permissions.rs index ff96927dd..3a5ee8abe 100644 --- a/core/common/src/commands/users/update_permissions.rs +++ b/core/common/src/commands/users/update_permissions.rs @@ -75,19 +75,26 @@ impl BytesSerializable for UpdatePermissions { let user_id = Identifier::from_bytes(bytes.clone())?; let mut position = user_id.get_size_bytes().as_bytes_usize(); - let has_permissions = bytes[position]; + let has_permissions = *bytes.get(position).ok_or(IggyError::InvalidCommand)?; if has_permissions > 1 { return Err(IggyError::InvalidCommand); } position += 1; let permissions = if has_permissions == 1 { - let permissions_length = - u32::from_le_bytes(bytes[position..position + 4].try_into().unwrap()); + let permissions_length = u32::from_le_bytes( + bytes + .get(position..position + 4) + .ok_or(IggyError::InvalidCommand)? + .try_into() + .map_err(|_| IggyError::InvalidNumberEncoding)?, + ); position += 4; - let permissions = Permissions::from_bytes( - bytes.slice(position..position + permissions_length as usize), - )?; + let end = position + permissions_length as usize; + if end > bytes.len() { + return Err(IggyError::InvalidCommand); + } + let permissions = Permissions::from_bytes(bytes.slice(position..end))?; Some(permissions) } else { None diff --git a/core/common/src/commands/users/update_user.rs b/core/common/src/commands/users/update_user.rs index 3e2811915..b6df319a7 100644 --- a/core/common/src/commands/users/update_user.rs +++ b/core/common/src/commands/users/update_user.rs @@ -96,32 +96,37 @@ impl BytesSerializable for UpdateUser { let user_id = Identifier::from_bytes(bytes.clone())?; let mut position = user_id.get_size_bytes().as_bytes_usize(); - let has_username = bytes[position]; + let has_username = *bytes.get(position).ok_or(IggyError::InvalidCommand)?; if has_username > 1 { return Err(IggyError::InvalidCommand); } position += 1; let username = if has_username == 1 { - let username_length = bytes[position]; + let username_length = *bytes.get(position).ok_or(IggyError::InvalidCommand)?; position += 1; - let username = from_utf8(&bytes[position..position + username_length as usize]) - .map_err(|_| IggyError::InvalidUtf8)? - .to_string(); + let username = from_utf8( + bytes + .get(position..position + username_length as usize) + .ok_or(IggyError::InvalidCommand)?, + ) + .map_err(|_| IggyError::InvalidUtf8)? + .to_string(); position += username_length as usize; Some(username) } else { None }; - let has_status = bytes[position]; + let has_status = *bytes.get(position).ok_or(IggyError::InvalidCommand)?; if has_status > 1 { return Err(IggyError::InvalidCommand); } let status = if has_status == 1 { position += 1; - let status = UserStatus::from_code(bytes[position])?; + let status = + UserStatus::from_code(*bytes.get(position).ok_or(IggyError::InvalidCommand)?)?; Some(status) } else { None @@ -180,6 +185,28 @@ mod tests { assert_eq!(status, command.status.unwrap()); } + #[test] + fn from_bytes_should_fail_on_empty_input() { + assert!(UpdateUser::from_bytes(Bytes::new()).is_err()); + } + + #[test] + fn from_bytes_should_fail_on_truncated_input() { + let command = UpdateUser { + user_id: Identifier::numeric(1).unwrap(), + username: Some("user".to_string()), + status: Some(UserStatus::Active), + }; + let bytes = command.to_bytes(); + for i in 0..bytes.len() - 1 { + let truncated = bytes.slice(..i); + assert!( + UpdateUser::from_bytes(truncated).is_err(), + "expected error for truncation at byte {i}" + ); + } + } + #[test] fn should_be_deserialized_from_bytes() { let user_id = Identifier::numeric(1).unwrap(); diff --git a/core/common/src/types/identifier/mod.rs b/core/common/src/types/identifier/mod.rs index bd4a5ccc0..82c41eb8a 100644 --- a/core/common/src/types/identifier/mod.rs +++ b/core/common/src/types/identifier/mod.rs @@ -187,12 +187,12 @@ impl Identifier { /// Creates identifier from raw bytes pub fn from_raw_bytes(bytes: &[u8]) -> Result<Self, IggyError> { - let kind = IdKind::from_code(bytes[0])?; - let length = bytes[1]; - let value = bytes[2..2 + length as usize].to_vec(); - if value.len() != length as usize { - return Err(IggyError::InvalidIdentifier); - } + let kind = IdKind::from_code(*bytes.first().ok_or(IggyError::InvalidIdentifier)?)?; + let length = *bytes.get(1).ok_or(IggyError::InvalidIdentifier)?; + let value = bytes + .get(2..2 + length as usize) + .ok_or(IggyError::InvalidIdentifier)? + .to_vec(); let identifier = Identifier { kind, @@ -228,16 +228,12 @@ impl BytesSerializable for Identifier { where Self: Sized, { - if bytes.len() < 3 { - return Err(IggyError::InvalidIdentifier); - } - - let kind = IdKind::from_code(bytes[0])?; - let length = bytes[1]; - let value = bytes[2..2 + length as usize].to_vec(); - if value.len() != length as usize { - return Err(IggyError::InvalidIdentifier); - } + let kind = IdKind::from_code(*bytes.first().ok_or(IggyError::InvalidIdentifier)?)?; + let length = *bytes.get(1).ok_or(IggyError::InvalidIdentifier)?; + let value = bytes + .get(2..2 + length as usize) + .ok_or(IggyError::InvalidIdentifier)? + .to_vec(); let identifier = Identifier { kind, @@ -382,6 +378,55 @@ mod tests { assert!(Identifier::named(&"a".repeat(256)).is_err()); } + #[test] + fn from_bytes_should_fail_on_empty_input() { + assert!(Identifier::from_bytes(Bytes::new()).is_err()); + } + + #[test] + fn from_bytes_should_fail_on_truncated_input() { + let id = Identifier::numeric(42).unwrap(); + let bytes = id.to_bytes(); + for i in 0..bytes.len() - 1 { + let truncated = bytes.slice(..i); + assert!( + Identifier::from_bytes(truncated).is_err(), + "expected error for truncation at byte {i}" + ); + } + } + + #[test] + fn from_bytes_should_fail_on_corrupted_length() { + let mut buf = BytesMut::new(); + buf.put_u8(1); // Numeric kind + buf.put_u8(255); // length = 255 but only 2 bytes of value follow + buf.put_u16_le(0); + assert!(Identifier::from_bytes(buf.freeze()).is_err()); + } + + #[test] + fn from_raw_bytes_should_fail_on_empty_input() { + assert!(Identifier::from_raw_bytes(&[]).is_err()); + } + + #[test] + fn from_raw_bytes_should_fail_on_truncated_input() { + let id = Identifier::numeric(42).unwrap(); + let bytes = id.to_bytes(); + for i in 0..bytes.len() - 1 { + assert!( + Identifier::from_raw_bytes(&bytes[..i]).is_err(), + "expected error for truncation at byte {i}" + ); + } + } + + #[test] + fn from_raw_bytes_should_fail_on_corrupted_length() { + assert!(Identifier::from_raw_bytes(&[1, 255, 0, 0]).is_err()); + } + #[test] fn numeric_id_should_be_converted_into_identifier_using_trait() { let id = 1; diff --git a/core/common/src/types/message/message_view.rs b/core/common/src/types/message/message_view.rs index 1f1442ea7..1dd664208 100644 --- a/core/common/src/types/message/message_view.rs +++ b/core/common/src/types/message/message_view.rs @@ -37,17 +37,29 @@ pub struct IggyMessageView<'a> { impl<'a> IggyMessageView<'a> { /// Creates a new immutable message view from a buffer. - pub fn new(buffer: &'a [u8]) -> Self { + /// + /// Validates that the buffer is large enough to contain the full message + /// (header + payload + user headers). All subsequent accessors can use + /// direct indexing because this constructor guarantees the bounds. + pub fn new(buffer: &'a [u8]) -> Result<Self, IggyError> { + if buffer.len() < IGGY_MESSAGE_HEADER_SIZE { + return Err(IggyError::InvalidMessagePayloadLength); + } let header_view = IggyMessageHeaderView::new(&buffer[IGGY_MESSAGE_HEADER_RANGE]); let payload_len = header_view.payload_length(); + let user_headers_len = header_view.user_headers_length(); + let total_size = IGGY_MESSAGE_HEADER_SIZE + payload_len + user_headers_len; + if buffer.len() < total_size { + return Err(IggyError::InvalidMessagePayloadLength); + } let payload_offset = IGGY_MESSAGE_HEADER_SIZE; let headers_offset = payload_offset + payload_len; - Self { + Ok(Self { buffer, payload_offset, user_headers_offset: headers_offset, - } + }) } /// Returns an immutable header view. @@ -108,32 +120,27 @@ impl<'a> IggyMessageView<'a> { pub fn payload(&self) -> &[u8] { let header_view = self.header(); let payload_len = header_view.payload_length(); - &self.buffer[self.payload_offset..self.payload_offset + payload_len] + let end = self.payload_offset + payload_len; + // Bounds guaranteed by new() which validates total message size + &self.buffer[self.payload_offset..end] } /// Validates that the message view is properly formatted and has valid data. + /// + /// Note: `new()` already validates buffer bounds, so this always succeeds + /// for a properly constructed view. pub fn validate(&self) -> Result<(), IggyError> { - if self.buffer.len() < IGGY_MESSAGE_HEADER_SIZE { - return Err(IggyError::InvalidMessagePayloadLength); - } - - let header = self.header(); - let payload_len = header.payload_length(); - let user_headers_len = header.user_headers_length(); - let total_size = IGGY_MESSAGE_HEADER_SIZE + payload_len + user_headers_len; - - if self.buffer.len() < total_size { - return Err(IggyError::InvalidMessagePayloadLength); - } Ok(()) } - /// Validates that the message view has a valid checksum. + /// Calculates the checksum over the message (excluding the checksum field itself). /// This should be called only on server side. pub fn calculate_checksum(&self) -> u64 { - let checksum_field_size = size_of::<u64>(); // Skip checksum field for checksum calculation + let checksum_field_size = size_of::<u64>(); let size = self.size() - checksum_field_size; - let data = &self.buffer[checksum_field_size..checksum_field_size + size]; + let end = checksum_field_size + size; + // Bounds guaranteed by new() which validates total message size + let data = &self.buffer[checksum_field_size..end]; checksum::calculate_checksum(data) } } @@ -186,7 +193,7 @@ impl<'a> Iterator for IggyMessageViewIterator<'a> { } let remaining = &self.buffer[self.position..]; - let view = IggyMessageView::new(remaining); + let view = IggyMessageView::new(remaining).ok()?; self.position += view.size(); Some(view) } diff --git a/core/common/src/types/message/messages_batch.rs b/core/common/src/types/message/messages_batch.rs index 6bc704866..aa6910b01 100644 --- a/core/common/src/types/message/messages_batch.rs +++ b/core/common/src/types/message/messages_batch.rs @@ -166,11 +166,8 @@ impl IggyMessagesBatch { /// Get the message at the specified index. /// Returns None if the index is out of bounds. pub fn get(&self, index: usize) -> Option<IggyMessageView<'_>> { - if let Some((start, end)) = self.get_message_boundaries(index) { - Some(IggyMessageView::new(&self.messages[start..end])) - } else { - None - } + let (start, end) = self.get_message_boundaries(index)?; + IggyMessageView::new(&self.messages[start..end]).ok() } } diff --git a/core/common/src/types/message/messages_batch_mut.rs b/core/common/src/types/message/messages_batch_mut.rs index 3857f2667..55e33bf27 100644 --- a/core/common/src/types/message/messages_batch_mut.rs +++ b/core/common/src/types/message/messages_batch_mut.rs @@ -199,7 +199,7 @@ impl IggyMessagesBatchMut { if self.is_empty() { return None; } - Some(IggyMessageView::new(&self.messages).header().offset()) + Some(IggyMessageView::new(&self.messages).ok()?.header().offset()) } /// Returns the first timestamp in the batch @@ -207,7 +207,12 @@ impl IggyMessagesBatchMut { if self.is_empty() { return None; } - Some(IggyMessageView::new(&self.messages).header().timestamp()) + Some( + IggyMessageView::new(&self.messages) + .ok()? + .header() + .timestamp(), + ) } /// Returns the last timestamp in the batch @@ -217,11 +222,13 @@ impl IggyMessagesBatchMut { } let last_index = self.count() as usize - 1; - self.get_message_boundaries(last_index).map(|(start, _)| { - IggyMessageView::new(&self.messages[start..]) + let (start, end) = self.get_message_boundaries(last_index)?; + Some( + IggyMessageView::new(&self.messages[start..end]) + .ok()? .header() - .timestamp() - }) + .timestamp(), + ) } /// Returns the last offset in the batch @@ -230,11 +237,13 @@ impl IggyMessagesBatchMut { return None; } let last_index = self.count() as usize - 1; - self.get_message_boundaries(last_index).map(|(start, _)| { - IggyMessageView::new(&self.messages[start..]) + let (start, end) = self.get_message_boundaries(last_index)?; + Some( + IggyMessageView::new(&self.messages[start..end]) + .ok()? .header() - .offset() - }) + .offset(), + ) } /// Checks if the batch is empty. @@ -472,8 +481,8 @@ impl IggyMessagesBatchMut { /// Get the message at the specified index. /// Returns None if the index is out of bounds or the message cannot be found. pub fn get(&self, index: usize) -> Option<IggyMessageView<'_>> { - self.get_message_boundaries(index) - .map(|(start, end)| IggyMessageView::new(&self.messages[start..end])) + let (start, end) = self.get_message_boundaries(index)?; + IggyMessageView::new(&self.messages[start..end]).ok() } /// This helper function is used to parse newly appended chunks in the `new_buffer`. @@ -494,7 +503,10 @@ impl IggyMessagesBatchMut { let mut current = chunk_start; while current < chunk_end { - let view = IggyMessageView::new(&new_buffer[current..]); + let Ok(view) = IggyMessageView::new(&new_buffer[current..]) else { + error!("Corrupt message in already-validated chunk at offset {current}"); + break; + }; let msg_size = view.size(); *offset_in_new_buffer += msg_size as u32; new_indexes.insert(0, *offset_in_new_buffer, 0); diff --git a/core/common/src/types/message/partitioning.rs b/core/common/src/types/message/partitioning.rs index 78efab2ce..1353ca394 100644 --- a/core/common/src/types/message/partitioning.rs +++ b/core/common/src/types/message/partitioning.rs @@ -137,14 +137,14 @@ impl Partitioning { } } - /// Create the partitioning from BytesMut. + /// Create the partitioning from a raw byte slice. pub fn from_raw_bytes(bytes: &[u8]) -> Result<Self, IggyError> { - let kind = PartitioningKind::from_code(bytes[0])?; - let length = bytes[1]; - let value = bytes[2..2 + length as usize].to_vec(); - if value.len() != length as usize { - return Err(IggyError::InvalidCommand); - } + let kind = PartitioningKind::from_code(*bytes.first().ok_or(IggyError::InvalidCommand)?)?; + let length = *bytes.get(1).ok_or(IggyError::InvalidCommand)?; + let value = bytes + .get(2..2 + length as usize) + .ok_or(IggyError::InvalidCommand)? + .to_vec(); Ok(Partitioning { kind, @@ -192,22 +192,7 @@ impl BytesSerializable for Partitioning { where Self: Sized, { - if bytes.len() < 2 { - return Err(IggyError::InvalidCommand); - } - - let kind = PartitioningKind::from_code(bytes[0])?; - let length = bytes[1]; - let value = bytes[2..2 + length as usize].to_vec(); - if value.len() != length as usize { - return Err(IggyError::InvalidCommand); - } - - Ok(Partitioning { - kind, - length, - value, - }) + Self::from_raw_bytes(&bytes) } fn write_to_buffer(&self, bytes: &mut BytesMut) { diff --git a/core/server/src/binary/handlers/messages/send_messages_handler.rs b/core/server/src/binary/handlers/messages/send_messages_handler.rs index 713d64ab2..9e7c0f8ef 100644 --- a/core/server/src/binary/handlers/messages/send_messages_handler.rs +++ b/core/server/src/binary/handlers/messages/send_messages_handler.rs @@ -76,18 +76,28 @@ impl ServerCommandHandler for SendMessages { element_size += stream_id.get_size_bytes().as_bytes_usize(); self.stream_id = stream_id; - let topic_id = Identifier::from_raw_bytes(&metadata_buf[element_size..])?; + let topic_id = Identifier::from_raw_bytes( + metadata_buf + .get(element_size..) + .ok_or(IggyError::InvalidCommand)?, + )?; element_size += topic_id.get_size_bytes().as_bytes_usize(); self.topic_id = topic_id; - let partitioning = Partitioning::from_raw_bytes(&metadata_buf[element_size..])?; + let partitioning = Partitioning::from_raw_bytes( + metadata_buf + .get(element_size..) + .ok_or(IggyError::InvalidCommand)?, + )?; element_size += partitioning.get_size_bytes().as_bytes_usize(); self.partitioning = partitioning; let messages_count = u32::from_le_bytes( - metadata_buf[element_size..element_size + 4] + metadata_buf + .get(element_size..element_size + 4) + .ok_or(IggyError::InvalidCommand)? .try_into() - .unwrap(), + .map_err(|_| IggyError::InvalidNumberEncoding)?, ); let indexes_size = messages_count as usize * INDEX_SIZE; @@ -96,8 +106,11 @@ impl ServerCommandHandler for SendMessages { result?; let indexes_buffer = indexes_buffer.into_inner(); - let messages_size = - total_payload_size - metadata_size as usize - indexes_size - metadata_len_field_size; + let messages_size = total_payload_size + .checked_sub(metadata_size as usize) + .and_then(|s| s.checked_sub(indexes_size)) + .and_then(|s| s.checked_sub(metadata_len_field_size)) + .ok_or(IggyError::InvalidCommand)?; let messages_buffer = PooledBuffer::with_capacity(messages_size); let (result, messages_buffer) = sender.read(messages_buffer.slice(0..messages_size)).await; result?;
