Skip to content

Commit

Permalink
feat: add tcp_timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
spacemeowx2 authored and llc1123 committed Jun 9, 2021
1 parent e4ca77a commit b76f9c5
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 2 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ hex = "0.4.3"
redis = { version = "0.20.0", features = ["tokio-comp"] }
tokio-openssl = "0.6.1"
openssl = { version = "0.10.34", features = ["vendored"] }
pin-project-lite = "0.2.6"
4 changes: 4 additions & 0 deletions src/relay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::common::UdpStream;
use crate::inbound::{Inbound, InboundAccept, InboundRequest};
use crate::outbound::Outbound;
use crate::utils::count_stream::CountStream;
use crate::utils::timeout_stream::TimeoutStream;
use anyhow::{anyhow, bail, Context, Error, Result};
use futures::{SinkExt, StreamExt, TryStreamExt};
use log::error;
Expand All @@ -20,6 +21,7 @@ pub struct Relay<I, O> {
inbound: Arc<I>,
outbound: Arc<O>,
tcp_nodelay: bool,
pub tcp_timeout: Option<Duration>,
}

impl<I, O> Relay<I, O>
Expand Down Expand Up @@ -78,6 +80,7 @@ where
inbound: Arc::new(inbound),
outbound: Arc::new(outbound),
tcp_nodelay,
tcp_timeout: None,
}
}
pub async fn serve_trojan(&self, auth: Arc<dyn Auth>) -> Result<()> {
Expand All @@ -91,6 +94,7 @@ where
.context("Set TCP_NODELAY failed")?;

let (stream, sender) = CountStream::new2(stream);
let stream = TimeoutStream::new(stream, self.tcp_timeout);

let inbound = self.inbound.clone();
let outbound = self.outbound.clone();
Expand Down
5 changes: 3 additions & 2 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{sync::Arc, time::Duration};

use crate::{
auth::{Auth, AuthHub},
Expand All @@ -23,7 +23,8 @@ pub async fn start(config: Config) -> Result<()> {
let inbound = TrojanInbound::new(auth_hub.clone(), tls_context, config.trojan).await?;
let outbound = DirectOutbound::new();

let relay = Relay::new(listener, inbound, outbound, config.tls.tcp_nodelay);
let mut relay = Relay::new(listener, inbound, outbound, config.tls.tcp_nodelay);
relay.tcp_timeout = Some(Duration::from_secs(600));

info!("Service started.");

Expand Down
1 change: 1 addition & 0 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pub mod config;
pub mod count_stream;
pub mod logger;
pub mod peekable_stream;
pub mod timeout_stream;
92 changes: 92 additions & 0 deletions src/utils/timeout_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use futures::ready;
use pin_project_lite::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::{
io::{self, AsyncRead, AsyncWrite},
time::{sleep, Duration, Instant, Sleep},
};

pin_project! {
#[derive(Debug)]
pub struct TimeoutStream<S> {
#[pin]
s: S,
duration: Option<Duration>,
sleep: Pin<Box<Sleep>>,
}
}

impl<S> TimeoutStream<S> {
pub fn new(s: S, duration: Option<Duration>) -> TimeoutStream<S>
where
S: AsyncRead + AsyncWrite,
{
TimeoutStream {
s,
duration,
sleep: Box::pin(sleep(duration.unwrap_or(Duration::from_secs(1)))),
}
}
}

impl<S> AsyncRead for TimeoutStream<S>
where
S: AsyncRead + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.project();
match this.duration {
Some(duration) => match this.s.poll_read(cx, buf) {
Poll::Ready(r) => {
this.sleep.as_mut().reset(Instant::now() + *duration);
Poll::Ready(r)
}
Poll::Pending => {
ready!(this.sleep.as_mut().poll(cx));
Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
}
},
None => this.s.poll_read(cx, buf),
}
}
}

impl<S> AsyncWrite for TimeoutStream<S>
where
S: AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let this = self.project();
match this.duration {
Some(duration) => match this.s.poll_write(cx, buf) {
Poll::Ready(r) => {
this.sleep.as_mut().reset(Instant::now() + *duration);
Poll::Ready(r)
}
Poll::Pending => {
ready!(this.sleep.as_mut().poll(cx));
Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
}
},
None => this.s.poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().s.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().s.poll_shutdown(cx)
}
}

0 comments on commit b76f9c5

Please sign in to comment.