Skip to content

Commit

Permalink
allows initializing WeightedShuffle with an ExactSizeIterator
Browse files Browse the repository at this point in the history
Using an ExactSizeIterator instead of a slice may allow to bypass the
allocation to collect an iterator into a vector.
  • Loading branch information
behzadnouri committed Feb 5, 2025
1 parent 6972d47 commit 1a9e088
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 22 deletions.
2 changes: 1 addition & 1 deletion core/src/repair/serve_repair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,7 @@ impl ServeRepair {
.compute_weights_exclude_nonfrozen(slot, &repair_peers)
.into_iter()
.unzip();
let peers = WeightedShuffle::new("repair_request_ancestor_hashes", &weights)
let peers = WeightedShuffle::new("repair_request_ancestor_hashes", weights)
.shuffle(&mut rand::thread_rng())
.map(|i| index[i])
.filter_map(|i| {
Expand Down
4 changes: 2 additions & 2 deletions gossip/benches/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ fn bench_weighted_shuffle_new(c: &mut Criterion) {
c.bench_function("bench_weighted_shuffle_new", |b| {
b.iter(|| {
let weights = make_weights(&mut rng);
black_box(WeightedShuffle::new("", &weights));
black_box(WeightedShuffle::<u64>::new("", &weights));
})
});
}
Expand All @@ -26,7 +26,7 @@ fn bench_weighted_shuffle_shuffle(c: &mut Criterion) {
let mut seed = [0u8; 32];
let mut rng = rand::thread_rng();
let weights = make_weights(&mut rng);
let weighted_shuffle = WeightedShuffle::new("", &weights);
let weighted_shuffle = WeightedShuffle::new("", weights);
c.bench_function("bench_weighted_shuffle_shuffle", |b| {
b.iter(|| {
rng.fill(&mut seed[..]);
Expand Down
2 changes: 1 addition & 1 deletion gossip/src/cluster_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,7 @@ impl ClusterInfo {
return packet_batch;
}
let mut rng = rand::thread_rng();
let shuffle = WeightedShuffle::new("handle-pull-requests", &scores).shuffle(&mut rng);
let shuffle = WeightedShuffle::new("handle-pull-requests", scores).shuffle(&mut rng);
let mut total_bytes = 0;
let mut sent = 0;
for (addr, response) in shuffle.map(|i| &responses[i]) {
Expand Down
2 changes: 1 addition & 1 deletion gossip/src/push_active_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl PushActiveSetEntry {
) {
debug_assert_eq!(nodes.len(), weights.len());
debug_assert!(weights.iter().all(|&weight| weight != 0u64));
let shuffle = WeightedShuffle::new("rotate-active-set", weights).shuffle(rng);
let shuffle = WeightedShuffle::<u64>::new("rotate-active-set", weights).shuffle(rng);
for node in shuffle.map(|k| &nodes[k]) {
// We intend to discard the oldest/first entry in the index-map.
if self.0.len() > size {
Expand Down
39 changes: 24 additions & 15 deletions gossip/src/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use {
distributions::uniform::{SampleUniform, UniformSampler},
Rng,
},
std::ops::{AddAssign, SubAssign},
std::{
borrow::Borrow,
ops::{AddAssign, SubAssign},
},
};

// Each internal tree node has FANOUT many child nodes with indices:
Expand Down Expand Up @@ -52,15 +55,21 @@ where
{
/// If weights are negative or overflow the total sum
/// they are treated as zero.
pub fn new(name: &'static str, weights: &[T]) -> Self {
pub fn new<I>(name: &'static str, weights: I) -> Self
where
I: IntoIterator<Item: Borrow<T>>,
<I as IntoIterator>::IntoIter: ExactSizeIterator,
{
let weights = weights.into_iter();
let (num_nodes, size) = get_num_nodes_and_tree_size(weights.len());
debug_assert!(size <= num_nodes);
let mut tree = vec![[Self::ZERO; FANOUT]; size];
let mut sum = Self::ZERO;
let mut zeros = Vec::default();
let mut num_negative: usize = 0;
let mut num_overflow: usize = 0;
for (k, &weight) in weights.iter().enumerate() {
for (k, weight) in weights.enumerate() {
let weight = *weight.borrow();
#[allow(clippy::neg_cmp_op_on_partial_ord)]
// weight < zero does not work for NaNs.
if !(weight >= Self::ZERO) {
Expand Down Expand Up @@ -348,7 +357,7 @@ mod tests {
fn test_weighted_shuffle_empty_weights() {
let weights = Vec::<u64>::new();
let mut rng = rand::thread_rng();
let shuffle = WeightedShuffle::new("", &weights);
let shuffle = WeightedShuffle::new("", weights);
assert!(shuffle.clone().shuffle(&mut rng).next().is_none());
assert!(shuffle.first(&mut rng).is_none());
}
Expand All @@ -359,7 +368,7 @@ mod tests {
let weights = vec![0u64; 5];
let seed = [37u8; 32];
let mut rng = ChaChaRng::from_seed(seed);
let shuffle = WeightedShuffle::new("", &weights);
let shuffle = WeightedShuffle::new("", weights);
assert_eq!(
shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
[1, 4, 2, 3, 0]
Expand All @@ -377,14 +386,14 @@ mod tests {
let weights = [1, 0, 1000, 0, 0, 10, 100, 0];
let mut counts = [0; 8];
for _ in 0..100000 {
let mut shuffle = WeightedShuffle::new("", &weights).shuffle(&mut rng);
let mut shuffle = WeightedShuffle::<i32>::new("", weights).shuffle(&mut rng);
counts[shuffle.next().unwrap()] += 1;
let _ = shuffle.count(); // consume the rest.
}
assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]);
let mut counts = [0; 8];
for _ in 0..100000 {
let mut shuffle = WeightedShuffle::new("", &weights);
let mut shuffle = WeightedShuffle::<i32>::new("", weights);
shuffle.remove_index(5);
shuffle.remove_index(3);
shuffle.remove_index(1);
Expand All @@ -400,15 +409,15 @@ mod tests {
const SEED: [u8; 32] = [48u8; 32];
let weights = [19i64, 23, 7, 0, 0, 23, 3, 0, 5, 0, 19, 29];
let mut rng = ChaChaRng::from_seed(SEED);
let shuffle = WeightedShuffle::new("", &weights);
let shuffle = WeightedShuffle::new("", weights);
assert_eq!(
shuffle.shuffle(&mut rng).collect::<Vec<_>>(),
[8, 1, 5, 10, 11, 0, 2, 6, 9, 4, 3, 7]
);
// Negative weights and overflowing ones are treated as zero.
let weights = [19, 23, 7, -57, i64::MAX, 23, 3, i64::MAX, 5, -79, 19, 29];
let mut rng = ChaChaRng::from_seed(SEED);
let shuffle = WeightedShuffle::new("", &weights);
let shuffle = WeightedShuffle::new("", weights);
assert_eq!(
shuffle.shuffle(&mut rng).collect::<Vec<_>>(),
[8, 1, 5, 10, 11, 0, 2, 6, 9, 4, 3, 7]
Expand All @@ -422,7 +431,7 @@ mod tests {
];
let seed = [48u8; 32];
let mut rng = ChaChaRng::from_seed(seed);
let mut shuffle = WeightedShuffle::new("", &weights);
let mut shuffle = WeightedShuffle::<i32>::new("", weights);
assert_eq!(
shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
[2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5]
Expand All @@ -442,7 +451,7 @@ mod tests {
assert_eq!(shuffle.first(&mut rng), Some(4));
let seed = [37u8; 32];
let mut rng = ChaChaRng::from_seed(seed);
let mut shuffle = WeightedShuffle::new("", &weights);
let mut shuffle = WeightedShuffle::new("", weights);
assert_eq!(
shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
[19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11]
Expand Down Expand Up @@ -503,7 +512,7 @@ mod tests {
weights.iter().fold(0u64, |a, &b| a.checked_add(b).unwrap()),
weights.iter().sum::<u64>()
);
let mut shuffle = WeightedShuffle::new("", &weights);
let mut shuffle = WeightedShuffle::<u64>::new("", &weights);
let shuffle1 = shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>();
// Assert that all indices appear in the shuffle.
assert_eq!(shuffle1.len(), num_weights);
Expand Down Expand Up @@ -544,13 +553,13 @@ mod tests {
let mut seed = [0u8; 32];
rng.fill(&mut seed[..]);
let mut rng = ChaChaRng::from_seed(seed);
let shuffle = WeightedShuffle::new("", &weights);
let shuffle = WeightedShuffle::<u64>::new("", &weights);
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
let mut rng = ChaChaRng::from_seed(seed);
let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone());
assert_eq!(shuffle, shuffle_slow);
let mut rng = ChaChaRng::from_seed(seed);
let shuffle = WeightedShuffle::new("", &weights);
let shuffle = WeightedShuffle::<u64>::new("", &weights);
assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0]));
}
}
Expand All @@ -563,7 +572,7 @@ mod tests {
let seed = rng.gen::<[u8; 32]>();
let mut rng = ChaChaRng::from_seed(seed);
let shuffle_slow = weighted_shuffle_slow(&mut rng.clone(), weights.clone());
let shuffle = WeightedShuffle::new("", &weights);
let shuffle = WeightedShuffle::new("", weights);
if size > 0 {
assert_eq!(shuffle.first(&mut rng.clone()), Some(shuffle_slow[0]));
}
Expand Down
4 changes: 2 additions & 2 deletions turbine/src/cluster_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ pub fn new_cluster_nodes<T: 'static>(
.map(|(ix, node)| (*node.pubkey(), ix))
.collect();
let broadcast = TypeId::of::<T>() == TypeId::of::<BroadcastStage>();
let stakes: Vec<u64> = nodes.iter().map(|node| node.stake).collect();
let mut weighted_shuffle = WeightedShuffle::new("cluster-nodes", &stakes);
let stakes = nodes.iter().map(|node| node.stake);
let mut weighted_shuffle = WeightedShuffle::new("cluster-nodes", stakes);
if broadcast {
weighted_shuffle.remove_index(index[&self_pubkey]);
}
Expand Down

0 comments on commit 1a9e088

Please sign in to comment.