util: Add rng utilities (#686)

This adds new PRNG utilities that only use libstd and not the external
`rand` crate. This change's motivation are that in tower middleware that
need PRNG don't need the complexity and vast utilities of the `rand`
crate.

This adds a `Rng` trait which abstracts the simple PRNG features tower
needs. This also provides a `HasherRng` which uses the `RandomState`
type from libstd to generate random `u64` values. In addition, there is
an internal only `sample_inplace` which is used within the balance p2c
middleware to randomly pick a ready service. This implementation is
crate private since its quite specific to the balance implementation.

The goal of this in addition to the balance middlware getting `rand`
removed is for the upcoming `Retry` changes. The `next_f64` will be used
in the jitter portion of the backoff utilities in #685.

Co-authored-by: Eliza Weisman <eliza@buoyant.io>
This commit is contained in:
Lucio Franco 2022-08-25 13:06:24 -04:00 committed by GitHub
parent aec7b8f417
commit e0558266a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 217 additions and 15 deletions

View File

@ -47,7 +47,7 @@ full = [
]
# FIXME: Use weak dependency once available (https://github.com/rust-lang/cargo/issues/8832)
log = ["tracing/log"]
balance = ["discover", "load", "ready-cache", "make", "rand", "slab"]
balance = ["discover", "load", "ready-cache", "make", "slab", "util"]
buffer = ["__common", "tokio/sync", "tokio/rt", "tokio-util", "tracing"]
discover = ["__common"]
filter = ["__common", "futures-util"]
@ -72,7 +72,6 @@ futures-core = { version = "0.3", optional = true }
futures-util = { version = "0.3", default-features = false, features = ["alloc"], optional = true }
hdrhistogram = { version = "7.0", optional = true, default-features = false }
indexmap = { version = "1.0.2", optional = true }
rand = { version = "0.8", features = ["small_rng"], optional = true }
slab = { version = "0.4", optional = true }
tokio = { version = "1.6", optional = true, features = ["sync"] }
tokio-stream = { version = "0.1.0", optional = true }
@ -88,9 +87,12 @@ tokio = { version = "1.6.2", features = ["macros", "sync", "test-util", "rt-mult
tokio-stream = "0.1"
tokio-test = "0.4"
tower-test = { version = "0.4", path = "../tower-test" }
tracing = { version = "0.1.2", default-features = false, features = ["std"] }
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi"] }
http = "0.2"
lazy_static = "1.4.0"
rand = { version = "0.8", features = ["small_rng"] }
quickcheck = "1"
[package.metadata.docs.rs]
all-features = true

View File

@ -2,10 +2,10 @@ use super::super::error;
use crate::discover::{Change, Discover};
use crate::load::Load;
use crate::ready_cache::{error::Failed, ReadyCache};
use crate::util::rng::{sample_inplace, HasherRng, Rng};
use futures_core::ready;
use futures_util::future::{self, TryFutureExt};
use pin_project_lite::pin_project;
use rand::{rngs::SmallRng, Rng, SeedableRng};
use std::hash::Hash;
use std::marker::PhantomData;
use std::{
@ -39,7 +39,7 @@ where
services: ReadyCache<D::Key, D::Service, Req>,
ready_index: Option<usize>,
rng: SmallRng,
rng: Box<dyn Rng + Send + Sync>,
_req: PhantomData<Req>,
}
@ -86,20 +86,20 @@ where
{
/// Constructs a load balancer that uses operating system entropy.
pub fn new(discover: D) -> Self {
Self::from_rng(discover, &mut rand::thread_rng()).expect("ThreadRNG must be valid")
Self::from_rng(discover, HasherRng::default())
}
/// Constructs a load balancer seeded with the provided random number generator.
pub fn from_rng<R: Rng>(discover: D, rng: R) -> Result<Self, rand::Error> {
let rng = SmallRng::from_rng(rng)?;
Ok(Self {
pub fn from_rng<R: Rng + Send + Sync + 'static>(discover: D, rng: R) -> Self {
let rng = Box::new(rng);
Self {
rng,
discover,
services: ReadyCache::default(),
ready_index: None,
_req: PhantomData,
})
}
}
/// Returns the number of endpoints currently tracked by the balancer.
@ -185,14 +185,14 @@ where
len => {
// Get two distinct random indexes (in a random order) and
// compare the loads of the service at each index.
let idxs = rand::seq::index::sample(&mut self.rng, len, 2);
let idxs = sample_inplace(&mut self.rng, len as u32, 2);
let aidx = idxs.index(0);
let bidx = idxs.index(1);
let aidx = idxs[0];
let bidx = idxs[1];
debug_assert_ne!(aidx, bidx, "random indices must be distinct");
let aload = self.ready_index_load(aidx);
let bload = self.ready_index_load(bidx);
let aload = self.ready_index_load(aidx as usize);
let bload = self.ready_index_load(bidx as usize);
let chosen = if aload <= bload { aidx } else { bidx };
trace!(
@ -203,7 +203,7 @@ where
chosen = if chosen == aidx { "a" } else { "b" },
"p2c",
);
Some(chosen)
Some(chosen as usize)
}
}
}

View File

@ -19,6 +19,8 @@ mod ready;
mod service_fn;
mod then;
pub mod rng;
pub use self::{
and_then::{AndThen, AndThenLayer},
boxed::{BoxLayer, BoxService, UnsyncBoxService},

198
tower/src/util/rng.rs Normal file
View File

@ -0,0 +1,198 @@
//! [PRNG] utilities for tower middleware.
//!
//! This module provides a generic [`Rng`] trait and a [`HasherRng`] that
//! implements the trait based on [`RandomState`] or any other [`Hasher`].
//!
//! These utlities replace tower's internal usage of `rand` with these smaller,
//! more lightweight methods. Most of the implementations are extracted from
//! their corresponding `rand` implementations.
//!
//! [PRNG]: https://en.wikipedia.org/wiki/Pseudorandom_number_generator
use std::{
collections::hash_map::RandomState,
hash::{BuildHasher, Hasher},
ops::Range,
};
/// A simple [PRNG] trait for use within tower middleware.
///
/// [PRNG]: https://en.wikipedia.org/wiki/Pseudorandom_number_generator
pub trait Rng {
/// Generate a random [`u64`].
fn next_u64(&mut self) -> u64;
/// Generate a random [`f64`] between `[0, 1)`.
fn next_f64(&mut self) -> f64 {
// Borrowed from:
// https://github.com/rust-random/rand/blob/master/src/distributions/float.rs#L106
let float_size = std::mem::size_of::<f64>() as u32 * 8;
let precison = 52 + 1;
let scale = 1.0 / ((1u64 << precison) as f64);
let value = self.next_u64();
let value = value >> (float_size - precison);
scale * value as f64
}
/// Randomly pick a value within the range.
///
/// # Panic
///
/// - If start < end this will panic in debug mode.
fn next_range(&mut self, range: Range<u64>) -> u64 {
debug_assert!(
range.start < range.end,
"The range start must be smaller than the end"
);
let start = range.start;
let end = range.end;
let range = end - start;
let n = self.next_u64();
(n % range) + start
}
}
impl<R: Rng + ?Sized> Rng for Box<R> {
fn next_u64(&mut self) -> u64 {
(**self).next_u64()
}
}
/// A [`Rng`] implementation that uses a [`Hasher`] to generate the random
/// values. The implementation uses an internal counter to pass to the hasher
/// for each iteration of [`Rng::next_u64`].
///
/// # Default
///
/// This hasher has a default type of [`RandomState`] which just uses the
/// libstd method of getting a random u64.
#[derive(Debug)]
pub struct HasherRng<H = RandomState> {
hasher: H,
counter: u64,
}
impl HasherRng {
/// Create a new default [`HasherRng`].
pub fn new() -> Self {
HasherRng::default()
}
}
impl Default for HasherRng {
fn default() -> Self {
HasherRng::with_hasher(RandomState::default())
}
}
impl<H> HasherRng<H> {
/// Create a new [`HasherRng`] with the provided hasher.
pub fn with_hasher(hasher: H) -> Self {
HasherRng { hasher, counter: 0 }
}
}
impl<H> Rng for HasherRng<H>
where
H: BuildHasher,
{
fn next_u64(&mut self) -> u64 {
let mut hasher = self.hasher.build_hasher();
hasher.write_u64(self.counter);
self.counter = self.counter.wrapping_add(1);
hasher.finish()
}
}
/// An inplace sampler borrowed from the Rand implementation for use internally
/// for the balance middleware.
/// ref: https://github.com/rust-random/rand/blob/b73640705d6714509f8ceccc49e8df996fa19f51/src/seq/index.rs#L425
///
/// Docs from rand:
///
/// Randomly sample exactly `amount` indices from `0..length`, using an inplace
/// partial Fisher-Yates method.
/// Sample an amount of indices using an inplace partial fisher yates method.
///
/// This allocates the entire `length` of indices and randomizes only the first `amount`.
/// It then truncates to `amount` and returns.
///
/// This method is not appropriate for large `length` and potentially uses a lot
/// of memory; because of this we only implement for `u32` index (which improves
/// performance in all cases).
///
/// Set-up is `O(length)` time and memory and shuffling is `O(amount)` time.
pub(crate) fn sample_inplace<R: Rng>(rng: &mut R, length: u32, amount: u32) -> Vec<u32> {
debug_assert!(amount <= length);
let mut indices: Vec<u32> = Vec::with_capacity(length as usize);
indices.extend(0..length);
for i in 0..amount {
let j: u64 = rng.next_range(i as u64..length as u64);
indices.swap(i as usize, j as usize);
}
indices.truncate(amount as usize);
debug_assert_eq!(indices.len(), amount as usize);
indices
}
#[cfg(test)]
mod tests {
use super::*;
use quickcheck::*;
quickcheck! {
fn next_f64(counter: u64) -> TestResult {
let mut rng = HasherRng::default();
rng.counter = counter;
let n = rng.next_f64();
TestResult::from_bool(n < 1.0 && n >= 0.0)
}
fn next_range(counter: u64, range: Range<u64>) -> TestResult {
if range.start >= range.end{
return TestResult::discard();
}
let mut rng = HasherRng::default();
rng.counter = counter;
let n = rng.next_range(range.clone());
TestResult::from_bool(n >= range.start && (n < range.end || range.start == range.end))
}
fn sample_inplace(counter: u64, length: u32, amount: u32) -> TestResult {
if amount > length || length > 256 || amount > 32 {
return TestResult::discard();
}
let mut rng = HasherRng::default();
rng.counter = counter;
let indxs = super::sample_inplace(&mut rng, length, amount);
for indx in indxs {
if indx > length {
return TestResult::failed();
}
}
TestResult::passed()
}
}
#[test]
fn sample_inplace_boundaries() {
let mut r = HasherRng::default();
assert_eq!(super::sample_inplace(&mut r, 0, 0).len(), 0);
assert_eq!(super::sample_inplace(&mut r, 1, 0).len(), 0);
assert_eq!(super::sample_inplace(&mut r, 1, 1), vec![0]);
}
}