Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allows initializing WeightedShuffle with an ExactSizeIterator #4796

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/repair/serve_repair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,7 @@ impl ServeRepair {
.compute_weights_exclude_nonfrozen(slot, &repair_peers)
.into_iter()
.unzip();
let peers = WeightedShuffle::new("repair_request_ancestor_hashes", &weights)
let peers = WeightedShuffle::new("repair_request_ancestor_hashes", weights)
.shuffle(&mut rand::thread_rng())
.map(|i| index[i])
.filter_map(|i| {
Expand Down
4 changes: 2 additions & 2 deletions gossip/benches/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ fn bench_weighted_shuffle_new(c: &mut Criterion) {
c.bench_function("bench_weighted_shuffle_new", |b| {
b.iter(|| {
let weights = make_weights(&mut rng);
black_box(WeightedShuffle::new("", &weights));
black_box(WeightedShuffle::<u64>::new("", &weights));
})
});
}
Expand All @@ -26,7 +26,7 @@ fn bench_weighted_shuffle_shuffle(c: &mut Criterion) {
let mut seed = [0u8; 32];
let mut rng = rand::thread_rng();
let weights = make_weights(&mut rng);
let weighted_shuffle = WeightedShuffle::new("", &weights);
let weighted_shuffle = WeightedShuffle::new("", weights);
c.bench_function("bench_weighted_shuffle_shuffle", |b| {
b.iter(|| {
rng.fill(&mut seed[..]);
Expand Down
2 changes: 1 addition & 1 deletion gossip/src/cluster_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,7 @@ impl ClusterInfo {
return packet_batch;
}
let mut rng = rand::thread_rng();
let shuffle = WeightedShuffle::new("handle-pull-requests", &scores).shuffle(&mut rng);
let shuffle = WeightedShuffle::new("handle-pull-requests", scores).shuffle(&mut rng);
let mut total_bytes = 0;
let mut sent = 0;
for (addr, response) in shuffle.map(|i| &responses[i]) {
Expand Down
2 changes: 1 addition & 1 deletion gossip/src/push_active_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl PushActiveSetEntry {
) {
debug_assert_eq!(nodes.len(), weights.len());
debug_assert!(weights.iter().all(|&weight| weight != 0u64));
let shuffle = WeightedShuffle::new("rotate-active-set", weights).shuffle(rng);
let shuffle = WeightedShuffle::<u64>::new("rotate-active-set", weights).shuffle(rng);
for node in shuffle.map(|k| &nodes[k]) {
// We intend to discard the oldest/first entry in the index-map.
if self.0.len() > size {
Expand Down
39 changes: 24 additions & 15 deletions gossip/src/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use {
distributions::uniform::{SampleUniform, UniformSampler},
Rng,
},
std::ops::{AddAssign, SubAssign},
std::{
borrow::Borrow,
ops::{AddAssign, SubAssign},
},
};

// Each internal tree node has FANOUT many child nodes with indices:
Expand Down Expand Up @@ -52,15 +55,21 @@ where
{
/// If weights are negative or overflow the total sum
/// they are treated as zero.
pub fn new(name: &'static str, weights: &[T]) -> Self {
pub fn new<I>(name: &'static str, weights: I) -> Self
where
I: IntoIterator<Item: Borrow<T>>,
<I as IntoIterator>::IntoIter: ExactSizeIterator,
{
let weights = weights.into_iter();
let (num_nodes, size) = get_num_nodes_and_tree_size(weights.len());
debug_assert!(size <= num_nodes);
let mut tree = vec![[Self::ZERO; FANOUT]; size];
let mut sum = Self::ZERO;
let mut zeros = Vec::default();
let mut num_negative: usize = 0;
let mut num_overflow: usize = 0;
for (k, &weight) in weights.iter().enumerate() {
for (k, weight) in weights.enumerate() {
let weight = *weight.borrow();
#[allow(clippy::neg_cmp_op_on_partial_ord)]
// weight < zero does not work for NaNs.
if !(weight >= Self::ZERO) {
Expand Down Expand Up @@ -348,7 +357,7 @@ mod tests {
fn test_weighted_shuffle_empty_weights() {
let weights = Vec::<u64>::new();
let mut rng = rand::thread_rng();
let shuffle = WeightedShuffle::new("", &weights);
let shuffle = WeightedShuffle::new("", weights);
assert!(shuffle.clone().shuffle(&mut rng).next().is_none());
assert!(shuffle.first(&mut rng).is_none());
}
Expand All @@ -359,7 +368,7 @@ mod tests {
let weights = vec![0u64; 5];
let seed = [37u8; 32];
let mut rng = ChaChaRng::from_seed(seed);
let shuffle = WeightedShuffle::new("", &weights);
let shuffle = WeightedShuffle::new("", weights);
assert_eq!(
shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
[1, 4, 2, 3, 0]
Expand All @@ -377,14 +386,14 @@ mod tests {
let weights = [1, 0, 1000, 0, 0, 10, 100, 0];
let mut counts = [0; 8];
for _ in 0..100000 {
let mut shuffle = WeightedShuffle::new("", &weights).shuffle(&mut rng);
let mut shuffle = WeightedShuffle::new("", weights).shuffle(&mut rng);
counts[shuffle.next().unwrap()] += 1;
let _ = shuffle.count(); // consume the rest.
}
assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]);
let mut counts = [0; 8];
for _ in 0..100000 {
let mut shuffle = WeightedShuffle::new("", &weights);
let mut shuffle = WeightedShuffle::new("", weights);
shuffle.remove_index(5);
shuffle.remove_index(3);
shuffle.remove_index(1);
Expand All @@ -400,15 +409,15 @@ mod tests {
const SEED: [u8; 32] = [48u8; 32];
let weights = [19i64, 23, 7, 0, 0, 23, 3, 0, 5, 0, 19, 29];
let mut rng = ChaChaRng::from_seed(SEED);
let shuffle = WeightedShuffle::new("", &weights);
let shuffle = WeightedShuffle::new("", weights);
assert_eq!(
shuffle.shuffle(&mut rng).collect::<Vec<_>>(),
[8, 1, 5, 10, 11, 0, 2, 6, 9, 4, 3, 7]
);
// Negative weights and overflowing ones are treated as zero.
let weights = [19, 23, 7, -57, i64::MAX, 23, 3, i64::MAX, 5, -79, 19, 29];
let mut rng = ChaChaRng::from_seed(SEED);
let shuffle = WeightedShuffle::new("", &weights);
let shuffle = WeightedShuffle::new("", weights);
assert_eq!(
shuffle.shuffle(&mut rng).collect::<Vec<_>>(),
[8, 1, 5, 10, 11, 0, 2, 6, 9, 4, 3, 7]
Expand All @@ -422,7 +431,7 @@ mod tests {
];
let seed = [48u8; 32];
let mut rng = ChaChaRng::from_seed(seed);
let mut shuffle = WeightedShuffle::new("", &weights);
let mut shuffle = WeightedShuffle::new("", weights);
assert_eq!(
shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
[2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5]
Expand All @@ -442,7 +451,7 @@ mod tests {
assert_eq!(shuffle.first(&mut rng), Some(4));
let seed = [37u8; 32];
let mut rng = ChaChaRng::from_seed(seed);
let mut shuffle = WeightedShuffle::new("", &weights);
let mut shuffle = WeightedShuffle::new("", weights);
assert_eq!(
shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>(),
[19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11]
Expand Down Expand Up @@ -503,7 +512,7 @@ mod tests {
weights.iter().fold(0u64, |a, &b| a.checked_add(b).unwrap()),
weights.iter().sum::<u64>()
);
let mut shuffle = WeightedShuffle::new("", &weights);
let mut shuffle = WeightedShuffle::<u64>::new("", &weights);
let shuffle1 = shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>();
// Assert that all indices appear in the shuffle.
assert_eq!(shuffle1.len(), num_weights);
Expand Down Expand Up @@ -544,13 +553,13 @@ mod tests {
let mut seed = [0u8; 32];
rng.fill(&mut seed[..]);
let mut rng = ChaChaRng::from_seed(seed);
let shuffle = WeightedShuffle::new("", &weights);
let shuffle = WeightedShuffle::<u64>::new("", &weights);
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
let mut rng = ChaChaRng::from_seed(seed);
let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone());
assert_eq!(shuffle, shuffle_slow);
let mut rng = ChaChaRng::from_seed(seed);
let shuffle = WeightedShuffle::new("", &weights);
let shuffle = WeightedShuffle::<u64>::new("", &weights);
assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0]));
}
}
Expand All @@ -563,7 +572,7 @@ mod tests {
let seed = rng.gen::<[u8; 32]>();
let mut rng = ChaChaRng::from_seed(seed);
let shuffle_slow = weighted_shuffle_slow(&mut rng.clone(), weights.clone());
let shuffle = WeightedShuffle::new("", &weights);
let shuffle = WeightedShuffle::new("", weights);
if size > 0 {
assert_eq!(shuffle.first(&mut rng.clone()), Some(shuffle_slow[0]));
}
Expand Down
4 changes: 2 additions & 2 deletions turbine/src/cluster_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ pub fn new_cluster_nodes<T: 'static>(
.map(|(ix, node)| (*node.pubkey(), ix))
.collect();
let broadcast = TypeId::of::<T>() == TypeId::of::<BroadcastStage>();
let stakes: Vec<u64> = nodes.iter().map(|node| node.stake).collect();
let mut weighted_shuffle = WeightedShuffle::new("cluster-nodes", &stakes);
let stakes = nodes.iter().map(|node| node.stake);
let mut weighted_shuffle = WeightedShuffle::new("cluster-nodes", stakes);
if broadcast {
weighted_shuffle.remove_index(index[&self_pubkey]);
}
Expand Down
Loading