Skip to content
This repository has been archived by the owner on Jun 27, 2022. It is now read-only.

add read timeout #10

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 89 additions & 19 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use util::{now_microseconds, ewma, abs_diff};
use packet::{Packet, PacketType, Encodable, Decodable, ExtensionType, HEADER_SIZE};
use rand::{self, Rng};
use time::SteadyTime;
use time;
use std::time::Duration;

// For simplicity's sake, let us assume no packet will ever exceed the
Expand Down Expand Up @@ -35,6 +36,7 @@ pub enum SocketError {
ConnectionClosed,
ConnectionReset,
ConnectionTimedOut,
UserTimedOut,
InvalidAddress,
InvalidPacket,
InvalidReply,
Expand All @@ -47,7 +49,7 @@ impl From<SocketError> for Error {
let (kind, message) = match error {
ConnectionClosed => (ErrorKind::NotConnected, "The socket is closed"),
ConnectionReset => (ErrorKind::ConnectionReset, "Connection reset by remote peer"),
ConnectionTimedOut => (ErrorKind::TimedOut, "Connection timed out"),
ConnectionTimedOut | UserTimedOut => (ErrorKind::TimedOut, "Connection timed out"),
InvalidAddress => (ErrorKind::InvalidInput, "Invalid address"),
InvalidPacket => (ErrorKind::Other, "Error parsing packet"),
InvalidReply => (ErrorKind::ConnectionRefused, "The remote peer sent an invalid reply"),
Expand Down Expand Up @@ -185,6 +187,14 @@ pub struct UtpSocket {

/// Maximum retransmission retries
pub max_retransmission_retries: u32,

/// Used by `set_read_timeout`.
user_read_timeout: u64,

/// The last time congestion algorithm was updated/handled-a-timeout
last_congestion_update: SteadyTime,

retries: u32,
}

impl UtpSocket {
Expand Down Expand Up @@ -233,6 +243,9 @@ impl UtpSocket {
congestion_timeout: INITIAL_CONGESTION_TIMEOUT,
cwnd: INIT_CWND * MSS,
max_retransmission_retries: MAX_RETRANSMISSION_RETRIES,
user_read_timeout: 0,
last_congestion_update: SteadyTime::now(),
retries: 0,
}
}

Expand Down Expand Up @@ -348,7 +361,7 @@ impl UtpSocket {
// Receive JAKE
let mut buf = [0; BUF_SIZE];
while self.state != SocketState::Closed {
try!(self.recv(&mut buf));
try!(self.recv(&mut buf, false));
}

Ok(())
Expand Down Expand Up @@ -377,7 +390,7 @@ impl UtpSocket {
return Ok((0, self.connected_to));
}

match self.recv(buf) {
match self.recv(buf, true) {
Ok((0, _src)) => continue,
Ok(x) => return Ok(x),
Err(e) => return Err(e)
Expand All @@ -386,41 +399,98 @@ impl UtpSocket {
}
}

fn recv(&mut self, buf: &mut[u8]) -> Result<(usize, SocketAddr)> {
/// Changes read operations to block for at most the specified number of
/// milliseconds.
pub fn set_read_timeout(&mut self, user_timeout: Option<u64>) {
self.user_read_timeout = match user_timeout {
Some(t) => {
if t > 0 {
t
} else {
0
}
},
None => 0
}
}

fn recv(&mut self, buf: &mut[u8], use_user_timeout: bool)
-> Result<(usize, SocketAddr)> {
let mut b = [0; BUF_SIZE + HEADER_SIZE];
let now = SteadyTime::now();
let (read, src);
let mut retries = 0;
let user_timeout = if use_user_timeout {
self.user_read_timeout
} else {
0
};
let use_user_timeout = user_timeout != 0;

// Try to receive a packet and handle timeouts
loop {
// Abort loop if the current try exceeds the maximum number of retransmission retries.
if retries >= self.max_retransmission_retries {
if self.retries >= self.max_retransmission_retries {
self.state = SocketState::Closed;
return Err(Error::from(SocketError::ConnectionTimedOut));
}

let timeout = if self.state != SocketState::New {
let timeout;
let congestion_timeout = if self.state != SocketState::New {
debug!("setting read timeout of {} ms", self.congestion_timeout);
Some(Duration::from_millis(self.congestion_timeout))
} else { None };
{
let user_timeout = Duration::from_millis(user_timeout);
timeout = if use_user_timeout {
match congestion_timeout {
Some(congestion_timeout) => {
use std::cmp::min;
Some(min(congestion_timeout, user_timeout))
},
None => Some(user_timeout),
}
} else {
congestion_timeout
};
}

if use_user_timeout {
let user_timeout
= time::Duration::milliseconds(user_timeout as i64);
if (SteadyTime::now() - now) >= user_timeout {
return Err(Error::from(SocketError::UserTimedOut));
}
}

self.socket.set_read_timeout(timeout).expect("Error setting read timeout");
match self.socket.recv_from(&mut b) {
Ok((r, s)) => { read = r; src = s; break },
Err(ref e) if (e.kind() == ErrorKind::WouldBlock ||
e.kind() == ErrorKind::TimedOut) => {
debug!("recv_from timed out");
try!(self.handle_receive_timeout());
let now = SteadyTime::now();
let congestion_timeout = {
time::Duration::milliseconds(self.congestion_timeout
as i64)
};
if !use_user_timeout
|| ((now - self.last_congestion_update)
>= congestion_timeout) {
self.last_congestion_update = now;
try!(self.handle_receive_timeout());
self.retries += 1;
}
},
Err(e) => return Err(e),
};

let elapsed = (SteadyTime::now() - now).num_milliseconds();
debug!("{} ms elapsed", elapsed);
retries += 1;
}

self.last_congestion_update = SteadyTime::now();
self.retries = 0;

// Decode received data into a packet
let packet = match Packet::from_bytes(&b[..read]) {
Ok(packet) => packet,
Expand Down Expand Up @@ -453,7 +523,7 @@ impl UtpSocket {
}

fn handle_receive_timeout(&mut self) -> Result<()> {
self.congestion_timeout = self.congestion_timeout * 2;
self.congestion_timeout *= 2;
self.cwnd = MSS;

// There are three possible cases here:
Expand Down Expand Up @@ -615,7 +685,7 @@ impl UtpSocket {
let mut buf = [0u8; BUF_SIZE];
while !self.send_window.is_empty() {
debug!("packets in send window: {}", self.send_window.len());
try!(self.recv(&mut buf));
try!(self.recv(&mut buf, false));
}

Ok(())
Expand Down Expand Up @@ -647,7 +717,7 @@ impl UtpSocket {
debug!("self.duplicate_ack_count: {}", self.duplicate_ack_count);
debug!("now_microseconds() - now = {}", now_microseconds() - now);
let mut buf = [0; BUF_SIZE];
try!(self.recv(&mut buf));
try!(self.recv(&mut buf, false));
}
debug!("out: now_microseconds() - now = {}", now_microseconds() - now);

Expand Down Expand Up @@ -1369,7 +1439,7 @@ mod test {
let child = thread::spawn(move || {
// Make the server listen for incoming connections
let mut buf = [0u8; BUF_SIZE];
let _resp = server.recv(&mut buf);
let _resp = server.recv(&mut buf, false);
tx.send(server.seq_nr).unwrap();

// Close the connection
Expand Down Expand Up @@ -1737,7 +1807,7 @@ mod test {

let mut buf = [0; BUF_SIZE];
// Expect SYN
iotry!(server.recv(&mut buf));
iotry!(server.recv(&mut buf, false));

// Receive data
let data_packet = match server.socket.recv_from(&mut buf) {
Expand Down Expand Up @@ -1812,7 +1882,7 @@ mod test {
});

let mut buf = [0u8; BUF_SIZE];
server.recv(&mut buf).unwrap();
server.recv(&mut buf, false).unwrap();
// After establishing a new connection, the server's ids are a mirror of the client's.
assert_eq!(server.receiver_connection_id, server.sender_connection_id + 1);

Expand Down Expand Up @@ -1921,7 +1991,7 @@ mod test {
});

let mut buf = [0u8; BUF_SIZE];
iotry!(server.recv(&mut buf));
iotry!(server.recv(&mut buf, false));
// After establishing a new connection, the server's ids are a mirror of the client's.
assert_eq!(server.receiver_connection_id, server.sender_connection_id + 1);

Expand Down Expand Up @@ -2271,7 +2341,7 @@ mod test {
let mut buf = [0; BUF_SIZE];

// Accept connection
iotry!(server.recv(&mut buf));
iotry!(server.recv(&mut buf, false));

// Send FIN without acknowledging packets received
let mut packet = Packet::new();
Expand Down Expand Up @@ -2357,7 +2427,7 @@ mod test {

// Wait for a connection to be established
let mut buf = [0; 1024];
iotry!(server.recv(&mut buf));
iotry!(server.recv(&mut buf, false));

// `peer_addr` should succeed and be equal to the client's address
assert!(server.peer_addr().is_ok());
Expand Down Expand Up @@ -2426,7 +2496,7 @@ mod test {

// Try to receive ACKs, time out too many times on flush, and fail with `TimedOut`
let mut buf = [0; BUF_SIZE];
match server.recv(&mut buf) {
match server.recv(&mut buf, false) {
Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
x => panic!("Expected Err(TimedOut), got {:?}", x),
}
Expand Down