Skip to content

Commit

Permalink
feat: use DuplexStream instead of UnixStream to communicate with …
Browse files Browse the repository at this point in the history
…workers (#320)

* perf: use `DuplexStream` instead of `UnixStream` to communicate with workers

* stamp: add an example for chunked transfer encoding

* perf: add `--tcp-nodelay` cli flag and disable nagle's algorithm by default

* stamp: typo
  • Loading branch information
nyannyacha authored Apr 19, 2024
1 parent 9b98ac7 commit 4e53e2a
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 115 deletions.
24 changes: 12 additions & 12 deletions crates/base/src/deno_runtime.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::inspector_server::Inspector;
use crate::rt_worker::rt;
use crate::rt_worker::supervisor::{CPUUsage, CPUUsageMetrics};
use crate::rt_worker::worker::UnixStreamEntry;
use crate::rt_worker::worker::DuplexStreamEntry;
use crate::utils::units::{bytes_to_display, mib_to_bytes};

use anyhow::{anyhow, bail, Context, Error};
Expand Down Expand Up @@ -33,7 +33,6 @@ use std::borrow::Cow;
use std::collections::HashMap;
use std::ffi::c_void;
use std::fmt;
use std::os::fd::RawFd;
use std::sync::Arc;
use std::task::Poll;
use std::time::Duration;
Expand Down Expand Up @@ -525,7 +524,7 @@ impl DenoRuntime {
}

if conf.is_main_worker() || conf.is_user_worker() {
op_state.put::<HashMap<RawFd, watch::Receiver<ConnSync>>>(HashMap::new());
op_state.put::<HashMap<usize, watch::Receiver<ConnSync>>>(HashMap::new());
}

if conf.is_user_worker() {
Expand Down Expand Up @@ -596,14 +595,15 @@ impl DenoRuntime {

pub async fn run(
&mut self,
unix_stream_rx: mpsc::UnboundedReceiver<UnixStreamEntry>,
duplex_stream_rx: mpsc::UnboundedReceiver<DuplexStreamEntry>,
maybe_cpu_usage_metrics_tx: Option<mpsc::UnboundedSender<CPUUsageMetrics>>,
name: Option<String>,
) -> (Result<(), Error>, i64) {
{
let op_state_rc = self.js_runtime.op_state();
let mut op_state = op_state_rc.borrow_mut();
op_state.put::<mpsc::UnboundedReceiver<UnixStreamEntry>>(unix_stream_rx);

op_state.put::<mpsc::UnboundedReceiver<DuplexStreamEntry>>(duplex_stream_rx);

if self.conf.is_main_worker() {
op_state.put::<mpsc::UnboundedSender<UserWorkerMsgs>>(
Expand Down Expand Up @@ -887,7 +887,7 @@ extern "C" fn mem_check_gc_prologue_callback_fn(
#[cfg(test)]
mod test {
use crate::deno_runtime::DenoRuntime;
use crate::rt_worker::worker::UnixStreamEntry;
use crate::rt_worker::worker::DuplexStreamEntry;
use deno_core::{FastString, ModuleCodeString, PollEventLoopOptions};
use sb_graph::emitter::EmitterFactory;
use sb_graph::{generate_binary_eszip, EszipPayloadKind};
Expand Down Expand Up @@ -1467,8 +1467,8 @@ mod test {
let mut user_rt =
create_basic_user_runtime("./test_cases/array_buffers", 20, 1000, &[]).await;

let (_tx, unix_stream_rx) = mpsc::unbounded_channel::<UnixStreamEntry>();
let (result, _) = user_rt.run(unix_stream_rx, None, None).await;
let (_tx, duplex_stream_rx) = mpsc::unbounded_channel::<DuplexStreamEntry>();
let (result, _) = user_rt.run(duplex_stream_rx, None, None).await;

assert!(result.is_ok(), "expected no errors");

Expand All @@ -1482,8 +1482,8 @@ mod test {
let mut user_rt =
create_basic_user_runtime("./test_cases/array_buffers", 15, 1000, &[]).await;

let (_tx, unix_stream_rx) = mpsc::unbounded_channel::<UnixStreamEntry>();
let (result, _) = user_rt.run(unix_stream_rx, None, None).await;
let (_tx, duplex_stream_rx) = mpsc::unbounded_channel::<DuplexStreamEntry>();
let (result, _) = user_rt.run(duplex_stream_rx, None, None).await;

match result {
Err(err) => {
Expand All @@ -1501,7 +1501,7 @@ mod test {
memory_limit_mb: u64,
worker_timeout_ms: u64,
) {
let (_unix_stream_tx, unix_stream_rx) = mpsc::unbounded_channel::<UnixStreamEntry>();
let (_duplex_stream_tx, duplex_stream_rx) = mpsc::unbounded_channel::<DuplexStreamEntry>();
let (callback_tx, mut callback_rx) = mpsc::unbounded_channel::<()>();
let mut user_rt =
create_basic_user_runtime(path, memory_limit_mb, worker_timeout_ms, static_patterns)
Expand All @@ -1518,7 +1518,7 @@ mod test {
});

let wait_fut = async move {
let (result, _) = user_rt.run(unix_stream_rx, None, None).await;
let (result, _) = user_rt.run(duplex_stream_rx, None, None).await;

assert_eq!(
result.unwrap_err().to_string(),
Expand Down
6 changes: 3 additions & 3 deletions crates/base/src/rt_worker/implementation/default_handler.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::deno_runtime::DenoRuntime;
use crate::rt_worker::supervisor::CPUUsageMetrics;
use crate::rt_worker::worker::{HandleCreationType, UnixStreamEntry, Worker, WorkerHandler};
use crate::rt_worker::worker::{DuplexStreamEntry, HandleCreationType, Worker, WorkerHandler};
use anyhow::Error;
use event_worker::events::{BootFailureEvent, PseudoEvent, UncaughtExceptionEvent, WorkerEvents};
use log::error;
Expand All @@ -19,14 +19,14 @@ impl WorkerHandler for Worker {
fn handle_creation<'r>(
&self,
created_rt: &'r mut DenoRuntime,
unix_stream_rx: UnboundedReceiver<UnixStreamEntry>,
duplex_stream_rx: UnboundedReceiver<DuplexStreamEntry>,
termination_event_rx: Receiver<WorkerEvents>,
maybe_cpu_usage_metrics_tx: Option<UnboundedSender<CPUUsageMetrics>>,
name: Option<String>,
) -> HandleCreationType<'r> {
let run_worker_rt = async move {
match created_rt
.run(unix_stream_rx, maybe_cpu_usage_metrics_tx, name)
.run(duplex_stream_rx, maybe_cpu_usage_metrics_tx, name)
.await
{
// if the error is execution terminated, check termination event reason
Expand Down
18 changes: 9 additions & 9 deletions crates/base/src/rt_worker/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use sb_workers::context::{UserWorkerMsgs, WorkerContextInitOpts};
use std::any::Any;
use std::future::{pending, Future};
use std::pin::Pin;
use tokio::net::UnixStream;
use tokio::io;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot::{Receiver, Sender};
use tokio::sync::{oneshot, watch};
Expand All @@ -44,14 +44,14 @@ pub struct Worker {
}

pub type HandleCreationType<'r> = Pin<Box<dyn Future<Output = Result<WorkerEvents, Error>> + 'r>>;
pub type UnixStreamEntry = (UnixStream, Option<watch::Receiver<ConnSync>>);
pub type DuplexStreamEntry = (io::DuplexStream, Option<watch::Receiver<ConnSync>>);

pub trait WorkerHandler: Send {
fn handle_error(&self, error: Error) -> Result<WorkerEvents, Error>;
fn handle_creation<'r>(
&self,
created_rt: &'r mut DenoRuntime,
unix_stream_rx: UnboundedReceiver<UnixStreamEntry>,
duplex_stream_rx: UnboundedReceiver<DuplexStreamEntry>,
termination_event_rx: Receiver<WorkerEvents>,
maybe_cpu_metrics_tx: Option<UnboundedSender<CPUUsageMetrics>>,
name: Option<String>,
Expand Down Expand Up @@ -91,9 +91,9 @@ impl Worker {
pub fn start(
&self,
mut opts: WorkerContextInitOpts,
unix_stream_pair: (
UnboundedSender<UnixStreamEntry>,
UnboundedReceiver<UnixStreamEntry>,
duplex_stream_pair: (
UnboundedSender<DuplexStreamEntry>,
UnboundedReceiver<DuplexStreamEntry>,
),
booter_signal: Sender<Result<MetricSource, Error>>,
termination_token: Option<TerminationToken>,
Expand All @@ -104,7 +104,7 @@ impl Worker {
let event_metadata = self.event_metadata.clone();
let supervisor_policy = self.supervisor_policy;

let (unix_stream_tx, unix_stream_rx) = unix_stream_pair;
let (duplex_stream_tx, duplex_stream_rx) = duplex_stream_pair;
let events_msg_tx = self.events_msg_tx.clone();
let pool_msg_tx = self.pool_msg_tx.clone();

Expand Down Expand Up @@ -244,7 +244,7 @@ impl Worker {
let result = method_cloner
.handle_creation(
&mut runtime,
unix_stream_rx,
duplex_stream_rx,
termination_event_rx,
maybe_cpu_usage_metrics_tx,
Some(worker_name),
Expand Down Expand Up @@ -283,7 +283,7 @@ impl Worker {
}
};

drop(unix_stream_tx);
drop(duplex_stream_tx);

match result {
Ok(event) => {
Expand Down
31 changes: 14 additions & 17 deletions crates/base/src/rt_worker/worker_ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ use std::io::ErrorKind;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::copy_bidirectional;
use tokio::net::{TcpStream, UnixStream};
use tokio::io::{self, copy_bidirectional};
use tokio::net::TcpStream;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::{mpsc, oneshot, watch, Mutex};
use tokio_rustls::server::TlsStream;
Expand All @@ -42,7 +42,7 @@ use uuid::Uuid;

use super::rt;
use super::supervisor::{self, CPUTimerParam, CPUUsageMetrics};
use super::worker::UnixStreamEntry;
use super::worker::DuplexStreamEntry;
use super::worker_pool::{SupervisorPolicy, WorkerPoolPolicy};

#[derive(Clone)]
Expand Down Expand Up @@ -93,28 +93,25 @@ impl TerminationToken {
}

async fn handle_request(
unix_stream_tx: mpsc::UnboundedSender<UnixStreamEntry>,
duplex_stream_tx: mpsc::UnboundedSender<DuplexStreamEntry>,
msg: WorkerRequestMsg,
) -> Result<(), Error> {
// create a unix socket pair
let (sender_stream, recv_stream) = UnixStream::pair()?;
let (ours, theirs) = io::duplex(1024);
let WorkerRequestMsg {
mut req,
res_tx,
conn_watch,
} = msg;

let _ = unix_stream_tx.send((recv_stream, conn_watch.clone()));
let _ = duplex_stream_tx.send((theirs, conn_watch.clone()));
let req_upgrade_type = get_upgrade_type(req.headers());
let req_upgrade = req_upgrade_type
.clone()
.and_then(|it| Some(it).zip(req.extensions_mut().remove::<OnUpgrade>()));

// send the HTTP request to the worker over Unix stream
let (mut request_sender, connection) = http1::Builder::new()
.writev(true)
.handshake(sender_stream)
.await?;
// send the HTTP request to the worker over duplex stream
let (mut request_sender, connection) =
http1::Builder::new().writev(true).handshake(ours).await?;

let (upgrade_tx, upgrade_rx) = oneshot::channel();

Expand Down Expand Up @@ -177,7 +174,7 @@ async fn handle_request(

async fn relay_upgraded_request_and_response(
downstream: OnUpgrade,
parts: http1::Parts<UnixStream>,
parts: http1::Parts<io::DuplexStream>,
) {
let mut upstream = Upgraded2::new(parts.io, parts.read_buf);
let mut downstream = downstream.await.expect("failed to upgrade request");
Expand All @@ -190,7 +187,7 @@ async fn relay_upgraded_request_and_response(
// `close_notify` before shutdown an upstream if downstream is a
// TLS stream.

// INVARIANT: `UnexpectedEof` due to shutdown `UnixStream` is
// INVARIANT: `UnexpectedEof` due to shutdown `DuplexStream` is
// only expected to occur in the context of `TlsStream`.
panic!("unhandleable unexpected eof");
};
Expand Down Expand Up @@ -516,7 +513,7 @@ pub async fn create_worker<Opt: Into<CreateWorkerArgs>>(
init_opts: Opt,
inspector: Option<Inspector>,
) -> Result<(MetricSource, mpsc::UnboundedSender<WorkerRequestMsg>), Error> {
let (unix_stream_tx, unix_stream_rx) = mpsc::unbounded_channel::<UnixStreamEntry>();
let (duplex_stream_tx, duplex_stream_rx) = mpsc::unbounded_channel::<DuplexStreamEntry>();
let (worker_boot_result_tx, worker_boot_result_rx) =
oneshot::channel::<Result<MetricSource, Error>>();

Expand All @@ -539,7 +536,7 @@ pub async fn create_worker<Opt: Into<CreateWorkerArgs>>(
if let Some(worker_struct_ref) = downcast_reference {
worker_struct_ref.start(
init_opts,
(unix_stream_tx.clone(), unix_stream_rx),
(duplex_stream_tx.clone(), duplex_stream_rx),
worker_boot_result_tx,
maybe_termination_token.clone(),
inspector,
Expand All @@ -549,7 +546,7 @@ pub async fn create_worker<Opt: Into<CreateWorkerArgs>>(
let (worker_req_tx, mut worker_req_rx) = mpsc::unbounded_channel::<WorkerRequestMsg>();

let worker_req_handle: tokio::task::JoinHandle<Result<(), Error>> = tokio::task::spawn({
let stream_tx = unix_stream_tx;
let stream_tx = duplex_stream_tx;
async move {
while let Some(msg) = worker_req_rx.recv().await {
tokio::task::spawn({
Expand Down
17 changes: 16 additions & 1 deletion crates/base/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ pub struct WorkerEntrypoints {
pub struct ServerFlags {
pub no_module_cache: bool,
pub allow_main_inspector: bool,
pub tcp_nodelay: bool,
pub graceful_exit_deadline_sec: u64,
}

Expand Down Expand Up @@ -446,7 +447,13 @@ impl Server {
}

let event_tx = can_receive_event.then_some(event_tx.clone());
let mut graceful_exit_deadline = flags.graceful_exit_deadline_sec;
let ServerFlags {
tcp_nodelay,
graceful_exit_deadline_sec,
..
} = flags;

let mut graceful_exit_deadline = graceful_exit_deadline_sec;

loop {
let main_worker_req_tx = self.main_worker_req_tx.clone();
Expand All @@ -457,6 +464,10 @@ impl Server {
msg = non_secure_listener.accept() => {
match msg {
Ok((stream, _)) => {
if tcp_nodelay {
let _ = stream.set_nodelay(true);
}

accept_stream(stream, main_worker_req_tx, event_tx, metric_src)
}
Err(e) => error!("socket error: {}", e)
Expand All @@ -473,6 +484,10 @@ impl Server {
} => {
match msg {
Ok((stream, _)) => {
if tcp_nodelay {
let _ = stream.get_ref().0.set_nodelay(true);
}

accept_stream(stream, main_worker_req_tx, event_tx, metric_src);
}
Err(e) => error!("socket error: {}", e)
Expand Down
14 changes: 13 additions & 1 deletion crates/cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use base::deno_runtime::MAYBE_DENO_VERSION;
use base::rt_worker::worker_pool::{SupervisorPolicy, WorkerPoolPolicy};
use base::server::{ServerFlags, Tls, WorkerEntrypoints};
use base::{DecoratorType, InspectorOption};
use clap::builder::{FalseyValueParser, TypedValueParser};
use clap::builder::{BoolishValueParser, FalseyValueParser, TypedValueParser};
use clap::{arg, crate_version, value_parser, ArgAction, ArgGroup, ArgMatches, Command};
use deno_core::url::Url;
use log::warn;
Expand Down Expand Up @@ -145,6 +145,13 @@ fn cli() -> Command {
.action(ArgAction::SetTrue)
)
.arg(arg!(--"static" <Path> "Glob pattern for static files to be included"))
.arg(arg!(--"tcp-nodelay" [BOOL] "Disables Nagle's algorithm")
.num_args(0..=1)
.value_parser(BoolishValueParser::new())
.require_equals(true)
.default_value("true")
.default_missing_value("true")
)
)
.subcommand(
Command::new("bundle")
Expand Down Expand Up @@ -269,6 +276,10 @@ fn main() -> Result<(), anyhow::Error> {
None
};

let tcp_nodelay =sub_matches.get_one::<bool>("tcp-nodelay")
.copied()
.unwrap();

start_server(
ip.as_str(),
port,
Expand Down Expand Up @@ -298,6 +309,7 @@ fn main() -> Result<(), anyhow::Error> {
ServerFlags {
no_module_cache,
allow_main_inspector,
tcp_nodelay,
graceful_exit_deadline_sec: graceful_exit_timeout.unwrap_or(0),
},
None,
Expand Down
Loading

0 comments on commit 4e53e2a

Please sign in to comment.