Skip to content

Commit

Permalink
reuses a thread-local weighted-shuffle for get_retransmit_addrs
Browse files Browse the repository at this point in the history
  • Loading branch information
behzadnouri committed Feb 5, 2025
1 parent b538fab commit 5c403a1
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 21 deletions.
3 changes: 2 additions & 1 deletion gossip/src/cluster_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,8 @@ impl ClusterInfo {
return packet_batch;
}
let mut rng = rand::thread_rng();
let shuffle = WeightedShuffle::new("handle-pull-requests", scores).shuffle(&mut rng);
let mut weighted_shuffle = WeightedShuffle::new("handle-pull-requests", scores);
let shuffle = weighted_shuffle.shuffle(&mut rng);
let mut total_bytes = 0;
let mut sent = 0;
for (addr, response) in shuffle.map(|i| &responses[i]) {
Expand Down
4 changes: 2 additions & 2 deletions gossip/src/push_active_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ impl PushActiveSetEntry {
) {
debug_assert_eq!(nodes.len(), weights.len());
debug_assert!(weights.iter().all(|&weight| weight != 0u64));
let shuffle = WeightedShuffle::<u64>::new("rotate-active-set", weights).shuffle(rng);
for node in shuffle.map(|k| &nodes[k]) {
let mut weighted_shuffle = WeightedShuffle::<u64>::new("rotate-active-set", weights);
for node in weighted_shuffle.shuffle(rng).map(|k| &nodes[k]) {
// We intend to discard the oldest/first entry in the index-map.
if self.0.len() > size {
break;
Expand Down
2 changes: 1 addition & 1 deletion gossip/src/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ impl<'a, T: 'a> WeightedShuffle<T>
where
T: Copy + ConstZero + PartialOrd + SampleUniform + SubAssign,
{
pub fn shuffle<R: Rng>(mut self, rng: &'a mut R) -> impl Iterator<Item = usize> + 'a {
pub fn shuffle<R: Rng>(&'a mut self, rng: &'a mut R) -> impl Iterator<Item = usize> + 'a {
std::iter::from_fn(move || {
if self.weight > Self::ZERO {
let sample =
Expand Down
43 changes: 26 additions & 17 deletions turbine/src/cluster_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use {
solana_streamer::socket::SocketAddrSpace,
std::{
any::TypeId,
cell::RefCell,
cmp::Ordering,
collections::{HashMap, HashSet},
iter::repeat_with,
Expand All @@ -37,6 +38,12 @@ use {
thiserror::Error,
};

thread_local! {
static THREAD_LOCAL_WEIGHTED_SHUFFLE: RefCell<WeightedShuffle<u64>> = RefCell::new(
WeightedShuffle::<u64>::new::<[u64; 0]>("get_retransmit_addrs", []),
);
}

const DATA_PLANE_FANOUT: usize = 200;
pub(crate) const MAX_NUM_TURBINE_HOPS: usize = 4;

Expand Down Expand Up @@ -220,30 +227,32 @@ impl ClusterNodes<RetransmitStage> {
fanout: usize,
socket_addr_space: &SocketAddrSpace,
) -> Result<(/*root_distance:*/ usize, Vec<SocketAddr>), Error> {
let mut weighted_shuffle = self.weighted_shuffle.clone();
// Exclude slot leader from list of nodes.
if slot_leader == &self.pubkey {
return Err(Error::Loopback {
leader: *slot_leader,
shred: *shred,
});
}
if let Some(index) = self.index.get(slot_leader) {
weighted_shuffle.remove_index(*index);
}
let mut rng = get_seeded_rng(slot_leader, shred);
let (index, peers) = get_retransmit_peers(
fanout,
|k| self.nodes[k].pubkey() == &self.pubkey,
weighted_shuffle.shuffle(&mut rng),
);
let protocol = get_broadcast_protocol(shred);
let peers = peers
.filter_map(|k| self.nodes[k].contact_info()?.tvu(protocol))
.filter(|addr| socket_addr_space.check(addr))
.collect();
let root_distance = get_root_distance(index, fanout);
Ok((root_distance, peers))
THREAD_LOCAL_WEIGHTED_SHUFFLE.with_borrow_mut(|weighted_shuffle| {
weighted_shuffle.clone_from(&self.weighted_shuffle);
if let Some(index) = self.index.get(slot_leader) {
weighted_shuffle.remove_index(*index);
}
let mut rng = get_seeded_rng(slot_leader, shred);
let (index, peers) = get_retransmit_peers(
fanout,
|k| self.nodes[k].pubkey() == &self.pubkey,
weighted_shuffle.shuffle(&mut rng),
);
let protocol = get_broadcast_protocol(shred);
let peers = peers
.filter_map(|k| self.nodes[k].contact_info()?.tvu(protocol))
.filter(|addr| socket_addr_space.check(addr))
.collect();
let root_distance = get_root_distance(index, fanout);
Ok((root_distance, peers))
})
}

// Returns the parent node in the turbine broadcast tree.
Expand Down

0 comments on commit 5c403a1

Please sign in to comment.