Skip to content

Commit

Permalink
try to make io retry generic
Browse files Browse the repository at this point in the history
  • Loading branch information
raffber committed Sep 26, 2024
1 parent e643a4b commit 1806835
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 36 deletions.
44 changes: 44 additions & 0 deletions comsrv/src/iotask.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::future::Future;
use std::time::Duration;

/// This module implements a very simple actor interface to which a request can be sent and a
/// response is returned.
///
Expand All @@ -8,6 +11,7 @@ use anyhow::anyhow;
use async_trait::async_trait;
use tokio::sync::{mpsc, oneshot};
use tokio::task;
use tokio::time::sleep;

/// Trait constraining the `Request` and `Response` associated types of `IoHandler`.
pub trait Message: 'static + Send {}
Expand Down Expand Up @@ -113,3 +117,43 @@ impl<T: 'static + IoHandler> IoTask<T> {
rx.await.map_err(|_| Error::internal(anyhow!("Channel disconnected")))?
}
}

pub async fn io_retry<Ret, Client, FutureHandle, FutureConnect, Handle, Connect>(
client: &mut Option<Client>,
handle: Handle,
connect: Connect,
) -> crate::Result<Ret>
where
Handle: Fn(&mut Client) -> FutureHandle,
FutureHandle: Future<Output = crate::Result<Ret>>,
Connect: Fn() -> FutureConnect,
FutureConnect: Future<Output = crate::Result<Client>>,
{
let mut c = if let Some(client) = client.take() {
client
} else {
connect().await?
};

let ret = handle(&mut c).await;
match ret {
Ok(ret) => {
client.replace(c);
Ok(ret)
}
Err(err) => {
drop(c);
if err.should_retry() {
sleep(Duration::from_millis(100)).await;
let mut c = connect().await?;
let ret = handle(&mut c).await;
if ret.is_ok() {
client.replace(c);
}
Ok(ret?)
} else {
Err(err)
}
}
}
}
65 changes: 29 additions & 36 deletions comsrv/src/transport/vxi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::time::Instant;
use tokio::task::{self, JoinHandle};
use tokio::time::{sleep, Duration};

use crate::iotask::{IoContext, IoHandler, IoTask};
use crate::iotask::{io_retry, IoContext, IoHandler, IoTask};
use crate::{protocol::scpi, Error};
use anyhow::anyhow;
use comsrv_protocol::{ScpiRequest, ScpiResponse};
Expand Down Expand Up @@ -99,14 +99,6 @@ impl Handler {
None
}

async fn connect(&self) -> crate::Result<CoreClient> {
let fut = CoreClient::connect(self.addr);
let ret = tokio::time::timeout(DEFAULT_CONNECTION_TIMEOUT, fut)
.await
.map_err(|_| crate::Error::protocol_timeout())?;
ret.map_err(map_error)
}

async fn handle_request_timeout(
client: &mut CoreClient,
req: ScpiRequest,
Expand Down Expand Up @@ -165,6 +157,24 @@ impl Handler {
}
}

struct Connector {
addr: IpAddr,
}

impl Connector {
fn new(addr: IpAddr) -> Self {
Self { addr }
}

async fn connect(&self) -> crate::Result<CoreClient> {
let fut = CoreClient::connect(self.addr);
let ret = tokio::time::timeout(DEFAULT_CONNECTION_TIMEOUT, fut)
.await
.map_err(|_| crate::Error::protocol_timeout())?;
ret.map_err(map_error)
}
}

#[async_trait]
impl IoHandler for Handler {
type Request = Request;
Expand All @@ -174,37 +184,20 @@ impl IoHandler for Handler {
if let Some(x) = self.drop_check(&req) {
return x;
}
let mut client = if let Some(client) = self.client.take() {
client
} else {
self.connect().await?
};
match req {
Request::Scpi { scpi: req, timeout } => {
let timeout = timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
let ret = Self::handle_request_timeout(&mut client, req.clone(), timeout).await;
match ret {
Ok(ret) => {
self.client.replace(client);
self.spawn_drop_check(ctx);
Ok(Response::Scpi(ret))
}
Err(err) => {
drop(client);
if err.should_retry() {
sleep(Duration::from_millis(100)).await;
let mut client = self.connect().await?;
let ret = Self::handle_request_timeout(&mut client, req, timeout).await;
if ret.is_ok() {
self.client.replace(client);
self.spawn_drop_check(ctx);
}
Ok(Response::Scpi(ret?))
} else {
Err(err)
}
}
let connector = Connector::new(self.addr);
let ret = io_retry(
&mut self.client,
|client| Self::handle_request_timeout(client, req.clone(), timeout),
|| connector.connect(),
)
.await?;
if self.client.is_some() {
self.spawn_drop_check(ctx);
}
Ok(Response::Scpi(ret))
}
Request::DropCheck => Ok(Response::Done),
}
Expand Down

0 comments on commit 1806835

Please sign in to comment.