fix message loss with requests with multiple responses

This commit is contained in:
2026-05-14 00:04:51 +03:00
parent 21556aeb04
commit 5d2f0a5011
4 changed files with 144 additions and 83 deletions

View File

@@ -74,6 +74,7 @@ pub(crate) trait PluginConnection: Send + Sync + Debug {
#[derive(Debug)]
pub(crate) enum PluginConnectionError {
SendRequest(String, PluginRequestType, Box<dyn Error + Send + Sync>),
Timeout(String, PluginRequestType),
ReadResponse(String, PluginRequestType, Box<dyn Error + Send + Sync>),
DecodeResponse(String, PluginRequestType, Option<DecodeError>),
InvalidMessageLength,
@@ -85,6 +86,9 @@ impl Display for PluginConnectionError {
Self::SendRequest(plugin_name, req_type, e) => f.write_fmt(format_args!(
"Can't send request ({req_type}) to plugin {plugin_name}: {e}"
)),
Self::Timeout(plugin_name, req_type) => f.write_fmt(format_args!(
"Timed out waiting for response to {req_type} request to plugin {plugin_name}"
)),
Self::ReadResponse(plugin_name, resp_type, e) => f.write_fmt(format_args!(
"Can't read response ({resp_type}) from plugin {plugin_name}: {e}"
)),
@@ -101,6 +105,7 @@ impl Error for PluginConnectionError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::SendRequest(_, _, e) => Some(e.as_ref()),
Self::Timeout(_, _) => None,
Self::ReadResponse(_, _, e) => Some(e.as_ref()),
Self::DecodeResponse(_, _, e) => e.as_ref().map(|e| e as &dyn Error),
Self::InvalidMessageLength => None,

View File

@@ -55,13 +55,19 @@ pub(super) async fn initialize_stdio_plugin(
Ok(plugin)
}
#[derive(Debug)]
enum ResponseListener {
Oneshot(oneshot::Sender<proto::Response>),
Multishot(mpsc::Sender<proto::Response>),
}
#[derive(Debug)]
struct StdioPluginConnection {
plugin: Arc<Mutex<LoadedPlugin>>,
process: Mutex<Child>,
buffered_stdout: Mutex<BufReader<ChildStdout>>,
next_request_id: Mutex<u32>,
pending_requests: Mutex<HashMap<u32, oneshot::Sender<proto::Response>>>,
pending_requests: Mutex<HashMap<u32, ResponseListener>>,
}
impl StdioPluginConnection {
@@ -103,18 +109,29 @@ impl StdioPluginConnection {
continue;
}
};
match self
.pending_requests
.lock()
.await
.remove(&response.request_id)
{
Some(sender) => match sender.send(response) {
let mut pending_req_lock = self.pending_requests.lock().await;
match pending_req_lock.get(&response.request_id) {
Some(ResponseListener::Oneshot(_)) => {
// We must remove to be able to send the response to the oneshot channel
if let Some(ResponseListener::Oneshot(sender)) = pending_req_lock.remove(&response.request_id) {
match sender.send(response) {
Ok(()) => {}
Err(response) => {
log::warn!(
"Dropping response with request_id {} from plugin {}",
response.request_id,
self.plugin.lock().await.plugin_id
);
}
}
}
}
Some(ResponseListener::Multishot(sender)) => match sender.send(response).await {
Ok(()) => {}
Err(response) => {
Err(e) => {
log::warn!(
"Dropping response with request_id {} from plugin {}",
response.request_id,
e.0.request_id,
self.plugin.lock().await.plugin_id
);
}
@@ -155,21 +172,6 @@ impl StdioPluginConnection {
reader.read_exact(&mut message).await?;
Ok(message)
}
async fn await_response_to(
&self,
request_id: u32,
timeout: Duration,
) -> Result<proto::Response, Box<dyn Error + Send + Sync>> {
let (tx, rx) = oneshot::channel();
self.pending_requests.lock().await.insert(request_id, tx);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(e)) => Err(Box::new(e)),
Err(e) => Err(Box::new(e)),
}
}
}
#[async_trait]
@@ -178,20 +180,31 @@ impl PluginConnection for StdioPluginConnection {
&self,
config: String,
) -> Result<proto::PluginInitializeResponse, PluginConnectionError> {
// Generate request id
let request_id = {
let mut r = self.next_request_id.lock().await;
let id = *r;
*r += 1;
id
};
// Encode request
let request = proto::Request {
request_id,
req: Some(proto::request::Req::InitializeReq(
proto::PluginInitializeRequest { config },
)),
req: Some(proto::request::Req::InitializeReq(proto::PluginInitializeRequest {
config,
})),
}
.encode_length_delimited_to_vec();
// Add response listener
let (tx, rx) = oneshot::channel();
self.pending_requests
.lock()
.await
.insert(request_id, ResponseListener::Oneshot(tx));
// Send request
if let Err(e) = self
.process
.lock()
@@ -209,16 +222,20 @@ impl PluginConnection for StdioPluginConnection {
));
}
let response = match self
.await_response_to(request_id, Duration::from_secs(10))
.await
{
Ok(response) => response,
Err(e) => {
// Receive response
let response = match tokio::time::timeout(Duration::from_secs(10), rx).await {
Ok(Ok(response)) => response,
Ok(Err(e)) => {
return Err(PluginConnectionError::ReadResponse(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::Initialize,
e,
Box::new(e),
));
}
Err(_) => {
return Err(PluginConnectionError::Timeout(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::Initialize,
));
}
};
@@ -233,23 +250,30 @@ impl PluginConnection for StdioPluginConnection {
}
}
async fn request_plugin_command_list(
&self,
) -> Result<proto::PluginCommandListResponse, PluginConnectionError> {
async fn request_plugin_command_list(&self) -> Result<proto::PluginCommandListResponse, PluginConnectionError> {
// Generate request id
let request_id = {
let mut r = self.next_request_id.lock().await;
let id = *r;
*r += 1;
id
};
// Encode request
let request = proto::Request {
request_id,
req: Some(proto::request::Req::CommandListReq(
proto::PluginCommandListRequest {},
)),
req: Some(proto::request::Req::CommandListReq(proto::PluginCommandListRequest {})),
}
.encode_length_delimited_to_vec();
// Add response listener
let (tx, rx) = oneshot::channel();
self.pending_requests
.lock()
.await
.insert(request_id, ResponseListener::Oneshot(tx));
// Send request
if let Err(e) = self
.process
.lock()
@@ -267,16 +291,20 @@ impl PluginConnection for StdioPluginConnection {
));
}
let response = match self
.await_response_to(request_id, Duration::from_secs(10))
.await
{
Ok(response) => response,
Err(e) => {
// Receive response
let response = match tokio::time::timeout(Duration::from_secs(10), rx).await {
Ok(Ok(response)) => response,
Ok(Err(e)) => {
return Err(PluginConnectionError::ReadResponse(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::CommandList,
e,
Box::new(e),
));
}
Err(_) => {
return Err(PluginConnectionError::Timeout(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::CommandList,
));
}
};
@@ -296,28 +324,34 @@ impl PluginConnection for StdioPluginConnection {
command_id: String,
issuer_id: String,
argv: Vec<String>,
) -> Result<
mpsc::Receiver<Result<proto::CommandReply, PluginConnectionError>>,
PluginConnectionError,
> {
) -> Result<mpsc::Receiver<Result<proto::CommandReply, PluginConnectionError>>, PluginConnectionError> {
// Generate request id
let request_id = {
let mut r = self.next_request_id.lock().await;
let id = *r;
*r += 1;
id
};
// Encode request
let request = proto::Request {
request_id,
req: Some(proto::request::Req::ExecuteReq(
proto::PluginExecuteRequest {
command_id,
issuer_id,
arg_vector: argv,
},
)),
req: Some(proto::request::Req::ExecuteReq(proto::PluginExecuteRequest {
command_id,
issuer_id,
arg_vector: argv,
})),
}
.encode_length_delimited_to_vec();
// Add response listener
let (tx, mut rx) = mpsc::channel(4);
self.pending_requests
.lock()
.await
.insert(request_id, ResponseListener::Multishot(tx));
// Send request
if let Err(e) = self
.process
.lock()
@@ -335,26 +369,14 @@ impl PluginConnection for StdioPluginConnection {
));
}
let (tx, rx) = mpsc::channel(4);
// Receive responses
let (tx_out, rx_out) = mpsc::channel(4);
tokio::spawn(async move {
loop {
let response = match self
.await_response_to(request_id, Duration::from_secs(600))
.await
{
Ok(response) => response,
Err(e) => {
if let Err(e) = tx
.send(Err(PluginConnectionError::ReadResponse(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::Execute,
e,
)))
.await
{
log::error!("Cannot send error notification to another task: {e}");
}
let response = match rx.recv().await {
Some(response) => response,
None => {
log::error!("Response channel for request {request_id} closed unexpectedly");
break;
}
};
@@ -365,7 +387,8 @@ impl PluginConnection for StdioPluginConnection {
Some(proto::command_reply::Reply::End(_)) => true,
_ => false,
};
if let Err(e) = tx.send(Ok(reply)).await {
log::trace!("reply sent: Ok({:?})", &reply);
if let Err(e) = tx_out.send(Ok(reply)).await {
log::error!("Cannot send command reply to another task: {e}");
}
if end {
@@ -373,7 +396,7 @@ impl PluginConnection for StdioPluginConnection {
}
}
_ => {
if let Err(e) = tx
if let Err(e) = tx_out
.send(Err(PluginConnectionError::DecodeResponse(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::Execute,
@@ -387,8 +410,10 @@ impl PluginConnection for StdioPluginConnection {
}
}
}
// Remove response listener
self.pending_requests.lock().await.remove_entry(&request_id);
});
Ok(rx)
Ok(rx_out)
}
}