diff --git a/src/socket.rs b/src/socket.rs index ea5237e87..acdbd5cfb 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -2,7 +2,7 @@ use std::cmp::{min, max}; use std::collections::VecDeque; use std::net::{ToSocketAddrs, SocketAddr, UdpSocket}; use std::io::{Result, Error, ErrorKind}; -use util::{now_microseconds, ewma, abs_diff}; +use util::{now_microseconds, ewma, abs_diff, Sequence}; use packet::{Packet, PacketType, Encodable, Decodable, ExtensionType, HEADER_SIZE}; use rand::{self, Rng}; use time::SteadyTime; @@ -201,6 +201,10 @@ pub struct UtpSocket { last_congestion_update: SteadyTime, retries: u32, + + /// The first 'State' packet we sent if we are a server (it may + /// need to be resent if the network dropped it). + state_packet: Option, } impl UtpSocket { @@ -253,6 +257,7 @@ impl UtpSocket { user_read_timeout: 0, last_congestion_update: SteadyTime::now(), retries: 0, + state_packet: None, } } @@ -305,11 +310,11 @@ impl UtpSocket { packet.set_connection_id(socket.receiver_connection_id); packet.set_seq_nr(socket.seq_nr); - let mut len = 0; let mut buf = [0; BUF_SIZE]; - let mut syn_timeout = socket.congestion_timeout; - for _ in 0..MAX_SYN_RETRIES { + let mut syn_retries = 0; + + while syn_retries < MAX_SYN_RETRIES { packet.set_timestamp_microseconds(now_microseconds()); // Send packet @@ -323,29 +328,34 @@ impl UtpSocket { .set_read_timeout(Some(Duration::from_millis(syn_timeout))) .expect("Error setting read timeout"); match socket.socket.recv_from(&mut buf) { - Ok((read, src)) => { - socket.connected_to = src; - len = read; - break; - } + Ok((read, addr)) => { + let packet = try!(Packet::from_bytes(&buf[..read]).or(Err(SocketError::InvalidPacket))); + + socket.connected_to = addr; + + if packet.get_type() != PacketType::State { + // The network might have dropped the `State` packet + // from the peer, so we need to ask for it again. + syn_retries += 1; + continue; + } + + try!(socket.handle_packet(&packet, addr)); + + return Ok(socket); + }, Err(ref e) if (e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut) => { debug!("Timed out, retrying"); syn_timeout *= 2; + syn_retries += 1; continue; } Err(e) => return Err(e), }; } - let addr = socket.connected_to; - let packet = try!(Packet::from_bytes(&buf[..len]).or(Err(SocketError::InvalidPacket))); - debug!("received {:?}", packet); - try!(socket.handle_packet(&packet, addr)); - - debug!("connected to: {}", socket.connected_to); - - Ok(socket) + Err(Error::from(SocketError::ConnectionTimedOut)) } /// If you have already prepared UDP sockets at each end (e.g. you're doing @@ -678,9 +688,10 @@ impl UtpSocket { // Insert data packet into the incoming buffer if it isn't a duplicate of a previously // discarded packet - if packet.get_type() == PacketType::Data && - packet.seq_nr().wrapping_sub(self.last_dropped) > 0 { - self.insert_into_buffer(packet); + if packet.get_type() == PacketType::Data { + if Sequence::less(self.last_dropped, packet.seq_nr()) { + self.insert_into_buffer(packet); + } } // Flush incoming buffer if possible @@ -798,7 +809,8 @@ impl UtpSocket { if !self.incoming_buffer.is_empty() && (self.ack_nr == self.incoming_buffer[0].seq_nr() || - self.ack_nr + 1 == self.incoming_buffer[0].seq_nr()) { + self.ack_nr.wrapping_add(1) == self.incoming_buffer[0].seq_nr()) + { let flushed = unsafe_copy(&self.incoming_buffer[0].payload[..], buf); if flushed == self.incoming_buffer[0].payload.len() { @@ -1071,8 +1083,15 @@ impl UtpSocket { fn handle_packet(&mut self, packet: &Packet, src: SocketAddr) -> Result> { debug!("({:?}, {:?})", self.state, packet.get_type()); + let is_data_or_fin = packet.get_type() == PacketType::Data + || packet.get_type() == PacketType::Fin; + // Acknowledge only if the packet strictly follows the previous one - if packet.seq_nr().wrapping_sub(self.ack_nr) == 1 { + // and only if it is a payload packet. The restriction on PacketType + // is due to all other (non Data) packets are assigned seq_nr the + // same as the next Data packet, thus we could acknowledge what + // we have not received yet. + if is_data_or_fin && packet.seq_nr().wrapping_sub(self.ack_nr) == 1 { self.ack_nr = packet.seq_nr(); } @@ -1103,21 +1122,31 @@ impl UtpSocket { self.state = SocketState::Connected; self.last_dropped = self.ack_nr; - Ok(Some(self.prepare_reply(packet, PacketType::State))) + self.state_packet = Some(self.prepare_reply(packet, PacketType::State)); + + // Advance the self.seq_nr (the sequence number of the next packet), + // this is because the other end will use the `seq_nr` of this state + // packet as his `self.last_acked` + self.seq_nr = self.seq_nr.wrapping_add(1); + + Ok(self.state_packet.clone()) } (SocketState::Connected, PacketType::Syn) if self.connected_to == src => { // The other end might have sent another Syn packet because // a reply to the first one did not arrive within a timeout // caused by network congestion. - Ok(None) + Ok(self.state_packet.clone()) + } + (_, PacketType::Syn) => { + Ok(Some(self.prepare_reply(packet, PacketType::Reset))) } - (_, PacketType::Syn) => Ok(Some(self.prepare_reply(packet, PacketType::Reset))), (SocketState::SynSent, PacketType::State) => { self.connected_to = src; self.ack_nr = packet.seq_nr(); self.seq_nr += 1; self.state = SocketState::Connected; self.last_acked = packet.ack_nr(); + self.last_dropped = packet.seq_nr(); self.last_acked_timestamp = now_microseconds(); Ok(None) } @@ -1315,7 +1344,7 @@ impl UtpSocket { // already resent. if !self.send_window.is_empty() && self.duplicate_ack_count == 3 && !packet.extensions.iter().any(|ext| ext.get_type() == ExtensionType::SelectiveAck) { - self.resend_lost_packet(packet.ack_nr() + 1); + self.resend_lost_packet(packet.ack_nr().wrapping_add(1)); } // Packet lost, halve the congestion window @@ -1484,10 +1513,6 @@ mod test { ($e:expr) => (match $e { Ok(e) => e, Err(e) => panic!("{:?}", e) }) } - macro_rules! mutetry { - ($e:expr) => (match $e { Ok(e) => e, Err(e) => println!("{:?}", e) }) - } - fn next_test_port() -> u16 { use std::sync::atomic::{AtomicUsize, ATOMIC_USIZE_INIT, Ordering}; static NEXT_OFFSET: AtomicUsize = ATOMIC_USIZE_INIT; @@ -1689,7 +1714,7 @@ mod test { let sender_seq_nr = rx.recv().unwrap(); let ack_nr = client.ack_nr; assert!(ack_nr != 0); - assert!(ack_nr == sender_seq_nr); + assert!(ack_nr.wrapping_add(1) == sender_seq_nr); assert!(client.close().is_ok()); // The reply to both connect (SYN) and close (FIN) should be @@ -1773,8 +1798,11 @@ mod test { assert!(response.ack_nr() == packet.seq_nr()); // Responses with no payload should not increase the sequence number + // unless it's the State packet sent to acknowledge the Syn packet as + // explained at + // assert!(response.payload.is_empty()); - assert!(response.seq_nr() == old_response.seq_nr()); + assert!(response.seq_nr() == old_response.seq_nr().wrapping_add(1)); // } // fn test_connection_teardown() { @@ -2521,8 +2549,8 @@ mod test { }); match UtpSocket::connect(server_addr) { - Err(ref e) if e.kind() == ErrorKind::ConnectionRefused => (), // OK - Err(e) => panic!("Expected ErrorKind::ConnectionRefused, got {:?}", e), + Err(ref e) if e.kind() == ErrorKind::TimedOut => (), // OK + Err(e) => panic!("Expected ErrorKind::TimedOut, got {:?}", e), Ok(_) => panic!("Expected Err, got Ok"), } @@ -2874,7 +2902,7 @@ mod test { assert!(child.join().is_ok()); } - const NETWORK_NODE_COUNT: usize = 40; + const NETWORK_NODE_COUNT: usize = 20; const NETWORK_MSG_COUNT: usize = 5; fn test_network(exchange: fn(&mut UtpSocket) -> ()) { @@ -2893,31 +2921,37 @@ mod test { } fn run(&mut self, exchange: fn(&mut UtpSocket) -> (), peer_addrs: Vec) { + let connect_cnt = peer_addrs.len(); + let connect_join_handle = spawn(move || { let mut send_jhs = Vec::>::new(); for peer_addr in peer_addrs { - let mut socket = iotry!(UtpSocket::connect(peer_addr)); - send_jhs.push(spawn(move || exchange(&mut socket))); + send_jhs.push(spawn(move || { + let mut socket = iotry!(UtpSocket::connect(peer_addr)); + exchange(&mut socket); + })); } for jh in send_jhs { - mutetry!(jh.join()); + iotry!(jh.join()); } }); let mut recv_jhs = Vec::>::new(); - for _ in 0..NODE_COUNT - 1 { + for _ in 0..NODE_COUNT-1-connect_cnt { let mut socket = iotry!(self.listener.accept()).0; - recv_jhs.push(spawn(move || exchange(&mut socket))); + recv_jhs.push(spawn(move || { + exchange(&mut socket); + })); } for jh in recv_jhs { - mutetry!(jh.join()); + iotry!(jh.join()); } - mutetry!(connect_join_handle.join()); + iotry!(connect_join_handle.join()); } } @@ -2938,9 +2972,7 @@ mod test { let mut addrs = Vec::::new(); for ai in 0..listening_addrs.len() { - if ai == ni { - continue; - } + if ai <= ni { continue } addrs.push(listening_addrs[ai].clone()); } @@ -2952,47 +2984,56 @@ mod test { } for handle in join_handles { - mutetry!(handle.join()); + iotry!(handle.join()); } } #[test] fn test_network_no_timeout() { - static MSG_COUNT: usize = NETWORK_MSG_COUNT; - static TX_BUF: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + static MSG_COUNT: usize = NETWORK_MSG_COUNT; + + fn make_buf(i: usize) -> [u8; 10] { + let mut buf = [0; 10]; + for j in 0..10 { + buf[j] = (i + j) as u8; + } + buf + } fn sequential_exchange(socket: &mut UtpSocket) { let mut i = 0; + let from = socket.socket.local_addr().map(|addr| addr.port()).unwrap_or(0); + let to = socket.connected_to.port(); + while i < MSG_COUNT { - assert_eq!(iotry!(socket.send_to(&TX_BUF)), TX_BUF.len()); + let tx_buf = make_buf(i); + assert_eq!(iotry!(socket.send_to(&tx_buf)), tx_buf.len()); let mut buf = [0; 10]; + match socket.recv_from(&mut buf) { Ok((cnt, _)) => { - if socket.state == SocketState::Connected { - assert_eq!(cnt, 10); - assert_eq!(buf, TX_BUF); - } else { - println!("socket is in an invliad state of {:?} from {:?} to {:?}", - socket.state, - socket.socket.local_addr(), - socket.connected_to); + if cnt == 0 { + if socket.state != SocketState::Connected { + panic!("socket is in an invalid state \"{:?}\" from {:?} to {:?}", + socket.state, from, to); + } } - } - Err(ref err) if err.kind() == ErrorKind::NotConnected && i == MSG_COUNT - 1 => { - // This is OK as it can happen on a congested network. - println!("connection not established due to a congested network"); - break; - } - Err(_err) => { - println!("failed in sending from {:?} to {:?}", - socket.socket.local_addr(), - socket.connected_to); - // panic!("Recv error {:?}", err); + assert_eq!(cnt, 10); + if buf != make_buf(i) { + panic!("expected {:?} but received {:?} in recv step {}", + make_buf(i), + buf, + i); + } + }, + Err(err) => { + panic!("Recv error {:?}; from {:?} to {:?}", err, from, to); } } i += 1; } } + for i in 0..100 { println!("------ Testing Network iteration {}", i); test_network(sequential_exchange); @@ -3001,18 +3042,31 @@ mod test { #[test] fn test_network_with_timeout() { - static MSG_COUNT: usize = NETWORK_MSG_COUNT; - static TX_BUF: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + static MSG_COUNT: usize = NETWORK_MSG_COUNT; + + fn make_buf(i: usize) -> [u8; 10] { + let mut buf = [0; 10]; + for j in 0..10 { + buf[j] = (i + j) as u8; + } + buf + } fn timeout_exchange(socket: &mut UtpSocket) { socket.set_read_timeout(Some(50)); let mut recv_cnt = 0; let mut send_cnt = 0; + + let from = socket.socket.local_addr().map(|addr| addr.port()).unwrap_or(0); + let to = socket.connected_to.port(); + loop { if send_cnt < MSG_COUNT { - match socket.send_to(&TX_BUF) { + let tx_buf = make_buf(send_cnt); + + match socket.send_to(&tx_buf) { Ok(cnt) => { - assert_eq!(cnt, TX_BUF.len()); + assert_eq!(cnt, tx_buf.len()); send_cnt += 1; } Err(ref e) if e.kind() == ErrorKind::TimedOut => {} @@ -3022,36 +3076,87 @@ mod test { } } if recv_cnt < MSG_COUNT { + let exp_buf = make_buf(recv_cnt); + let mut buf = [0; 10]; match socket.recv_from(&mut buf) { Ok((cnt, _)) => { - recv_cnt += 1; if cnt == 0 { - // Zero size msg will be returned if the socket is in Closed state - println!("received a message size of zero"); - continue; + if socket.state != SocketState::Connected { + panic!("socket is in an invalid state \"{:?}\" \ + from {:?} to {:?} in receive #{}", + socket.state, from, to, recv_cnt); + } } else { - assert_eq!(cnt, TX_BUF.len()); - assert_eq!(buf, TX_BUF); + assert_eq!(cnt, exp_buf.len()); + assert_eq!(buf, exp_buf); + recv_cnt += 1; } - } - Err(ref e) if e.kind() == ErrorKind::TimedOut => {} - Err(ref e) if e.kind() == ErrorKind::NotConnected && - send_cnt == MSG_COUNT => { - break; - } + }, + Err(ref e) if e.kind() == ErrorKind::TimedOut => { + }, Err(e) => { panic!("{:?} recv_cnt={} send_cnt={}", e, recv_cnt, send_cnt); } } } + if send_cnt == MSG_COUNT && recv_cnt == MSG_COUNT { break; } } } - test_network(timeout_exchange); + + for i in 0..100 { + println!("------ Testing Network iteration {}", i); + test_network(timeout_exchange); + } + } + + #[test] + fn test_send_client_to_server() { + let listener = iotry!(UtpListener::bind("127.0.0.1:0")); + let server_addr = iotry!(listener.local_addr()); + + static TX_BUF: [u8; 10] = [0,1,2,3,4,5,6,7,8,9]; + + let client_t = thread::spawn(move || { + let mut client = iotry!(UtpSocket::connect(server_addr)); + assert_eq!(iotry!(client.send_to(&TX_BUF)), TX_BUF.len()); + }); + + let mut server = iotry!(listener.accept()).0; + + let mut buf = [0; 10]; + iotry!(server.recv_from(&mut buf)); + assert_eq!(buf, TX_BUF); + + assert!(client_t.join().is_ok()); + } + + // Test data exchange + #[test] + fn test_send_server_to_client() { + let listener = iotry!(UtpListener::bind("127.0.0.1:0")); + let server_addr = iotry!(listener.local_addr()); + + static TX_BUF: [u8; 10] = [0,1,2,3,4,5,6,7,8,9]; + + let client_t = thread::spawn(move || { + let mut client = iotry!(UtpSocket::connect(server_addr)); + let mut buf = [0; 10]; + iotry!(client.recv_from(&mut buf)); + assert_eq!(buf, TX_BUF); + }); + + let mut server = iotry!(listener.accept()).0; + + assert_eq!(iotry!(server.send_to(&TX_BUF)), TX_BUF.len()); + let fr = server.flush(); + assert!(fr.is_ok()); + + assert!(client_t.join().is_ok()); } // Test data exchange @@ -3076,6 +3181,7 @@ mod test { let mut buf = [0; 10]; iotry!(server.recv_from(&mut buf)); assert_eq!(buf, TX_BUF); + let _ = server.flush(); assert!(client_t.join().is_ok()); } diff --git a/src/util.rs b/src/util.rs index cc736825d..74a915351 100644 --- a/src/util.rs +++ b/src/util.rs @@ -24,6 +24,17 @@ pub fn abs_diff(a: u32, b: u32) -> u32 { } } +pub struct Sequence; + +impl Sequence { + pub fn less(a: u16, b: u16) -> bool { + const MIDDLE_VALUE: u16 = ::std::u16::MAX / 2; + + ((b > a) && (b - a <= MIDDLE_VALUE)) || + ((b < a) && (a - b > MIDDLE_VALUE)) + } +} + #[cfg(test)] mod test { use super::ewma;