Skip to content

Commit

Permalink
allow different implementations of connection for DTLS
Browse files Browse the repository at this point in the history
remove need for client transport to need an address
rename transport to ClientTransport since only client uses it
fixed leak cycle issue on client where connection is never dropped
  • Loading branch information
osobiehl committed Apr 7, 2024
1 parent 9bd773a commit d80c56c
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 83 deletions.
6 changes: 3 additions & 3 deletions examples/echo_with_dtls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
/// a look at the test in dtls.rs
extern crate coap;
use coap::client::CoAPClient;
use coap::dtls::DtlsConfig;
use coap::dtls::UdpDtlsConfig;
use coap::request::RequestBuilder;
use coap::Server;
use coap_lite::{CoapRequest, RequestType as Method};
Expand Down Expand Up @@ -60,7 +60,7 @@ async fn main() {
.await
.unwrap();

let dtls_config = DtlsConfig {
let dtls_config = UdpDtlsConfig {
config,
dest_addr: ("127.0.0.1", server_port)
.to_socket_addrs()
Expand All @@ -69,7 +69,7 @@ async fn main() {
.unwrap(),
};

let client = CoAPClient::from_dtls_config(dtls_config)
let client = CoAPClient::from_udp_dtls_config(dtls_config)
.await
.expect("could not create client");
let domain = format!("127.0.0.1:{}", server_port);
Expand Down
90 changes: 54 additions & 36 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#[cfg(feature = "dtls")]
use crate::dtls::{DtlsConfig, DtlsConnection};
use crate::dtls::{DtlsConnection, UdpDtlsConfig};
use crate::request::RequestBuilder;
use alloc::string::String;
use alloc::vec::Vec;
Expand All @@ -18,7 +18,7 @@ use regex::Regex;
use std::{
collections::BTreeMap,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
sync::atomic::AtomicU16,
sync::{atomic::AtomicU16, Weak},
};
use std::{
io::{Error, ErrorKind, Result as IoResult},
Expand All @@ -44,24 +44,24 @@ pub enum ObserveMessage {
use async_trait::async_trait;

#[async_trait]
/// A basic interface for a transport on both the client and transport
/// A basic interface for a transport on the client
/// representing a one-to-one connection between a client and server
/// timeouts and retries do not need to be implemented by the transport
/// if confirmable messages are sent
pub trait Transport: Send + Sync {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)>;
pub trait ClientTransport: Send + Sync {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize>;
async fn send(&self, buf: &[u8]) -> std::io::Result<usize>;
}

trait TransportExt {
async fn receive_packet(&self) -> IoResult<Option<(Packet, SocketAddr)>>;
async fn receive_packet(&self) -> IoResult<Option<Packet>>;
}

impl<T: Transport> TransportExt for T {
async fn receive_packet(&self) -> IoResult<Option<(Packet, SocketAddr)>> {
impl<T: ClientTransport> TransportExt for T {
async fn receive_packet(&self) -> IoResult<Option<Packet>> {
let mut buf = [0; 1500];
let (nread, src) = self.recv(&mut buf).await?;
let parse_opt = Packet::from_bytes(&buf[..nread]).ok().map(|p| (p, src));
let nread = self.recv(&mut buf).await?;
let parse_opt = Packet::from_bytes(&buf[..nread]).ok();
return Ok(parse_opt);
}
}
Expand Down Expand Up @@ -133,22 +133,34 @@ impl TransportSynchronizer {
}
}

async fn receive_loop<T: Transport + 'static>(
transport: Arc<T>,
async fn receive_loop<T: ClientTransport + 'static>(
transport: Weak<T>,
transport_sync: TransportSynchronizer,
) -> std::io::Result<()> {
let err = loop {
let recv_res = transport.receive_packet().await;
let Some(transport_instance) = transport.upgrade() else {
// nobody else is listening so we can drop our reference
return Ok(());
};
// we do a timeout here to ensure that we do not block forever
let Ok(recv_res) = timeout(
Duration::from_millis(300),
transport_instance.receive_packet(),
)
.await
else {
continue;
};
let option_packet = match recv_res {
Err(e) => break e,
Ok(o) => o,
};
let Some((packet, _src)) = option_packet else {
let Some(packet) = option_packet else {
trace!("unexpected malformed packet received");
continue;
};
if let Some(ack) = parse_for_ack(&packet) {
transport.send(&ack).await?;
transport_instance.send(&ack).await?;
}

let MessageClass::Response(_) = packet.header.code else {
Expand All @@ -168,7 +180,7 @@ async fn receive_loop<T: Transport + 'static>(
}
};
let Ok(_) = sender.send(Ok(packet)) else {
debug!("unexpected drop of oneshot sender");
debug!("unexpected drop of sender");
continue;
};
};
Expand All @@ -194,14 +206,14 @@ pub fn make_ack(packet: &Packet) -> Vec<u8> {
}

/// a wrapper for transports responsible for retries and timeouts
struct ClientTransport<T: Transport> {
struct CoapClientTransport<T: ClientTransport> {
pub(crate) transport: Arc<T>,
pub(crate) synchronizer: TransportSynchronizer,
pub(crate) retries: usize,
pub(crate) timeout: Duration,
}

impl<T: Transport> Clone for ClientTransport<T> {
impl<T: ClientTransport> Clone for CoapClientTransport<T> {
fn clone(&self) -> Self {
Self {
transport: self.transport.clone(),
Expand All @@ -212,7 +224,7 @@ impl<T: Transport> Clone for ClientTransport<T> {
}
}

impl<T: Transport> ClientTransport<T> {
impl<T: ClientTransport> CoapClientTransport<T> {
pub const DEFAULT_NUM_RETRIES: usize = 5;
async fn establish_receiver_for(&self, msg: &Packet) -> UnboundedReceiver<IoResult<Packet>> {
let (tx, rx) = unbounded_channel();
Expand Down Expand Up @@ -296,9 +308,12 @@ pub struct UdpTransport {
pub peer_addr: SocketAddr,
}
#[async_trait]
impl Transport for UdpTransport {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
self.socket.recv_from(buf).await
impl ClientTransport for UdpTransport {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize> {
self.socket
.recv_from(buf)
.await
.map(|(recv_size, _addr)| recv_size)
}
async fn send(&self, buf: &[u8]) -> std::io::Result<usize> {
self.socket.send_to(buf, self.peer_addr).await
Expand All @@ -308,13 +323,13 @@ impl Transport for UdpTransport {
/// A CoAP client over UDP. This client can send multicast and broadcasts
pub type UdpCoAPClient = CoAPClient<UdpTransport>;

pub struct CoAPClient<T: Transport> {
transport: ClientTransport<T>,
pub struct CoAPClient<T: ClientTransport> {
transport: CoapClientTransport<T>,
block1_size: usize,
message_id: Arc<AtomicU16>,
}

impl<T: Transport> Clone for CoAPClient<T> {
impl<T: ClientTransport> Clone for CoAPClient<T> {
fn clone(&self) -> Self {
Self {
transport: self.transport.clone(),
Expand Down Expand Up @@ -485,14 +500,14 @@ impl UdpCoAPClient {

#[cfg(feature = "dtls")]
impl CoAPClient<DtlsConnection> {
pub async fn from_dtls_config(config: DtlsConfig) -> IoResult<Self> {
pub async fn from_udp_dtls_config(config: UdpDtlsConfig) -> IoResult<Self> {
Ok(CoAPClient::from_transport(
DtlsConnection::try_new(config).await?,
))
}
}

impl<T: Transport + 'static> CoAPClient<T> {
impl<T: ClientTransport + 'static> CoAPClient<T> {
const MAX_PAYLOAD_BLOCK: usize = 1024;
/// Create a CoAP client with a chosen transport type
Expand All @@ -501,9 +516,12 @@ impl<T: Transport + 'static> CoAPClient<T> {
let transport_arc = Arc::new(transport);
let message_id: u16 = rand::random();
// spawn receive loop to handle responses
tokio::spawn(receive_loop(transport_arc.clone(), synchronizer.clone()));
tokio::spawn(receive_loop(
Arc::downgrade(&transport_arc),
synchronizer.clone(),
));
CoAPClient {
transport: ClientTransport::from_transport(transport_arc.clone(), synchronizer),
transport: CoapClientTransport::from_transport(transport_arc.clone(), synchronizer),
block1_size: Self::MAX_PAYLOAD_BLOCK,
message_id: Arc::new(AtomicU16::new(message_id)),
}
Expand Down Expand Up @@ -1195,8 +1213,8 @@ mod test {
}

#[async_trait]
impl Transport for FaultyUdp {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
impl ClientTransport for FaultyUdp {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize> {
self.udp.recv(buf).await
}

Expand Down Expand Up @@ -1247,7 +1265,7 @@ mod test {
let server_addr = format!("127.0.0.1:{}", server_port);
let mut client = get_faulty_client(
&server_addr,
ClientTransport::<FaultyUdp>::DEFAULT_NUM_RETRIES as u32 + 1,
CoapClientTransport::<FaultyUdp>::DEFAULT_NUM_RETRIES as u32 + 1,
)
.await;
let request_gen = || {
Expand All @@ -1260,7 +1278,7 @@ mod test {
//this request will work, we do this to reset the state of the faulty udp
client.send(request_gen()).await.unwrap();

client.set_transport_retries(ClientTransport::<UdpTransport>::DEFAULT_NUM_RETRIES + 2);
client.set_transport_retries(CoapClientTransport::<UdpTransport>::DEFAULT_NUM_RETRIES + 2);
let resp = client.send(request_gen()).await.unwrap();

assert_eq!(resp.message.payload, b"Rust".to_vec());
Expand All @@ -1287,7 +1305,7 @@ mod test {
assert!(req.is_err());
}

async fn do_wait_request<T: Transport + 'static>(
async fn do_wait_request<T: ClientTransport + 'static>(
client: Arc<CoAPClient<T>>,
path: &str,
token: Vec<u8>,
Expand Down Expand Up @@ -1362,8 +1380,8 @@ mod test {
pub should_fail: Mutex<oneshot::Receiver<std::io::Error>>,
}
#[async_trait]
impl Transport for FaultyReceiver {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
impl ClientTransport for FaultyReceiver {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize> {
let mut mutex = self.should_fail.lock().await;
tokio::select! {
e = mutex.deref_mut() => {
Expand Down
Loading

0 comments on commit d80c56c

Please sign in to comment.