Skip to content

Commit

Permalink
use framed codecs to avoid unbounded buffer (#33)
Browse files Browse the repository at this point in the history
* using stream

* fix tests

* 💄

* 💄 fix review comments

* clean up buffered data

* 💄 fix review comments

* Refactor Delimited to be its own struct

* Add very_long_frame test to ensure behavior

Co-authored-by: Eric Zhang <[email protected]>
  • Loading branch information
cedric05 and ekzhang authored Apr 22, 2022
1 parent e613629 commit 9cd43f4
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 110 deletions.
115 changes: 96 additions & 19 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ path = "src/main.rs"
anyhow = { version = "1.0.56", features = ["backtrace"] }
clap = { version = "3.1.8", features = ["derive", "env"] }
dashmap = "5.2.0"
futures-util = { version = "0.3.21", features = ["sink"] }
hex = "0.4.3"
hmac = "0.12.1"
serde = { version = "1.0.136", features = ["derive"] }
serde_json = "1.0.79"
sha2 = "0.10.2"
tokio = { version = "1.17.0", features = ["rt-multi-thread", "io-util", "macros", "net", "time"] }
tokio-util = { version = "0.7.1", features = ["codec"] }
tracing = "0.1.32"
tracing-subscriber = "0.3.10"
uuid = { version = "0.8.2", features = ["serde", "v4"] }
Expand Down
20 changes: 10 additions & 10 deletions src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
use anyhow::{bail, ensure, Result};
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
use tokio::io::{AsyncBufRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite};
use uuid::Uuid;

use crate::shared::{recv_json_timeout, send_json, ClientMessage, ServerMessage};
use crate::shared::{ClientMessage, Delimited, ServerMessage};

/// Wrapper around a MAC used for authenticating clients that have a secret.
pub struct Authenticator(Hmac<Sha256>);
Expand Down Expand Up @@ -48,13 +48,13 @@ impl Authenticator {
}

/// As the server, send a challenge to the client and validate their response.
pub async fn server_handshake(
pub async fn server_handshake<T: AsyncRead + AsyncWrite + Unpin>(
&self,
stream: &mut (impl AsyncBufRead + AsyncWrite + Unpin),
stream: &mut Delimited<T>,
) -> Result<()> {
let challenge = Uuid::new_v4();
send_json(stream, ServerMessage::Challenge(challenge)).await?;
match recv_json_timeout(stream).await? {
stream.send(ServerMessage::Challenge(challenge)).await?;
match stream.recv_timeout().await? {
Some(ClientMessage::Authenticate(tag)) => {
ensure!(self.validate(&challenge, &tag), "invalid secret");
Ok(())
Expand All @@ -64,16 +64,16 @@ impl Authenticator {
}

/// As the client, answer a challenge to attempt to authenticate with the server.
pub async fn client_handshake(
pub async fn client_handshake<T: AsyncRead + AsyncWrite + Unpin>(
&self,
stream: &mut (impl AsyncBufRead + AsyncWrite + Unpin),
stream: &mut Delimited<T>,
) -> Result<()> {
let challenge = match recv_json_timeout(stream).await? {
let challenge = match stream.recv_timeout().await? {
Some(ServerMessage::Challenge(challenge)) => challenge,
_ => bail!("expected authentication challenge, but no secret was required"),
};
let tag = self.answer(&challenge);
send_json(stream, ClientMessage::Authenticate(tag)).await?;
stream.send(ClientMessage::Authenticate(tag)).await?;
Ok(())
}
}
32 changes: 16 additions & 16 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,21 @@
use std::sync::Arc;

use anyhow::{bail, Context, Result};
use tokio::{io::BufReader, net::TcpStream, time::timeout};

use tokio::io::AsyncWriteExt;
use tokio::{net::TcpStream, time::timeout};
use tracing::{error, info, info_span, warn, Instrument};
use uuid::Uuid;

use crate::auth::Authenticator;
use crate::shared::{
proxy, recv_json, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT,
NETWORK_TIMEOUT,
proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT,
};

/// State structure for the client.
pub struct Client {
/// Control connection to the server.
conn: Option<BufReader<TcpStream>>,
conn: Option<Delimited<TcpStream>>,

/// Destination address of the server.
to: String,
Expand All @@ -43,15 +44,14 @@ impl Client {
port: u16,
secret: Option<&str>,
) -> Result<Self> {
let mut stream = BufReader::new(connect_with_timeout(to, CONTROL_PORT).await?);

let mut stream = Delimited::new(connect_with_timeout(to, CONTROL_PORT).await?);
let auth = secret.map(Authenticator::new);
if let Some(auth) = &auth {
auth.client_handshake(&mut stream).await?;
}

send_json(&mut stream, ClientMessage::Hello(port)).await?;
let remote_port = match recv_json_timeout(&mut stream).await? {
stream.send(ClientMessage::Hello(port)).await?;
let remote_port = match stream.recv_timeout().await? {
Some(ServerMessage::Hello(remote_port)) => remote_port,
Some(ServerMessage::Error(message)) => bail!("server error: {message}"),
Some(ServerMessage::Challenge(_)) => {
Expand Down Expand Up @@ -82,10 +82,8 @@ impl Client {
pub async fn listen(mut self) -> Result<()> {
let mut conn = self.conn.take().unwrap();
let this = Arc::new(self);
let mut buf = Vec::new();
loop {
let msg = recv_json(&mut conn, &mut buf).await?;
match msg {
match conn.recv().await? {
Some(ServerMessage::Hello(_)) => warn!("unexpected hello"),
Some(ServerMessage::Challenge(_)) => warn!("unexpected challenge"),
Some(ServerMessage::Heartbeat) => (),
Expand All @@ -110,14 +108,16 @@ impl Client {

async fn handle_connection(&self, id: Uuid) -> Result<()> {
let mut remote_conn =
BufReader::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?);
Delimited::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?);
if let Some(auth) = &self.auth {
auth.client_handshake(&mut remote_conn).await?;
}
send_json(&mut remote_conn, ClientMessage::Accept(id)).await?;

let local_conn = connect_with_timeout(&self.local_host, self.local_port).await?;
proxy(local_conn, remote_conn).await?;
remote_conn.send(ClientMessage::Accept(id)).await?;
let mut local_conn = connect_with_timeout(&self.local_host, self.local_port).await?;
let parts = remote_conn.into_parts();
debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
local_conn.write_all(&parts.read_buf).await?; // mostly of the cases, this will be empty
proxy(local_conn, parts.io).await?;
Ok(())
}
}
Expand Down
Loading

0 comments on commit 9cd43f4

Please sign in to comment.