util: assert compatibility between LengthDelimitedCodec options (#6414)

This commit is contained in:
M.Amin Rayej 2024-03-24 01:12:24 +03:30 committed by GitHub
parent 4c453e9790
commit 8342e4b524
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 98 additions and 1 deletions

View File

@ -386,6 +386,10 @@ use std::{cmp, fmt, mem};
/// `Builder` enables constructing configured length delimited codecs. Note
/// that not all configuration settings apply to both encoding and decoding. See
/// the documentation for specific methods for more detail.
///
/// Note that the if the value of [`Builder::max_frame_length`] becomes larger than
/// what can actually fit in [`Builder::length_field_length`], it will be clipped to
/// the maximum value that can fit.
#[derive(Debug, Clone, Copy)]
pub struct Builder {
// Maximum frame length
@ -935,8 +939,12 @@ impl Builder {
/// # }
/// ```
pub fn new_codec(&self) -> LengthDelimitedCodec {
let mut builder = *self;
builder.adjust_max_frame_len();
LengthDelimitedCodec {
builder: *self,
builder,
state: DecodeState::Head,
}
}
@ -1018,6 +1026,35 @@ impl Builder {
self.num_skip
.unwrap_or(self.length_field_offset + self.length_field_len)
}
fn adjust_max_frame_len(&mut self) {
// This function is basically `std::u64::saturating_add_signed`. Since it
// requires MSRV 1.66, its implementation is copied here.
//
// TODO: use the method from std when MSRV becomes >= 1.66
fn saturating_add_signed(num: u64, rhs: i64) -> u64 {
let (res, overflow) = num.overflowing_add(rhs as u64);
if overflow == (rhs < 0) {
res
} else if overflow {
u64::MAX
} else {
0
}
}
// Calculate the maximum number that can be represented using `length_field_len` bytes.
let max_number = match 1u64.checked_shl((8 * self.length_field_len) as u32) {
Some(shl) => shl - 1,
None => u64::MAX,
};
let max_allowed_len = saturating_add_signed(max_number, self.length_adjustment as i64);
if self.max_frame_len as u64 > max_allowed_len {
self.max_frame_len = usize::try_from(max_allowed_len).unwrap_or(usize::MAX);
}
}
}
impl Default for Builder {

View File

@ -689,6 +689,66 @@ fn encode_overflow() {
codec.encode(Bytes::from("hello"), &mut buf).unwrap();
}
#[test]
fn frame_does_not_fit() {
let codec = LengthDelimitedCodec::builder()
.length_field_length(1)
.max_frame_length(256)
.new_codec();
assert_eq!(codec.max_frame_length(), 255);
}
#[test]
fn neg_adjusted_frame_does_not_fit() {
let codec = LengthDelimitedCodec::builder()
.length_field_length(1)
.length_adjustment(-1)
.new_codec();
assert_eq!(codec.max_frame_length(), 254);
}
#[test]
fn pos_adjusted_frame_does_not_fit() {
let codec = LengthDelimitedCodec::builder()
.length_field_length(1)
.length_adjustment(1)
.new_codec();
assert_eq!(codec.max_frame_length(), 256);
}
#[test]
fn max_allowed_frame_fits() {
let codec = LengthDelimitedCodec::builder()
.length_field_length(std::mem::size_of::<usize>())
.max_frame_length(usize::MAX)
.new_codec();
assert_eq!(codec.max_frame_length(), usize::MAX);
}
#[test]
fn smaller_frame_len_not_adjusted() {
let codec = LengthDelimitedCodec::builder()
.max_frame_length(10)
.length_field_length(std::mem::size_of::<usize>())
.new_codec();
assert_eq!(codec.max_frame_length(), 10);
}
#[test]
fn max_allowed_length_field() {
let codec = LengthDelimitedCodec::builder()
.length_field_length(8)
.max_frame_length(usize::MAX)
.new_codec();
assert_eq!(codec.max_frame_length(), usize::MAX);
}
// ===== Test utils =====
struct Mock {