Skip to content

Commit

Permalink
Add read/write timeouts and avoid stuck peers
Browse files Browse the repository at this point in the history
  • Loading branch information
ikatson committed Dec 4, 2022
1 parent ae847ce commit 9e8f235
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 55 deletions.
89 changes: 37 additions & 52 deletions crates/librqbit/src/peer_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub enum WriterRequest {
#[derive(Default, Copy, Clone)]
pub struct PeerConnectionOptions {
pub connect_timeout: Option<Duration>,
pub read_write_timeout: Option<Duration>,
pub keep_alive_interval: Option<Duration>,
}

Expand All @@ -48,36 +49,21 @@ pub struct PeerConnection<H> {
spawner: BlockingSpawner,
}

// async fn read_one<'a, R: AsyncReadExt + Unpin>(
// mut reader: R,
// read_buf: &'a mut Vec<u8>,
// read_so_far: &mut usize,
// ) -> anyhow::Result<(MessageBorrowed<'a>, usize)> {
// loop {
// match MessageBorrowed::deserialize(&read_buf[..*read_so_far]) {
// Ok((msg, size)) => return Ok((msg, size)),
// Err(MessageDeserializeError::NotEnoughData(d, _)) => {
// if read_buf.len() < *read_so_far + d {
// read_buf.reserve(d);
// read_buf.resize(read_buf.capacity(), 0);
// }

// let size = reader
// .read(&mut read_buf[*read_so_far..])
// .await
// .context("error reading from peer")?;
// if size == 0 {
// anyhow::bail!("disconnected while reading, read so far: {}", *read_so_far)
// }
// *read_so_far += size;
// }
// Err(e) => return Err(e.into()),
// }
// }
// }
async fn with_timeout<T, E>(
timeout_value: Duration,
fut: impl std::future::Future<Output = Result<T, E>>,
) -> anyhow::Result<T>
where
E: Into<anyhow::Error>,
{
timeout(timeout_value, fut)
.await
.with_context(|| format!("timeout at {timeout_value:?}"))?
.map_err(|e| e.into())
}

macro_rules! read_one {
($conn:ident, $read_buf:ident, $read_so_far:ident) => {{
($conn:ident, $read_buf:ident, $read_so_far:ident, $rwtimeout:ident) => {{
let (extended, size) = loop {
match MessageBorrowed::deserialize(&$read_buf[..$read_so_far]) {
Ok((msg, size)) => break (msg, size),
Expand All @@ -87,8 +73,7 @@ macro_rules! read_one {
$read_buf.resize($read_buf.capacity(), 0);
}

let size = $conn
.read(&mut $read_buf[$read_so_far..])
let size = with_timeout($rwtimeout, $conn.read(&mut $read_buf[$read_so_far..]))
.await
.context("error reading from peer")?;
if size == 0 {
Expand Down Expand Up @@ -130,29 +115,31 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
) -> anyhow::Result<()> {
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
let mut conn = match timeout(
self.options
.connect_timeout
.unwrap_or_else(|| Duration::from_secs(10)),
tokio::net::TcpStream::connect(self.addr),
)
.await
{
Ok(conn) => conn.context("error connecting")?,
Err(_) => anyhow::bail!("timeout connecting to {}", self.addr),
};

let rwtimeout = self
.options
.read_write_timeout
.unwrap_or_else(|| Duration::from_secs(10));

let connect_timeout = self
.options
.connect_timeout
.unwrap_or_else(|| Duration::from_secs(10));

let mut conn = with_timeout(connect_timeout, tokio::net::TcpStream::connect(self.addr))
.await
.context("error connecting")?;

let mut write_buf = Vec::<u8>::with_capacity(PIECE_MESSAGE_DEFAULT_LEN);
let handshake = Handshake::new(self.info_hash, self.peer_id);
handshake.serialize(&mut write_buf);
conn.write_all(&write_buf)
with_timeout(rwtimeout, conn.write_all(&write_buf))
.await
.context("error writing handshake")?;
write_buf.clear();

let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2];
let mut read_so_far = conn
.read(&mut read_buf)
let mut read_so_far = with_timeout(rwtimeout, conn.read(&mut read_buf))
.await
.context("error reading handshake")?;
if read_so_far == 0 {
Expand Down Expand Up @@ -188,12 +175,12 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
&my_extended
);
my_extended.serialize(&mut write_buf, None).unwrap();
conn.write_all(&write_buf)
with_timeout(rwtimeout, conn.write_all(&write_buf))
.await
.context("error writing extended handshake")?;
write_buf.clear();

let (extended, size) = read_one!(conn, read_buf, read_so_far);
let (extended, size) = read_one!(conn, read_buf, read_so_far, rwtimeout);
match extended {
Message::Extended(ExtendedMessage::Handshake(h)) => {
trace!("received from {}: {:?}", self.addr, &h);
Expand Down Expand Up @@ -222,8 +209,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
.handler
.serialize_bitfield_message_to_buf(&mut write_buf)
{
write_half
.write_all(&write_buf[..len])
with_timeout(rwtimeout, write_half.write_all(&write_buf[..len]))
.await
.context("error writing bitfield to peer")?;
debug!("sent bitfield to {}", self.addr);
Expand Down Expand Up @@ -256,7 +242,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
self.handler
.read_chunk(chunk, &mut write_buf[preamble_len..])
})
.with_context(|| format!("error reading chunk {:?}", chunk))?;
.with_context(|| format!("error reading chunk {chunk:?}"))?;

uploaded_add = Some(chunk.size);
full_len
Expand All @@ -265,8 +251,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {

debug!("sending to {}: {:?}, length={}", self.addr, &req, len);

write_half
.write_all(&write_buf[..len])
with_timeout(rwtimeout, write_half.write_all(&write_buf[..len]))
.await
.context("error writing the message to peer")?;
write_buf.clear();
Expand All @@ -283,7 +268,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {

let reader = async move {
loop {
let (message, size) = read_one!(read_half, read_buf, read_so_far);
let (message, size) = read_one!(read_half, read_buf, read_so_far, rwtimeout);
trace!("received from {}: {:?}", self.addr, &message);

self.handler
Expand Down
4 changes: 4 additions & 0 deletions crates/librqbit/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,10 @@ impl Session {
builder.peer_connect_timeout(t);
}

if let Some(t) = opts.peer_opts.unwrap_or(self.peer_opts).read_write_timeout {
builder.peer_read_write_timeout(t);
}

let handle = match builder
.start_manager()
.context("error starting torrent manager")
Expand Down
7 changes: 7 additions & 0 deletions crates/librqbit/src/torrent_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::{
struct TorrentManagerOptions {
force_tracker_interval: Option<Duration>,
peer_connect_timeout: Option<Duration>,
peer_read_write_timeout: Option<Duration>,
only_files: Option<Vec<usize>>,
peer_id: Option<Id20>,
overwrite: bool,
Expand Down Expand Up @@ -90,6 +91,11 @@ impl TorrentManagerBuilder {
self
}

pub fn peer_read_write_timeout(&mut self, timeout: Duration) -> &mut Self {
self.options.peer_read_write_timeout = Some(timeout);
self
}

pub fn start_manager(self) -> anyhow::Result<TorrentManagerHandle> {
TorrentManager::start(
self.info,
Expand Down Expand Up @@ -256,6 +262,7 @@ impl TorrentManager {
#[allow(clippy::needless_update)]
let state_options = TorrentStateOptions {
peer_connect_timeout: options.peer_connect_timeout,
peer_read_write_timeout: options.peer_read_write_timeout,
..Default::default()
};

Expand Down
9 changes: 6 additions & 3 deletions crates/librqbit/src/torrent_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ impl StatsSnapshot {
#[derive(Default)]
pub struct TorrentStateOptions {
pub peer_connect_timeout: Option<Duration>,
pub peer_read_write_timeout: Option<Duration>,
}

pub struct TorrentState {
Expand Down Expand Up @@ -286,6 +287,7 @@ impl TorrentState {
loop {
let (addr, out_rx) = peer_queue_rx.recv().await.unwrap();

let permit = state.peer_semaphore.acquire().await.unwrap();
match state.locked.write().peers.states.get_mut(&addr) {
Some(s @ PeerState::Queued) => *s = PeerState::Connecting,
s => {
Expand All @@ -294,15 +296,14 @@ impl TorrentState {
}
};

state.peer_semaphore.acquire().await.unwrap().forget();

let handler = PeerHandler {
addr,
state: state.clone(),
spawner,
};
let options = PeerConnectionOptions {
connect_timeout: state.options.peer_connect_timeout,
read_write_timeout: state.options.peer_read_write_timeout,
..Default::default()
};
let peer_connection = PeerConnection::new(
Expand All @@ -313,7 +314,9 @@ impl TorrentState {
Some(options),
spawner,
);
spawn(format!("manage_peer({})", addr), async move {

permit.forget();
spawn(format!("manage_peer({addr})"), async move {
if let Err(e) = peer_connection.manage_peer(out_rx).await {
debug!("error managing peer {}: {:#}", addr, e)
};
Expand Down
5 changes: 5 additions & 0 deletions crates/rqbit/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ struct Opts {
#[clap(long = "peer-connect-timeout")]
peer_connect_timeout: Option<ParsedDuration>,

/// The connect timeout, e.g. 1s, 1.5s, 100ms etc.
#[clap(long = "peer-read-write-timeout")]
peer_read_write_timeout: Option<ParsedDuration>,

/// How many threads to spawn for the executor.
#[clap(short = 't', long)]
worker_threads: Option<usize>,
Expand Down Expand Up @@ -200,6 +204,7 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
peer_id: None,
peer_opts: Some(PeerConnectionOptions {
connect_timeout: opts.peer_connect_timeout.map(|d| d.0),
read_write_timeout: opts.peer_read_write_timeout.map(|d| d.0),
..Default::default()
}),
};
Expand Down

0 comments on commit 9e8f235

Please sign in to comment.