diff --git a/etherparse/src/transport/igmp/mod.rs b/etherparse/src/transport/igmp/mod.rs index 303533fe..cb2cfb8c 100644 --- a/etherparse/src/transport/igmp/mod.rs +++ b/etherparse/src/transport/igmp/mod.rs @@ -31,6 +31,9 @@ pub use report_group_record_type::*; mod report_group_record_v3_header; pub use report_group_record_v3_header::*; +mod report_group_record_v3_slice; +pub use report_group_record_v3_slice::*; + mod unknown_header; pub use unknown_header::*; diff --git a/etherparse/src/transport/igmp/report_group_record_v3_slice.rs b/etherparse/src/transport/igmp/report_group_record_v3_slice.rs new file mode 100644 index 00000000..9f12317d --- /dev/null +++ b/etherparse/src/transport/igmp/report_group_record_v3_slice.rs @@ -0,0 +1,414 @@ +use super::ReportGroupRecordV3Header; +use crate::*; + +/// A zero-copy slice of a single IGMPv3 group record. +/// +/// Provides access to the 8-byte fixed header fields, the source +/// address list, and the auxiliary data without copying. +/// +/// ```text +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | Record Type | Aux Data Len | Number of Sources (N) | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | Multicast Address | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | Source Address [1] | +/// +- . -+ +/// . . . +/// +- -+ +/// | Source Address [N] | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | Auxiliary Data | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// ``` +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ReportGroupRecordV3Slice<'a> { + /// The full record bytes (header + sources + aux data). + slice: &'a [u8], +} + +impl<'a> ReportGroupRecordV3Slice<'a> { + /// Creates a group record slice from raw bytes. + /// + /// Validates that the slice is at least + /// [`ReportGroupRecordV3Header::LEN`] bytes and that it contains + /// enough data for the declared source addresses and auxiliary data. + /// + /// Returns a tuple of the group record slice and the remaining + /// bytes after this record. + /// + /// # Errors + /// + /// Returns an [`err::LenError`] if the slice is too short. + #[inline] + pub fn from_slice( + slice: &'a [u8], + ) -> Result<(ReportGroupRecordV3Slice<'a>, &'a [u8]), err::LenError> { + if slice.len() < ReportGroupRecordV3Header::LEN { + return Err(err::LenError { + required_len: ReportGroupRecordV3Header::LEN, + len: slice.len(), + len_source: LenSource::Slice, + layer: err::Layer::Igmp, + layer_start_offset: 0, + }); + } + + // SAFETY: Safe as the length was checked to be >= LEN (8). + let num_of_sources = + u16::from_be_bytes(unsafe { [*slice.get_unchecked(2), *slice.get_unchecked(3)] }); + let aux_data_len = unsafe { *slice.get_unchecked(1) }; + + let record_len = ReportGroupRecordV3Header::LEN + + usize::from(num_of_sources) * 4 + + usize::from(aux_data_len) * 4; + + if slice.len() < record_len { + return Err(err::LenError { + required_len: record_len, + len: slice.len(), + len_source: LenSource::Slice, + layer: err::Layer::Igmp, + layer_start_offset: 0, + }); + } + + Ok(( + ReportGroupRecordV3Slice { + slice: &slice[..record_len], + }, + &slice[record_len..], + )) + } + + /// Decode the fixed header into a [`ReportGroupRecordV3Header`]. + #[inline] + pub fn header(&self) -> ReportGroupRecordV3Header { + // SAFETY: from_slice guarantees at least LEN bytes. + let (header, _) = ReportGroupRecordV3Header::from_slice(self.slice).unwrap(); + header + } + + /// Returns the group record type. + #[inline] + pub fn record_type(&self) -> igmp::ReportGroupRecordType { + // SAFETY: Safe as from_slice checks that the slice has at least LEN (8) bytes. + igmp::ReportGroupRecordType(unsafe { *self.slice.get_unchecked(0) }) + } + + /// Returns the auxiliary data length in units of 32-bit words. + #[inline] + pub fn aux_data_len(&self) -> u8 { + // SAFETY: Safe as from_slice checks that the slice has at least LEN (8) bytes. + unsafe { *self.slice.get_unchecked(1) } + } + + /// Returns the number of source addresses. + #[inline] + pub fn num_of_sources(&self) -> u16 { + // SAFETY: Safe as from_slice checks that the slice has at least LEN (8) bytes. + unsafe { get_unchecked_be_u16(self.slice.as_ptr().add(2)) } + } + + /// Returns the multicast address. + #[inline] + pub fn multicast_address(&self) -> [u8; 4] { + // SAFETY: Safe as from_slice checks that the slice has at least LEN (8) bytes. + unsafe { + [ + *self.slice.get_unchecked(4), + *self.slice.get_unchecked(5), + *self.slice.get_unchecked(6), + *self.slice.get_unchecked(7), + ] + } + } + + /// Returns the raw source address bytes. + /// + /// The returned slice contains `num_of_sources * 4` bytes. Each 4 + /// consecutive bytes represent one IPv4 source address. + #[inline] + pub fn source_addrs_bytes(&self) -> &'a [u8] { + let start = ReportGroupRecordV3Header::LEN; + let len = usize::from(self.num_of_sources()) * 4; + // SAFETY: Safe as from_slice validates the total record length. + unsafe { core::slice::from_raw_parts(self.slice.as_ptr().add(start), len) } + } + + /// Returns the auxiliary data bytes. + #[inline] + pub fn aux_data(&self) -> &'a [u8] { + let start = ReportGroupRecordV3Header::LEN + usize::from(self.num_of_sources()) * 4; + let len = usize::from(self.aux_data_len()) * 4; + // SAFETY: Safe as from_slice validates the total record length. + unsafe { core::slice::from_raw_parts(self.slice.as_ptr().add(start), len) } + } + + /// Returns the full slice of this group record. + #[inline] + pub fn slice(&self) -> &'a [u8] { + self.slice + } +} + +/// An iterator over IGMPv3 group record slices in a report payload. +#[derive(Clone, Debug)] +pub struct ReportGroupRecordV3SliceIter<'a> { + remaining: &'a [u8], + count: u16, +} + +impl<'a> ReportGroupRecordV3SliceIter<'a> { + /// Creates a new iterator over `count` group records starting at + /// the beginning of `slice`. + #[inline] + pub fn new(slice: &'a [u8], count: u16) -> ReportGroupRecordV3SliceIter<'a> { + ReportGroupRecordV3SliceIter { + remaining: slice, + count, + } + } +} + +impl<'a> Iterator for ReportGroupRecordV3SliceIter<'a> { + type Item = Result, err::LenError>; + + fn next(&mut self) -> Option { + if self.count == 0 { + return None; + } + self.count -= 1; + match ReportGroupRecordV3Slice::from_slice(self.remaining) { + Ok((record, rest)) => { + self.remaining = rest; + Some(Ok(record)) + } + Err(e) => { + // Stop iteration on error. + self.count = 0; + Some(Err(e)) + } + } + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(usize::from(self.count))) + } +} + +#[cfg(test)] +mod test { + use super::*; + use alloc::{format, vec, vec::Vec}; + use proptest::prelude::*; + + fn make_record_bytes( + record_type: u8, + aux_data_len: u8, + num_sources: u16, + multicast_addr: [u8; 4], + ) -> Vec { + let n = num_sources.to_be_bytes(); + let mut bytes = vec![record_type, aux_data_len, n[0], n[1]]; + bytes.extend_from_slice(&multicast_addr); + // source addresses (4 bytes each) + for i in 0..num_sources { + bytes.extend_from_slice(&[10, 0, 0, (i + 1) as u8]); + } + // aux data (4 bytes per word) + for _ in 0..aux_data_len { + bytes.extend_from_slice(&[0xAA, 0xBB, 0xCC, 0xDD]); + } + bytes + } + + #[test] + fn from_slice_no_sources() { + let bytes = make_record_bytes(1, 0, 0, [224, 0, 0, 1]); + let mut with_trailer = bytes.clone(); + with_trailer.extend_from_slice(&[0xEE]); + + let (slice, rest) = ReportGroupRecordV3Slice::from_slice(&with_trailer).unwrap(); + assert_eq!(rest, &[0xEE]); + assert_eq!(slice.slice(), &bytes[..]); + assert_eq!(slice.record_type(), igmp::ReportGroupRecordType(1)); + assert_eq!(slice.aux_data_len(), 0); + assert_eq!(slice.num_of_sources(), 0); + assert_eq!(slice.multicast_address(), [224, 0, 0, 1]); + assert_eq!(slice.source_addrs_bytes(), &[]); + assert_eq!(slice.aux_data(), &[]); + } + + #[test] + fn from_slice_with_sources() { + let bytes = make_record_bytes(2, 0, 2, [224, 0, 0, 1]); + let mut with_trailer = bytes.clone(); + with_trailer.push(0xFF); + + let (slice, rest) = ReportGroupRecordV3Slice::from_slice(&with_trailer).unwrap(); + assert_eq!(rest, &[0xFF]); + assert_eq!(slice.num_of_sources(), 2); + assert_eq!(slice.source_addrs_bytes(), &[10, 0, 0, 1, 10, 0, 0, 2]); + } + + #[test] + fn from_slice_with_aux_data() { + let bytes = make_record_bytes(1, 1, 1, [224, 0, 0, 1]); + + let (slice, rest) = ReportGroupRecordV3Slice::from_slice(&bytes).unwrap(); + assert!(rest.is_empty()); + assert_eq!(slice.source_addrs_bytes(), &[10, 0, 0, 1]); + assert_eq!(slice.aux_data(), &[0xAA, 0xBB, 0xCC, 0xDD]); + } + + #[test] + fn from_slice_too_short_header() { + for bad_len in 0..ReportGroupRecordV3Header::LEN { + let bytes = vec![0u8; bad_len]; + assert_eq!( + ReportGroupRecordV3Slice::from_slice(&bytes).unwrap_err(), + err::LenError { + required_len: ReportGroupRecordV3Header::LEN, + len: bad_len, + len_source: LenSource::Slice, + layer: err::Layer::Igmp, + layer_start_offset: 0, + } + ); + } + } + + #[test] + fn from_slice_too_short_sources() { + // Declare 2 sources but only provide 1 + let n = 2u16.to_be_bytes(); + let mut bytes = vec![1, 0, n[0], n[1], 224, 0, 0, 1]; + bytes.extend_from_slice(&[10, 0, 0, 1]); // only 4 bytes, need 8 + + assert_eq!( + ReportGroupRecordV3Slice::from_slice(&bytes).unwrap_err(), + err::LenError { + required_len: 8 + 8, // header + 2 sources + len: 12, + len_source: LenSource::Slice, + layer: err::Layer::Igmp, + layer_start_offset: 0, + } + ); + } + + #[test] + fn from_slice_too_short_aux_data() { + // Declare 1 word aux data but don't provide it + let bytes = vec![1, 1, 0, 0, 224, 0, 0, 1]; // aux_data_len=1, 0 sources + + assert_eq!( + ReportGroupRecordV3Slice::from_slice(&bytes).unwrap_err(), + err::LenError { + required_len: 8 + 4, // header + 1 word aux + len: 8, + len_source: LenSource::Slice, + layer: err::Layer::Igmp, + layer_start_offset: 0, + } + ); + } + + #[test] + fn header_accessor() { + let bytes = make_record_bytes(3, 0, 1, [239, 1, 2, 3]); + let (slice, _) = ReportGroupRecordV3Slice::from_slice(&bytes).unwrap(); + let header = slice.header(); + assert_eq!(header.record_type, igmp::ReportGroupRecordType(3)); + assert_eq!(header.aux_data_len, 0); + assert_eq!(header.num_of_sources, 1); + assert_eq!(header.multicast_address, [239, 1, 2, 3]); + } + + proptest! { + #[test] + fn field_accessors( + record_type in any::(), + aux_data_len in 0u8..4, + num_sources in 0u16..4, + multicast_address in any::<[u8; 4]>(), + ) { + let bytes = make_record_bytes(record_type, aux_data_len, num_sources, multicast_address); + let (slice, rest) = ReportGroupRecordV3Slice::from_slice(&bytes).unwrap(); + prop_assert!(rest.is_empty()); + prop_assert_eq!(record_type, slice.record_type().0); + prop_assert_eq!(aux_data_len, slice.aux_data_len()); + prop_assert_eq!(num_sources, slice.num_of_sources()); + prop_assert_eq!(multicast_address, slice.multicast_address()); + prop_assert_eq!(usize::from(num_sources) * 4, slice.source_addrs_bytes().len()); + prop_assert_eq!(usize::from(aux_data_len) * 4, slice.aux_data().len()); + } + } + + proptest! { + #[test] + fn clone_eq(multicast_address in any::<[u8; 4]>()) { + let bytes = make_record_bytes(1, 0, 0, multicast_address); + let (slice, _) = ReportGroupRecordV3Slice::from_slice(&bytes).unwrap(); + prop_assert_eq!(&slice, &slice.clone()); + } + } + + #[test] + fn debug_fmt() { + let bytes = make_record_bytes(1, 0, 0, [224, 0, 0, 1]); + let (slice, _) = ReportGroupRecordV3Slice::from_slice(&bytes).unwrap(); + let dbg = format!("{:?}", slice); + assert!(dbg.starts_with("ReportGroupRecordV3Slice")); + } + + // Iterator tests + + #[test] + fn iterator_empty() { + let iter = ReportGroupRecordV3SliceIter::new(&[], 0); + assert_eq!(0, iter.count()); + } + + #[test] + fn iterator_single() { + let bytes = make_record_bytes(1, 0, 0, [224, 0, 0, 1]); + let mut iter = ReportGroupRecordV3SliceIter::new(&bytes, 1); + let record = iter.next().unwrap().unwrap(); + assert_eq!(record.multicast_address(), [224, 0, 0, 1]); + assert!(iter.next().is_none()); + } + + #[test] + fn iterator_multiple() { + let mut bytes = make_record_bytes(1, 0, 1, [224, 0, 0, 1]); + bytes.extend_from_slice(&make_record_bytes(2, 0, 0, [224, 0, 0, 2])); + + let records: Vec<_> = ReportGroupRecordV3SliceIter::new(&bytes, 2) + .collect::, _>>() + .unwrap(); + assert_eq!(2, records.len()); + assert_eq!(records[0].record_type().0, 1); + assert_eq!(records[0].multicast_address(), [224, 0, 0, 1]); + assert_eq!(records[1].record_type().0, 2); + assert_eq!(records[1].multicast_address(), [224, 0, 0, 2]); + } + + #[test] + fn iterator_error_stops() { + // Declare 2 records but only provide 1 + let bytes = make_record_bytes(1, 0, 0, [224, 0, 0, 1]); + let mut iter = ReportGroupRecordV3SliceIter::new(&bytes, 2); + assert!(iter.next().unwrap().is_ok()); + assert!(iter.next().unwrap().is_err()); + assert!(iter.next().is_none()); + } + + #[test] + fn iterator_size_hint() { + let bytes = make_record_bytes(1, 0, 0, [224, 0, 0, 1]); + let iter = ReportGroupRecordV3SliceIter::new(&bytes, 3); + assert_eq!((0, Some(3)), iter.size_hint()); + } +} diff --git a/etherparse/src/transport/igmp_slice.rs b/etherparse/src/transport/igmp_slice.rs new file mode 100644 index 00000000..8048d24a --- /dev/null +++ b/etherparse/src/transport/igmp_slice.rs @@ -0,0 +1,500 @@ +use crate::{igmp::*, *}; + +/// A slice containing an IGMP network packet. +/// +/// Struct allows the selective read of fields in the IGMP +/// packet without copying the data. +/// +/// # Important: Caller must trim to IGMP message length +/// +/// For `0x11` "Membership Query" messages, the IGMP version is +/// determined by message length per [RFC 9776 ยง7.1]( +/// https://datatracker.ietf.org/doc/html/rfc9776#section-7.1): +/// +/// * IGMPv1/v2 Query: length = 8 octets +/// * IGMPv3 Query: length >= 12 octets +/// +/// The caller **must** trim the input slice to the exact IGMP message +/// boundary (typically derived from the IP payload length) before +/// calling [`IgmpSlice::from_slice`]. If extra trailing bytes are +/// present, a query may be misidentified as IGMPv3. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct IgmpSlice<'a> { + slice: &'a [u8], +} + +impl<'a> IgmpSlice<'a> { + /// Creates a slice containing an IGMP packet. + /// + /// # Errors + /// + /// The function will return an `Err` [`err::LenError`] if the given + /// slice is too small to contain a valid IGMP header (minimum 8 + /// bytes), or has a length of 9-11 bytes for a `0x11` Membership + /// Query (which is invalid per RFC 9776). + #[inline] + pub fn from_slice(slice: &'a [u8]) -> Result, err::LenError> { + // Validate by attempting to parse the header. This checks both + // the minimum length and the 9-11 byte invalid range for queries. + let _ = IgmpHeader::from_slice(slice)?; + Ok(IgmpSlice { slice }) + } + + /// Decode the header values into an [`IgmpHeader`] struct. + #[inline] + pub fn header(&self) -> IgmpHeader { + // SAFETY: from_slice already validated the slice, so this cannot fail. + let (header, _) = IgmpHeader::from_slice(self.slice).unwrap(); + header + } + + /// Number of bytes/octets that will be converted into an + /// [`IgmpHeader`] when [`IgmpSlice::header`] gets called. + #[inline] + pub fn header_len(&self) -> usize { + // SAFETY: Safe as from_slice checks that the slice has at least + // IgmpHeader::MIN_LEN (8) bytes. + let type_u8 = unsafe { *self.slice.get_unchecked(0) }; + match type_u8 { + IGMP_TYPE_MEMBERSHIP_QUERY if self.slice.len() >= MembershipQueryWithSourcesHeader::LEN => { + MembershipQueryWithSourcesHeader::LEN + } + _ => IgmpHeader::MIN_LEN, + } + } + + /// Decode the header values (excluding the checksum) into an [`IgmpType`] enum. + #[inline] + pub fn igmp_type(&self) -> IgmpType { + self.header().igmp_type + } + + /// Returns the "type" byte value in the IGMP header. + #[inline] + pub fn type_u8(&self) -> u8 { + // SAFETY: Safe as from_slice checks that the slice has at least + // IgmpHeader::MIN_LEN (8) bytes. + unsafe { *self.slice.get_unchecked(0) } + } + + /// Returns the second byte of the IGMP header. + /// + /// The meaning of this byte depends on the message type: + /// - Membership Query: Max Response Time (v1: 0, v2: non-zero) + /// - Membership Report V3: Reserved (0) + /// - All other types: unused/reserved + #[inline] + pub fn max_resp_code_or_reserved(&self) -> u8 { + // SAFETY: Safe as from_slice checks that the slice has at least + // IgmpHeader::MIN_LEN (8) bytes. + unsafe { *self.slice.get_unchecked(1) } + } + + /// Returns the "checksum" value in the IGMP header. + #[inline] + pub fn checksum(&self) -> u16 { + // SAFETY: Safe as from_slice checks that the slice has at least + // IgmpHeader::MIN_LEN (8) bytes. + unsafe { get_unchecked_be_u16(self.slice.as_ptr().add(2)) } + } + + /// Returns the bytes from position 4 through 7 in the IGMP header. + /// + /// For most message types this is the Group Address. For IGMPv3 + /// Membership Reports, bytes 4-5 are flags and bytes 6-7 are the + /// Number of Group Records. + #[inline] + pub fn bytes4to7(&self) -> [u8; 4] { + // SAFETY: Safe as from_slice checks that the slice has at least + // IgmpHeader::MIN_LEN (8) bytes. + unsafe { + [ + *self.slice.get_unchecked(4), + *self.slice.get_unchecked(5), + *self.slice.get_unchecked(6), + *self.slice.get_unchecked(7), + ] + } + } + + /// Returns a slice to the bytes not covered by `.header()`. + /// + /// The contents of the payload depend on the message type: + /// + /// | Message Type | Payload Content | + /// |---|---| + /// | [`IgmpType::MembershipQuery`] (v1/v2) | Nothing (empty) | + /// | [`IgmpType::MembershipQueryWithSources`] (v3) | Source Address list | + /// | [`IgmpType::MembershipReportV1`] | Nothing (empty, unless trailing data) | + /// | [`IgmpType::MembershipReportV2`] | Nothing (empty, unless trailing data) | + /// | [`IgmpType::MembershipReportV3`] | Group Records | + /// | [`IgmpType::LeaveGroup`] | Nothing (empty, unless trailing data) | + /// | [`IgmpType::Unknown`] | Everything after the 8th byte | + #[inline] + pub fn payload(&self) -> &'a [u8] { + let header_len = self.header_len(); + // SAFETY: Safe as from_slice validated that the slice length is + // at least header_len. + unsafe { + core::slice::from_raw_parts( + self.slice.as_ptr().add(header_len), + self.slice.len() - header_len, + ) + } + } + + /// Returns the slice containing the entire IGMP packet. + #[inline] + pub fn slice(&self) -> &'a [u8] { + self.slice + } +} + +#[cfg(test)] +mod test { + use super::*; + use alloc::{format, vec}; + use proptest::prelude::*; + + #[test] + fn from_slice_too_small() { + for bad_len in 0..IgmpHeader::MIN_LEN { + let bytes = [0u8; 8]; + assert_eq!( + IgmpSlice::from_slice(&bytes[..bad_len]).unwrap_err(), + err::LenError { + required_len: IgmpHeader::MIN_LEN, + len: bad_len, + len_source: LenSource::Slice, + layer: err::Layer::Igmp, + layer_start_offset: 0, + } + ); + } + } + + #[test] + fn from_slice_query_invalid_length() { + // 9-11 bytes with type 0x11 should fail + for bad_len in 9..12 { + let mut bytes = [0u8; 12]; + bytes[0] = IGMP_TYPE_MEMBERSHIP_QUERY; + assert!(IgmpSlice::from_slice(&bytes[..bad_len]).is_err()); + } + } + + #[test] + fn from_slice_v1_query() { + // 8 bytes, type 0x11, max_resp_time = 0 => v1 query + let mut bytes = [0u8; 8]; + bytes[0] = IGMP_TYPE_MEMBERSHIP_QUERY; + bytes[4] = 224; + bytes[7] = 1; + + let slice = IgmpSlice::from_slice(&bytes).unwrap(); + assert_eq!(slice.type_u8(), IGMP_TYPE_MEMBERSHIP_QUERY); + assert_eq!(slice.max_resp_code_or_reserved(), 0); + assert_eq!(slice.header_len(), 8); + assert_eq!(slice.payload(), &[]); + + match slice.igmp_type() { + IgmpType::MembershipQuery(q) => { + assert_eq!(q.max_response_time, 0); + assert_eq!(q.group_address.octets, [224, 0, 0, 1]); + } + _ => panic!("expected MembershipQuery"), + } + } + + #[test] + fn from_slice_v2_query() { + // 8 bytes, type 0x11, max_resp_time != 0 => v2 query + let mut bytes = [0u8; 8]; + bytes[0] = IGMP_TYPE_MEMBERSHIP_QUERY; + bytes[1] = 100; // max_resp_time + bytes[4] = 224; + bytes[7] = 1; + + let slice = IgmpSlice::from_slice(&bytes).unwrap(); + assert_eq!(slice.max_resp_code_or_reserved(), 100); + assert_eq!(slice.header_len(), 8); + + match slice.igmp_type() { + IgmpType::MembershipQuery(q) => { + assert_eq!(q.max_response_time, 100); + } + _ => panic!("expected MembershipQuery"), + } + } + + #[test] + fn from_slice_v3_query() { + // >= 12 bytes, type 0x11 => v3 query + let mut bytes = [0u8; 16]; + bytes[0] = IGMP_TYPE_MEMBERSHIP_QUERY; + bytes[1] = 50; // max_resp_code + bytes[4] = 224; + bytes[7] = 1; + bytes[8] = 0x0A; // flags|S|QRV + bytes[9] = 125; // QQIC + bytes[10] = 0; + bytes[11] = 1; // 1 source + + // 12 bytes header + 4 bytes payload (1 source address) + let slice = IgmpSlice::from_slice(&bytes[..16]).unwrap(); + assert_eq!(slice.header_len(), MembershipQueryWithSourcesHeader::LEN); + assert_eq!(slice.payload().len(), 4); // 16 - 12 + + match slice.igmp_type() { + IgmpType::MembershipQueryWithSources(q) => { + assert_eq!(q.max_response_code.0, 50); + assert_eq!(q.group_address.octets, [224, 0, 0, 1]); + assert_eq!(q.raw_byte_8, 0x0A); + assert_eq!(q.qqic, 125); + assert_eq!(q.num_of_sources, 1); + } + _ => panic!("expected MembershipQueryWithSources"), + } + } + + #[test] + fn from_slice_v1_report() { + let mut bytes = [0u8; 8]; + bytes[0] = IGMPV1_TYPE_MEMBERSHIP_REPORT; + bytes[4] = 224; + bytes[7] = 1; + + let slice = IgmpSlice::from_slice(&bytes).unwrap(); + assert_eq!(slice.type_u8(), IGMPV1_TYPE_MEMBERSHIP_REPORT); + assert_eq!(slice.header_len(), 8); + assert_eq!(slice.payload(), &[]); + + match slice.igmp_type() { + IgmpType::MembershipReportV1(r) => { + assert_eq!(r.group_address.octets, [224, 0, 0, 1]); + } + _ => panic!("expected MembershipReportV1"), + } + } + + #[test] + fn from_slice_v2_report() { + let mut bytes = [0u8; 8]; + bytes[0] = IGMPV2_TYPE_MEMBERSHIP_REPORT; + bytes[4] = 224; + bytes[7] = 2; + + let slice = IgmpSlice::from_slice(&bytes).unwrap(); + assert_eq!(slice.type_u8(), IGMPV2_TYPE_MEMBERSHIP_REPORT); + + match slice.igmp_type() { + IgmpType::MembershipReportV2(r) => { + assert_eq!(r.group_address.octets, [224, 0, 0, 2]); + } + _ => panic!("expected MembershipReportV2"), + } + } + + #[test] + fn from_slice_v3_report() { + // type 0x22, 8-byte header + group record payload + let mut bytes = vec![0u8; 16]; + bytes[0] = IGMPV3_TYPE_MEMBERSHIP_REPORT; + bytes[1] = 0; // reserved + // bytes[2..4] = checksum (0) + bytes[4] = 0; // flags[0] + bytes[5] = 0; // flags[1] + bytes[6] = 0; // num_of_records high + bytes[7] = 1; // num_of_records low = 1 + // group record (8 bytes) + bytes[8] = 1; // record type (MODE_IS_INCLUDE) + bytes[9] = 0; // aux data len + bytes[10] = 0; // num sources high + bytes[11] = 0; // num sources low + bytes[12] = 224; + bytes[13] = 0; + bytes[14] = 0; + bytes[15] = 1; + + let slice = IgmpSlice::from_slice(&bytes).unwrap(); + assert_eq!(slice.type_u8(), IGMPV3_TYPE_MEMBERSHIP_REPORT); + assert_eq!(slice.header_len(), 8); + assert_eq!(slice.payload().len(), 8); + + match slice.igmp_type() { + IgmpType::MembershipReportV3(r) => { + assert_eq!(r.num_of_records, 1); + assert_eq!(r.flags, [0, 0]); + } + _ => panic!("expected MembershipReportV3"), + } + } + + #[test] + fn from_slice_leave_group() { + let mut bytes = [0u8; 8]; + bytes[0] = IGMPV2_TYPE_LEAVE_GROUP; + bytes[4] = 224; + bytes[7] = 1; + + let slice = IgmpSlice::from_slice(&bytes).unwrap(); + assert_eq!(slice.type_u8(), IGMPV2_TYPE_LEAVE_GROUP); + + match slice.igmp_type() { + IgmpType::LeaveGroup(l) => { + assert_eq!(l.group_address.octets, [224, 0, 0, 1]); + } + _ => panic!("expected LeaveGroup"), + } + } + + #[test] + fn from_slice_unknown_type() { + let mut bytes = [0u8; 8]; + bytes[0] = 0xFF; + bytes[1] = 0xAB; + bytes[4] = 1; + bytes[5] = 2; + bytes[6] = 3; + bytes[7] = 4; + + let slice = IgmpSlice::from_slice(&bytes).unwrap(); + assert_eq!(slice.type_u8(), 0xFF); + + match slice.igmp_type() { + IgmpType::Unknown(u) => { + assert_eq!(u.igmp_type, 0xFF); + assert_eq!(u.raw_byte_1, 0xAB); + assert_eq!(u.raw_bytes_4_7, [1, 2, 3, 4]); + } + _ => panic!("expected Unknown"), + } + } + + #[test] + fn from_slice_with_trailing_payload() { + // v1 report with trailing data + let mut bytes = [0u8; 12]; + bytes[0] = IGMPV1_TYPE_MEMBERSHIP_REPORT; + bytes[4] = 224; + bytes[8] = 0xDE; + bytes[9] = 0xAD; + bytes[10] = 0xBE; + bytes[11] = 0xEF; + + let slice = IgmpSlice::from_slice(&bytes).unwrap(); + assert_eq!(slice.header_len(), 8); + assert_eq!(slice.payload(), &[0xDE, 0xAD, 0xBE, 0xEF]); + } + + proptest! { + #[test] + fn header_roundtrip(bytes in proptest::collection::vec(any::(), 8..=8)) { + // Avoid type 0x11 (query) to sidestep the length-based version detection + let mut bytes = bytes; + if bytes[0] == IGMP_TYPE_MEMBERSHIP_QUERY { + bytes[0] = 0xFF; + } + let slice = IgmpSlice::from_slice(&bytes).unwrap(); + let header = slice.header(); + assert_eq!(header.checksum, slice.checksum()); + } + } + + proptest! { + #[test] + fn type_u8_accessor(bytes in any::<[u8; 8]>()) { + // Avoid 0x11 with exactly 8 bytes -> fine, but avoid invalid 9-11 + let slice_result = IgmpSlice::from_slice(&bytes); + if let Ok(slice) = slice_result { + assert_eq!(bytes[0], slice.type_u8()); + } + } + } + + proptest! { + #[test] + fn checksum_accessor(bytes in any::<[u8; 8]>()) { + if let Ok(slice) = IgmpSlice::from_slice(&bytes) { + assert_eq!( + u16::from_be_bytes([bytes[2], bytes[3]]), + slice.checksum() + ); + } + } + } + + proptest! { + #[test] + fn bytes4to7_accessor(bytes in any::<[u8; 8]>()) { + if let Ok(slice) = IgmpSlice::from_slice(&bytes) { + assert_eq!( + [bytes[4], bytes[5], bytes[6], bytes[7]], + slice.bytes4to7() + ); + } + } + } + + proptest! { + #[test] + fn slice_accessor(bytes in proptest::collection::vec(any::(), 8..64)) { + let mut bytes = bytes; + // Avoid query type to prevent 9-11 byte rejection + if bytes[0] == IGMP_TYPE_MEMBERSHIP_QUERY { + bytes[0] = 0xFF; + } + let igmp_slice = IgmpSlice::from_slice(&bytes).unwrap(); + assert_eq!(&bytes[..], igmp_slice.slice()); + } + } + + proptest! { + #[test] + fn clone_eq(bytes in any::<[u8; 12]>()) { + // Use v3 query type so 12 bytes is valid + let mut bytes = bytes; + bytes[0] = IGMP_TYPE_MEMBERSHIP_QUERY; + let slice = IgmpSlice::from_slice(&bytes).unwrap(); + assert_eq!(slice, slice.clone()); + } + } + + proptest! { + #[test] + fn debug_fmt(bytes in any::<[u8; 8]>()) { + let mut bytes = bytes; + if bytes[0] == IGMP_TYPE_MEMBERSHIP_QUERY { + bytes[0] = 0xFF; + } + let slice = IgmpSlice::from_slice(&bytes).unwrap(); + let dbg = format!("{:?}", slice); + assert!(dbg.starts_with("IgmpSlice")); + } + } + + #[test] + fn payload_v3_query_sources() { + // 12-byte header + 8 bytes (2 source addresses) + let mut bytes = [0u8; 20]; + bytes[0] = IGMP_TYPE_MEMBERSHIP_QUERY; + bytes[1] = 10; // max_resp_code + bytes[10] = 0; + bytes[11] = 2; // 2 sources + // source 1: 10.0.0.1 + bytes[12] = 10; + bytes[15] = 1; + // source 2: 10.0.0.2 + bytes[16] = 10; + bytes[19] = 2; + + let slice = IgmpSlice::from_slice(&bytes).unwrap(); + assert_eq!(slice.header_len(), 12); + let payload = slice.payload(); + assert_eq!(payload.len(), 8); + assert_eq!(payload[0], 10); // first byte of source 1 + assert_eq!(payload[3], 1); + assert_eq!(payload[4], 10); // first byte of source 2 + assert_eq!(payload[7], 2); + } +} diff --git a/etherparse/src/transport/mod.rs b/etherparse/src/transport/mod.rs index c104b4b4..f3ffbe58 100644 --- a/etherparse/src/transport/mod.rs +++ b/etherparse/src/transport/mod.rs @@ -34,6 +34,9 @@ pub use igmp_type::*; mod igmp_header; pub use igmp_header::*; +mod igmp_slice; +pub use igmp_slice::*; + mod tcp_header; pub use tcp_header::*;