diff --git a/src/socket.rs b/src/socket.rs index 140dcc289..6d87b0c76 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -2,10 +2,12 @@ use std::cmp::{min, max}; use std::collections::VecDeque; use std::net::{ToSocketAddrs, SocketAddr, UdpSocket}; use std::io::{Result, Error, ErrorKind}; +use std::error::Error as ErrorTrait; use util::{now_microseconds, ewma}; use packet::{Packet, PacketType, Encodable, Decodable, ExtensionType, HEADER_SIZE}; use rand::{self, Rng}; use with_read_timeout::WithReadTimeout; +use std::fmt; // For simplicity's sake, let us assume no packet will ever exceed the // Ethernet maximum transfer unit of 1500 bytes. @@ -34,23 +36,45 @@ pub enum SocketError { ConnectionClosed, ConnectionReset, ConnectionTimedOut, + UserTimedOut, InvalidAddress, InvalidPacket, InvalidReply, } +impl fmt::Display for SocketError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?}", *self) + } +} + +impl ErrorTrait for SocketError { + fn description(&self) -> &str { + use self::SocketError::*; + match *self { + ConnectionClosed => "The socket is closed", + ConnectionReset => "Connection reset by remote peer", + ConnectionTimedOut => "Connection timed out", + UserTimedOut => "User timeout reached", + InvalidAddress => "Invalid address", + InvalidPacket => "Error parsing packet", + InvalidReply => "The remote peer sent an invalid reply", + } + } +} + impl From for Error { fn from(error: SocketError) -> Error { use self::SocketError::*; - let (kind, message) = match error { - ConnectionClosed => (ErrorKind::NotConnected, "The socket is closed"), - ConnectionReset => (ErrorKind::ConnectionReset, "Connection reset by remote peer"), - ConnectionTimedOut => (ErrorKind::TimedOut, "Connection timed out"), - InvalidAddress => (ErrorKind::InvalidInput, "Invalid address"), - InvalidPacket => (ErrorKind::Other, "Error parsing packet"), - InvalidReply => (ErrorKind::ConnectionRefused, "The remote peer sent an invalid reply"), + let kind = match error { + ConnectionClosed => ErrorKind::NotConnected, + ConnectionReset => ErrorKind::ConnectionReset, + ConnectionTimedOut | UserTimedOut => ErrorKind::TimedOut, + InvalidAddress => ErrorKind::InvalidInput, + InvalidPacket => ErrorKind::Other, + InvalidReply => ErrorKind::ConnectionRefused, }; - Error::new(kind, message) + Error::new(kind, error) } } @@ -185,6 +209,9 @@ pub struct UtpSocket { /// Maximum retransmission retries pub max_retransmission_retries: u32, + + /// Used by `set_read_timeout`. + user_read_timeout: i64, } impl UtpSocket { @@ -231,6 +258,7 @@ impl UtpSocket { congestion_timeout: INITIAL_CONGESTION_TIMEOUT, cwnd: INIT_CWND * MSS, max_retransmission_retries: MAX_RETRANSMISSION_RETRIES, + user_read_timeout: 0, } } @@ -335,7 +363,7 @@ impl UtpSocket { // Receive JAKE let mut buf = [0; BUF_SIZE]; while self.state != SocketState::Closed { - try!(self.recv(&mut buf)); + try!(self.recv(&mut buf, false)); } Ok(()) @@ -364,7 +392,8 @@ impl UtpSocket { return Ok((0, self.connected_to)); } - match self.recv(buf) { + let user_read_timeout = self.user_read_timeout; + match self.recv(buf, user_read_timeout != 0) { Ok((0, _src)) => continue, Ok(x) => return Ok(x), Err(e) => return Err(e) @@ -373,11 +402,32 @@ impl UtpSocket { } } - fn recv(&mut self, buf: &mut[u8]) -> Result<(usize, SocketAddr)> { + /// Changes read operations to block for at most the specified number of + /// milliseconds. + pub fn set_read_timeout(&mut self, user_timeout: Option) { + self.user_read_timeout = match user_timeout { + Some(t) => { + if t > 0 { + t + } else { + 0 + } + }, + None => 0 + } + } + + fn recv(&mut self, buf: &mut[u8], use_user_timeout: bool) + -> Result<(usize, SocketAddr)> { let mut b = [0; BUF_SIZE + HEADER_SIZE]; let now = now_microseconds(); let (read, src); let mut retries = 0; + let user_timeout = if use_user_timeout { + self.user_read_timeout + } else { + 0 + }; // Try to receive a packet and handle timeouts loop { @@ -387,17 +437,32 @@ impl UtpSocket { return Err(Error::from(SocketError::ConnectionTimedOut)); } - let timeout = if self.state != SocketState::New { + let congestion_timeout = if self.state != SocketState::New { debug!("setting read timeout of {} ms", self.congestion_timeout); self.congestion_timeout as i64 } else { 0 }; + let timeout = if user_timeout != 0 { + if congestion_timeout != 0 { + use std::cmp::min; + min(congestion_timeout, user_timeout) + } else { + user_timeout + } + } else { + congestion_timeout + }; + + if user_timeout != 0 + && ((now_microseconds() - now) / 1000) as i64 >= user_timeout { + return Err(Error::from(SocketError::UserTimedOut)); + } match self.socket.recv_timeout(&mut b, timeout) { Ok((r, s)) => { read = r; src = s; break }, Err(ref e) if (e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut) => { debug!("recv_from timed out"); - try!(self.handle_receive_timeout()); + try!(self.handle_receive_timeout(user_timeout != 0)); }, Err(e) => return Err(e), }; @@ -438,8 +503,11 @@ impl UtpSocket { Ok((read, src)) } - fn handle_receive_timeout(&mut self) -> Result<()> { - self.congestion_timeout = self.congestion_timeout * 2; + fn handle_receive_timeout(&mut self, keep_current_timeout: bool) + -> Result<()> { + if !keep_current_timeout { + self.congestion_timeout *= 2 + } self.cwnd = MSS; // There are three possible cases here: @@ -605,7 +673,7 @@ impl UtpSocket { let mut buf = [0u8; BUF_SIZE]; while !self.send_window.is_empty() { debug!("packets in send window: {}", self.send_window.len()); - try!(self.recv(&mut buf)); + try!(self.recv(&mut buf, false)); } Ok(()) @@ -637,7 +705,7 @@ impl UtpSocket { debug!("self.duplicate_ack_count: {}", self.duplicate_ack_count); debug!("now_microseconds() - now = {}", now_microseconds() - now); let mut buf = [0; BUF_SIZE]; - try!(self.recv(&mut buf)); + try!(self.recv(&mut buf, false)); } debug!("out: now_microseconds() - now = {}", now_microseconds() - now); @@ -1355,7 +1423,7 @@ mod test { thread::spawn(move || { // Make the server listen for incoming connections let mut buf = [0u8; BUF_SIZE]; - let _resp = server.recv(&mut buf); + let _resp = server.recv(&mut buf, false); tx.send(server.seq_nr).unwrap(); // Close the connection @@ -1719,7 +1787,7 @@ mod test { let mut buf = [0; BUF_SIZE]; // Expect SYN - iotry!(server.recv(&mut buf)); + iotry!(server.recv(&mut buf, false)); // Receive data let data_packet = match server.socket.recv_from(&mut buf) { @@ -1792,7 +1860,7 @@ mod test { }); let mut buf = [0u8; BUF_SIZE]; - server.recv(&mut buf).unwrap(); + server.recv(&mut buf, false).unwrap(); // After establishing a new connection, the server's ids are a mirror of the client's. assert_eq!(server.receiver_connection_id, server.sender_connection_id + 1); @@ -1899,7 +1967,7 @@ mod test { }); let mut buf = [0u8; BUF_SIZE]; - iotry!(server.recv(&mut buf)); + iotry!(server.recv(&mut buf, false)); // After establishing a new connection, the server's ids are a mirror of the client's. assert_eq!(server.receiver_connection_id, server.sender_connection_id + 1); @@ -2233,7 +2301,7 @@ mod test { let mut buf = [0; BUF_SIZE]; // Accept connection - iotry!(server.recv(&mut buf)); + iotry!(server.recv(&mut buf, false)); // Send FIN without acknowledging packets received let mut packet = Packet::new(); @@ -2348,7 +2416,7 @@ mod test { // Try to receive ACKs, time out too many times on flush, and fail with `TimedOut` let mut buf = [0; BUF_SIZE]; - match server.recv(&mut buf) { + match server.recv(&mut buf, false) { Err(ref e) if e.kind() == ErrorKind::TimedOut => (), x => panic!("Expected Err(TimedOut), got {:?}", x), }