Skip to content

Commit

Permalink
refactor with renaming coap-lite Packet to Message
Browse files Browse the repository at this point in the history
  • Loading branch information
michieldwitte committed Jun 22, 2024
1 parent d37f30a commit 759258b
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 68 deletions.
117 changes: 57 additions & 60 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use alloc::vec::Vec;
use coap_lite::{
block_handler::{extending_splice, BlockValue},
error::HandlingError,
CoapOption, CoapRequest, CoapResponse, MessageClass, MessageType, ObserveOption, Packet,
CoapOption, CoapRequest, CoapResponse, MessageClass, MessageType, ObserveOption, Packet as Message,
RequestType as Method, ResponseType as Status,
};
use core::mem;
Expand Down Expand Up @@ -38,9 +38,9 @@ use url::Url;
const DEFAULT_RECEIVE_TIMEOUT_SECONDS: u64 = 2; // 2s

#[derive(Debug, Clone)]
pub struct AddressedPacket {
pub address: SocketAddr,
pub packet: Packet,
pub struct Packet {
pub address: Option<SocketAddr>,
pub message: Message,
}

#[derive(Debug)]
Expand All @@ -55,29 +55,28 @@ use async_trait::async_trait;
/// timeouts and retries do not need to be implemented by the transport
/// if confirmable messages are sent
pub trait ClientTransport: Send + Sync {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)>;
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, Option<SocketAddr>)>;
async fn send(&self, buf: &[u8]) -> std::io::Result<usize>;
}

trait TransportExt {
async fn receive_packet(&self) -> IoResult<Option<AddressedPacket>>;
async fn receive_packet(&self) -> IoResult<Option<Packet>>;
}

impl<T: ClientTransport> TransportExt for T {
async fn receive_packet(&self) -> IoResult<Option<AddressedPacket>> {
async fn receive_packet(&self) -> IoResult<Option<Packet>> {
let mut buf = [0; 1500];
let (nread, address) = self.recv(&mut buf).await?;
// let parse_opt = Packet::from_bytes(&buf[..nread]).ok();
return match Packet::from_bytes(&buf[..nread]).ok() {
Some(packet) => Ok(Some(AddressedPacket {address, packet})),
return match Message::from_bytes(&buf[..nread]).ok() {
Some(message) => Ok(Some(Packet {address, message})),
None => Ok(None),
}
}
}

/// we only use the token as the identifier, and an empty token to represent empty requests
type Token = Vec<u8>;
type PacketRegistry = BTreeMap<Token, UnboundedSender<IoResult<AddressedPacket>>>;
type PacketRegistry = BTreeMap<Token, UnboundedSender<IoResult<Packet>>>;

#[derive(Clone)]
pub struct TransportSynchronizer {
Expand All @@ -93,7 +92,7 @@ impl TransportSynchronizer {
}
}

async fn check_for_error(&self, sender: &UnboundedSender<IoResult<AddressedPacket>>) -> Option<()> {
async fn check_for_error(&self, sender: &UnboundedSender<IoResult<Packet>>) -> Option<()> {
self.fail_error.read().await.as_ref();
if let Some(err) = self.fail_error.read().await.as_ref() {
let _ = sender.send(Err(Self::clone_err(err)));
Expand All @@ -120,7 +119,7 @@ impl TransportSynchronizer {
}
}

pub async fn get_sender(&self, key: &[u8]) -> Option<UnboundedSender<IoResult<AddressedPacket>>> {
pub async fn get_sender(&self, key: &[u8]) -> Option<UnboundedSender<IoResult<Packet>>> {
self.outgoing
.lock()
.await
Expand All @@ -132,12 +131,12 @@ impl TransportSynchronizer {
pub async fn set_sender(
&self,
key: Vec<u8>,
sender: UnboundedSender<IoResult<AddressedPacket>>,
) -> Option<UnboundedSender<IoResult<AddressedPacket>>> {
sender: UnboundedSender<IoResult<Packet>>,
) -> Option<UnboundedSender<IoResult<Packet>>> {
self.check_for_error(&sender).await?;
self.outgoing.lock().await.insert(key, sender)
}
pub async fn remove_sender(&self, key: &[u8]) -> Option<UnboundedSender<IoResult<AddressedPacket>>> {
pub async fn remove_sender(&self, key: &[u8]) -> Option<UnboundedSender<IoResult<Packet>>> {
self.outgoing.lock().await.remove(key)
}
}
Expand Down Expand Up @@ -172,16 +171,16 @@ async fn receive_loop<T: ClientTransport + 'static>(
transport_instance.send(&ack).await?;
}

let MessageClass::Response(_) = packet.packet.header.code else {
let MessageClass::Response(_) = packet.message.header.code else {
continue;
};

let token = packet.packet.get_token();
let token = packet.message.get_token();
let Some(sender) = transport_sync.get_sender(token).await else {
info!("received unexpected response for token {:?}", &token);
continue;
};
match packet.packet.header.code {
match packet.message.header.code {
MessageClass::Response(_) => {}
m => {
debug!("unknown message type {}", m);
Expand All @@ -199,17 +198,17 @@ async fn receive_loop<T: ClientTransport + 'static>(
return e;
}

pub fn parse_for_ack(packet: &AddressedPacket) -> Option<Vec<u8>> {
match (packet.packet.header.get_type(), packet.packet.header.code) {
pub fn parse_for_ack(packet: &Packet) -> Option<Vec<u8>> {
match (packet.message.header.get_type(), packet.message.header.code) {
(MessageType::Confirmable, MessageClass::Response(_)) => Some(make_ack(packet)),
_ => None,
}
}

pub fn make_ack(packet: &AddressedPacket) -> Vec<u8> {
let mut ack = Packet::new();
pub fn make_ack(packet: &Packet) -> Vec<u8> {
let mut ack = Message::new();
ack.header.set_type(MessageType::Acknowledgement);
ack.header.message_id = packet.packet.header.message_id;
ack.header.message_id = packet.message.header.message_id;
ack.header.code = MessageClass::Empty;
return ack.to_bytes().unwrap();
}
Expand All @@ -235,9 +234,9 @@ impl<T: ClientTransport> Clone for CoapClientTransport<T> {

impl<T: ClientTransport> CoapClientTransport<T> {
pub const DEFAULT_NUM_RETRIES: usize = 5;
async fn establish_receiver_for(&self, msg: &Packet) -> UnboundedReceiver<IoResult<AddressedPacket>> {
async fn establish_receiver_for(&self, packet: &Packet) -> UnboundedReceiver<IoResult<Packet>> {
let (tx, rx) = unbounded_channel();
let token = msg.get_token().to_owned();
let token = packet.message.get_token().to_owned();
self.synchronizer.set_sender(token, tx).await;
return rx;
}
Expand All @@ -246,8 +245,8 @@ impl<T: ClientTransport> CoapClientTransport<T> {
async fn try_send_confirmable_message(
&self,
msg: &Packet,
receiver: &mut UnboundedReceiver<IoResult<AddressedPacket>>,
) -> IoResult<AddressedPacket> {
receiver: &mut UnboundedReceiver<IoResult<Packet>>,
) -> IoResult<Packet> {
let mut res = Err(Error::new(ErrorKind::InvalidData, "not enough retries"));
for _ in 0..self.retries {
res = self.try_send_non_confirmable_message(&msg, receiver).await;
Expand All @@ -258,20 +257,20 @@ impl<T: ClientTransport> CoapClientTransport<T> {
return res;
}

fn encode_packet(packet: &Packet) -> IoResult<Vec<u8>> {
packet
fn encode_message(message: &Message) -> IoResult<Vec<u8>> {
message
.to_bytes()
.map_err(|e| std::io::Error::new(ErrorKind::InvalidData, e.to_string()))
}

async fn try_send_non_confirmable_message(
&self,
msg: &Packet,
receiver: &mut UnboundedReceiver<IoResult<AddressedPacket>>,
) -> IoResult<AddressedPacket> {
let bytes = Self::encode_packet(msg)?;
receiver: &mut UnboundedReceiver<IoResult<Packet>>,
) -> IoResult<Packet> {
let bytes = Self::encode_message(&msg.message)?;
self.transport.send(&bytes).await?;
let try_receive: Result<Option<Result<AddressedPacket, Error>>, tokio::time::error::Elapsed> =
let try_receive: Result<Option<Result<Packet, Error>>, tokio::time::error::Elapsed> =
timeout(self.timeout, receiver.recv()).await;
if let Ok(Some(res)) = try_receive {
return res;
Expand All @@ -282,9 +281,9 @@ impl<T: ClientTransport> CoapClientTransport<T> {
async fn do_request_response_for_packet_inner(
&self,
packet: &Packet,
receiver: &mut UnboundedReceiver<IoResult<AddressedPacket>>,
) -> IoResult<AddressedPacket> {
if packet.header.get_type() == MessageType::Confirmable {
receiver: &mut UnboundedReceiver<IoResult<Packet>>,
) -> IoResult<Packet> {
if packet.message.header.get_type() == MessageType::Confirmable {
return self.try_send_confirmable_message(&packet, receiver).await;
} else {
return self
Expand All @@ -298,11 +297,8 @@ impl<T: ClientTransport> CoapClientTransport<T> {
let result = self
.do_request_response_for_packet_inner(packet, &mut receiver)
.await;
self.synchronizer.remove_sender(packet.get_token()).await;
match result {
Ok(addr_packet) => Ok(addr_packet.packet),
Err(err) => Err(err),
}
self.synchronizer.remove_sender(packet.message.get_token()).await;
result
}

pub fn from_transport(transport: Arc<T>, synchronizer: TransportSynchronizer) -> Self {
Expand All @@ -321,10 +317,11 @@ pub struct UdpTransport {
}
#[async_trait]
impl ClientTransport for UdpTransport {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
self.socket
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, Option<SocketAddr>)> {
let (read, addr) = self.socket
.recv_from(buf)
.await
.await?;
return Ok((read, Some(addr)));
}
async fn send(&self, buf: &[u8]) -> std::io::Result<usize> {
self.socket.send_to(buf, self.peer_addr).await
Expand Down Expand Up @@ -353,12 +350,12 @@ impl<T: ClientTransport> Clone for CoAPClient<T> {
/// a receiver used whenever you have a use case involving multiple responses to a single request
pub struct MessageReceiver {
synchronizer: TransportSynchronizer,
receiver: UnboundedReceiver<IoResult<AddressedPacket>>,
receiver: UnboundedReceiver<IoResult<Packet>>,
token: Vec<u8>,
}

impl MessageReceiver {
pub async fn receive(&mut self) -> IoResult<AddressedPacket> {
pub async fn receive(&mut self) -> IoResult<Packet> {
match self.receiver.recv().await {
Some(Ok(packet)) => Ok(packet),
Some(Err(e)) => Err(e),
Expand All @@ -370,7 +367,7 @@ impl MessageReceiver {
}
pub fn new(
synchronizer: TransportSynchronizer,
receiver: UnboundedReceiver<IoResult<AddressedPacket>>,
receiver: UnboundedReceiver<IoResult<Packet>>,
token: &[u8],
) -> Self {
Self {
Expand Down Expand Up @@ -492,7 +489,7 @@ impl UdpCoAPClient {
/// client.send_all_coap(&request, segment).await.unwrap();
/// loop {
/// let recv_packet = receiver.receive().await.unwrap();
/// assert_eq!(recv_packet.packet.payload, b"test-echo".to_vec());
/// assert_eq!(recv_packet.message.payload, b"test-echo".to_vec());
/// }
/// }
/// ```
Expand Down Expand Up @@ -630,7 +627,7 @@ impl<T: ClientTransport + 'static> CoAPClient<T> {
self.receive(&mut request).await
}

pub async fn observe<H: FnMut(Packet) + Send + 'static>(
pub async fn observe<H: FnMut(Message) + Send + 'static>(
&self,
resource_path: &str,
handler: H,
Expand All @@ -645,7 +642,7 @@ impl<T: ClientTransport + 'static> CoAPClient<T> {
/// Observe a resource with the handler and specified timeout using the given transport.
/// Use the oneshot sender to cancel observation. If this sender is dropped without explicitly
/// cancelling it, the observation will continue forever.
pub async fn observe_with_timeout<H: FnMut(Packet) + Send + 'static>(
pub async fn observe_with_timeout<H: FnMut(Message) + Send + 'static>(
&mut self,
resource_path: &str,
handler: H,
Expand All @@ -662,7 +659,7 @@ impl<T: ClientTransport + 'static> CoAPClient<T> {
/// Use this method if you need to set some specific options in your
/// requests. This method will add observe flags and a message id as a fallback
/// Use this method if you plan on re-using the same client for requests
pub async fn observe_with<H: FnMut(Packet) + Send + 'static>(
pub async fn observe_with<H: FnMut(Message) + Send + 'static>(
&self,
request: CoapRequest<SocketAddr>,
mut handler: H,
Expand Down Expand Up @@ -734,17 +731,17 @@ impl<T: ClientTransport + 'static> CoAPClient<T> {

let _ = self
.transport
.do_request_response_for_packet(&deregister_packet.message)
.do_request_response_for_packet(&Packet {address:None, message: deregister_packet.message})
.await;
}

async fn receive_and_handle_message_observe<H: FnMut(Packet) + Send + 'static>(
socket_result: IoResult<AddressedPacket>,
async fn receive_and_handle_message_observe<H: FnMut(Message) + Send + 'static>(
socket_result: IoResult<Packet>,
handler: &mut H,
) {
match socket_result {
Ok(packet) => {
handler(packet.packet);
handler(packet.message);
}
Err(e) => match e.kind() {
ErrorKind::WouldBlock => {
Expand All @@ -766,9 +763,9 @@ impl<T: ClientTransport + 'static> CoAPClient<T> {
) -> IoResult<CoapResponse> {
let response = self
.transport
.do_request_response_for_packet(&request.message)
.do_request_response_for_packet(&Packet {address:None, message:request.message.to_owned()})
.await?;
Ok(CoapResponse { message: response })
Ok(CoapResponse { message: response.message })
}

/// low-level method to send a a request supporting block1 option based on
Expand Down Expand Up @@ -1225,7 +1222,7 @@ mod test {

#[async_trait]
impl ClientTransport for FaultyUdp {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, Option<SocketAddr>)> {
self.udp.recv(buf).await
}

Expand Down Expand Up @@ -1392,7 +1389,7 @@ mod test {
}
#[async_trait]
impl ClientTransport for FaultyReceiver {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, Option<SocketAddr>)> {
let mut mutex = self.should_fail.lock().await;
tokio::select! {
e = mutex.deref_mut() => {
Expand Down
11 changes: 5 additions & 6 deletions src/dtls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use crate::client::ClientTransport;
use crate::server::{Listener, Responder, TransportRequestSender};
use async_trait::async_trait;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::net::SocketAddr;
use std::time::Duration;
use std::{
io::{Error, ErrorKind, Result as IoResult},
Expand Down Expand Up @@ -42,14 +42,13 @@ pub struct DtlsResponse {

#[async_trait]
impl ClientTransport for DtlsConnection {
async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, Option<SocketAddr>)> {
let read = self
.conn
.read(buf, None)
.await
.map_err(|e| Error::new(ErrorKind::Other, e))?;

return Ok((read, SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0)));
return Ok((read, self.conn.remote_addr()));
}

async fn send(&self, buf: &[u8]) -> IoResult<usize> {
Expand Down Expand Up @@ -188,7 +187,7 @@ mod test {
use rcgen::KeyPair;
use std::fs::File;
use std::io::{BufReader, Read};
use std::net::{SocketAddr, ToSocketAddrs};
use std::net::{IpAddr, Ipv4Addr, SocketAddr, ToSocketAddrs};
use std::sync::atomic::AtomicBool;
use tokio::sync::mpsc;
use tokio::time::sleep;
Expand Down Expand Up @@ -636,7 +635,7 @@ mod test {
todo!("not needed");
}
fn remote_addr(&self) -> Option<SocketAddr> {
todo!("not needed")
Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0))
}
async fn close(&self) -> WebrtcResult<()> {
Ok(self.0.close().await?)
Expand Down
Loading

0 comments on commit 759258b

Please sign in to comment.