fix message loss with requests with multiple responses
This commit is contained in:
@@ -74,6 +74,7 @@ pub(crate) trait PluginConnection: Send + Sync + Debug {
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum PluginConnectionError {
|
||||
SendRequest(String, PluginRequestType, Box<dyn Error + Send + Sync>),
|
||||
Timeout(String, PluginRequestType),
|
||||
ReadResponse(String, PluginRequestType, Box<dyn Error + Send + Sync>),
|
||||
DecodeResponse(String, PluginRequestType, Option<DecodeError>),
|
||||
InvalidMessageLength,
|
||||
@@ -85,6 +86,9 @@ impl Display for PluginConnectionError {
|
||||
Self::SendRequest(plugin_name, req_type, e) => f.write_fmt(format_args!(
|
||||
"Can't send request ({req_type}) to plugin {plugin_name}: {e}"
|
||||
)),
|
||||
Self::Timeout(plugin_name, req_type) => f.write_fmt(format_args!(
|
||||
"Timed out waiting for response to {req_type} request to plugin {plugin_name}"
|
||||
)),
|
||||
Self::ReadResponse(plugin_name, resp_type, e) => f.write_fmt(format_args!(
|
||||
"Can't read response ({resp_type}) from plugin {plugin_name}: {e}"
|
||||
)),
|
||||
@@ -101,6 +105,7 @@ impl Error for PluginConnectionError {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
match self {
|
||||
Self::SendRequest(_, _, e) => Some(e.as_ref()),
|
||||
Self::Timeout(_, _) => None,
|
||||
Self::ReadResponse(_, _, e) => Some(e.as_ref()),
|
||||
Self::DecodeResponse(_, _, e) => e.as_ref().map(|e| e as &dyn Error),
|
||||
Self::InvalidMessageLength => None,
|
||||
|
||||
@@ -55,13 +55,19 @@ pub(super) async fn initialize_stdio_plugin(
|
||||
Ok(plugin)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ResponseListener {
|
||||
Oneshot(oneshot::Sender<proto::Response>),
|
||||
Multishot(mpsc::Sender<proto::Response>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct StdioPluginConnection {
|
||||
plugin: Arc<Mutex<LoadedPlugin>>,
|
||||
process: Mutex<Child>,
|
||||
buffered_stdout: Mutex<BufReader<ChildStdout>>,
|
||||
next_request_id: Mutex<u32>,
|
||||
pending_requests: Mutex<HashMap<u32, oneshot::Sender<proto::Response>>>,
|
||||
pending_requests: Mutex<HashMap<u32, ResponseListener>>,
|
||||
}
|
||||
|
||||
impl StdioPluginConnection {
|
||||
@@ -103,18 +109,29 @@ impl StdioPluginConnection {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
match self
|
||||
.pending_requests
|
||||
.lock()
|
||||
.await
|
||||
.remove(&response.request_id)
|
||||
{
|
||||
Some(sender) => match sender.send(response) {
|
||||
let mut pending_req_lock = self.pending_requests.lock().await;
|
||||
match pending_req_lock.get(&response.request_id) {
|
||||
Some(ResponseListener::Oneshot(_)) => {
|
||||
// We must remove to be able to send the response to the oneshot channel
|
||||
if let Some(ResponseListener::Oneshot(sender)) = pending_req_lock.remove(&response.request_id) {
|
||||
match sender.send(response) {
|
||||
Ok(()) => {}
|
||||
Err(response) => {
|
||||
log::warn!(
|
||||
"Dropping response with request_id {} from plugin {}",
|
||||
response.request_id,
|
||||
self.plugin.lock().await.plugin_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(ResponseListener::Multishot(sender)) => match sender.send(response).await {
|
||||
Ok(()) => {}
|
||||
Err(response) => {
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"Dropping response with request_id {} from plugin {}",
|
||||
response.request_id,
|
||||
e.0.request_id,
|
||||
self.plugin.lock().await.plugin_id
|
||||
);
|
||||
}
|
||||
@@ -155,21 +172,6 @@ impl StdioPluginConnection {
|
||||
reader.read_exact(&mut message).await?;
|
||||
Ok(message)
|
||||
}
|
||||
|
||||
async fn await_response_to(
|
||||
&self,
|
||||
request_id: u32,
|
||||
timeout: Duration,
|
||||
) -> Result<proto::Response, Box<dyn Error + Send + Sync>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.pending_requests.lock().await.insert(request_id, tx);
|
||||
|
||||
match tokio::time::timeout(timeout, rx).await {
|
||||
Ok(Ok(response)) => Ok(response),
|
||||
Ok(Err(e)) => Err(Box::new(e)),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -178,20 +180,31 @@ impl PluginConnection for StdioPluginConnection {
|
||||
&self,
|
||||
config: String,
|
||||
) -> Result<proto::PluginInitializeResponse, PluginConnectionError> {
|
||||
// Generate request id
|
||||
let request_id = {
|
||||
let mut r = self.next_request_id.lock().await;
|
||||
let id = *r;
|
||||
*r += 1;
|
||||
id
|
||||
};
|
||||
|
||||
// Encode request
|
||||
let request = proto::Request {
|
||||
request_id,
|
||||
req: Some(proto::request::Req::InitializeReq(
|
||||
proto::PluginInitializeRequest { config },
|
||||
)),
|
||||
req: Some(proto::request::Req::InitializeReq(proto::PluginInitializeRequest {
|
||||
config,
|
||||
})),
|
||||
}
|
||||
.encode_length_delimited_to_vec();
|
||||
|
||||
// Add response listener
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.pending_requests
|
||||
.lock()
|
||||
.await
|
||||
.insert(request_id, ResponseListener::Oneshot(tx));
|
||||
|
||||
// Send request
|
||||
if let Err(e) = self
|
||||
.process
|
||||
.lock()
|
||||
@@ -209,16 +222,20 @@ impl PluginConnection for StdioPluginConnection {
|
||||
));
|
||||
}
|
||||
|
||||
let response = match self
|
||||
.await_response_to(request_id, Duration::from_secs(10))
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
// Receive response
|
||||
let response = match tokio::time::timeout(Duration::from_secs(10), rx).await {
|
||||
Ok(Ok(response)) => response,
|
||||
Ok(Err(e)) => {
|
||||
return Err(PluginConnectionError::ReadResponse(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::Initialize,
|
||||
e,
|
||||
Box::new(e),
|
||||
));
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(PluginConnectionError::Timeout(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::Initialize,
|
||||
));
|
||||
}
|
||||
};
|
||||
@@ -233,23 +250,30 @@ impl PluginConnection for StdioPluginConnection {
|
||||
}
|
||||
}
|
||||
|
||||
async fn request_plugin_command_list(
|
||||
&self,
|
||||
) -> Result<proto::PluginCommandListResponse, PluginConnectionError> {
|
||||
async fn request_plugin_command_list(&self) -> Result<proto::PluginCommandListResponse, PluginConnectionError> {
|
||||
// Generate request id
|
||||
let request_id = {
|
||||
let mut r = self.next_request_id.lock().await;
|
||||
let id = *r;
|
||||
*r += 1;
|
||||
id
|
||||
};
|
||||
|
||||
// Encode request
|
||||
let request = proto::Request {
|
||||
request_id,
|
||||
req: Some(proto::request::Req::CommandListReq(
|
||||
proto::PluginCommandListRequest {},
|
||||
)),
|
||||
req: Some(proto::request::Req::CommandListReq(proto::PluginCommandListRequest {})),
|
||||
}
|
||||
.encode_length_delimited_to_vec();
|
||||
|
||||
// Add response listener
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.pending_requests
|
||||
.lock()
|
||||
.await
|
||||
.insert(request_id, ResponseListener::Oneshot(tx));
|
||||
|
||||
// Send request
|
||||
if let Err(e) = self
|
||||
.process
|
||||
.lock()
|
||||
@@ -267,16 +291,20 @@ impl PluginConnection for StdioPluginConnection {
|
||||
));
|
||||
}
|
||||
|
||||
let response = match self
|
||||
.await_response_to(request_id, Duration::from_secs(10))
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
// Receive response
|
||||
let response = match tokio::time::timeout(Duration::from_secs(10), rx).await {
|
||||
Ok(Ok(response)) => response,
|
||||
Ok(Err(e)) => {
|
||||
return Err(PluginConnectionError::ReadResponse(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::CommandList,
|
||||
e,
|
||||
Box::new(e),
|
||||
));
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(PluginConnectionError::Timeout(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::CommandList,
|
||||
));
|
||||
}
|
||||
};
|
||||
@@ -296,28 +324,34 @@ impl PluginConnection for StdioPluginConnection {
|
||||
command_id: String,
|
||||
issuer_id: String,
|
||||
argv: Vec<String>,
|
||||
) -> Result<
|
||||
mpsc::Receiver<Result<proto::CommandReply, PluginConnectionError>>,
|
||||
PluginConnectionError,
|
||||
> {
|
||||
) -> Result<mpsc::Receiver<Result<proto::CommandReply, PluginConnectionError>>, PluginConnectionError> {
|
||||
// Generate request id
|
||||
let request_id = {
|
||||
let mut r = self.next_request_id.lock().await;
|
||||
let id = *r;
|
||||
*r += 1;
|
||||
id
|
||||
};
|
||||
|
||||
// Encode request
|
||||
let request = proto::Request {
|
||||
request_id,
|
||||
req: Some(proto::request::Req::ExecuteReq(
|
||||
proto::PluginExecuteRequest {
|
||||
command_id,
|
||||
issuer_id,
|
||||
arg_vector: argv,
|
||||
},
|
||||
)),
|
||||
req: Some(proto::request::Req::ExecuteReq(proto::PluginExecuteRequest {
|
||||
command_id,
|
||||
issuer_id,
|
||||
arg_vector: argv,
|
||||
})),
|
||||
}
|
||||
.encode_length_delimited_to_vec();
|
||||
|
||||
// Add response listener
|
||||
let (tx, mut rx) = mpsc::channel(4);
|
||||
self.pending_requests
|
||||
.lock()
|
||||
.await
|
||||
.insert(request_id, ResponseListener::Multishot(tx));
|
||||
|
||||
// Send request
|
||||
if let Err(e) = self
|
||||
.process
|
||||
.lock()
|
||||
@@ -335,26 +369,14 @@ impl PluginConnection for StdioPluginConnection {
|
||||
));
|
||||
}
|
||||
|
||||
let (tx, rx) = mpsc::channel(4);
|
||||
|
||||
// Receive responses
|
||||
let (tx_out, rx_out) = mpsc::channel(4);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let response = match self
|
||||
.await_response_to(request_id, Duration::from_secs(600))
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
if let Err(e) = tx
|
||||
.send(Err(PluginConnectionError::ReadResponse(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::Execute,
|
||||
e,
|
||||
)))
|
||||
.await
|
||||
{
|
||||
log::error!("Cannot send error notification to another task: {e}");
|
||||
}
|
||||
let response = match rx.recv().await {
|
||||
Some(response) => response,
|
||||
None => {
|
||||
log::error!("Response channel for request {request_id} closed unexpectedly");
|
||||
break;
|
||||
}
|
||||
};
|
||||
@@ -365,7 +387,8 @@ impl PluginConnection for StdioPluginConnection {
|
||||
Some(proto::command_reply::Reply::End(_)) => true,
|
||||
_ => false,
|
||||
};
|
||||
if let Err(e) = tx.send(Ok(reply)).await {
|
||||
log::trace!("reply sent: Ok({:?})", &reply);
|
||||
if let Err(e) = tx_out.send(Ok(reply)).await {
|
||||
log::error!("Cannot send command reply to another task: {e}");
|
||||
}
|
||||
if end {
|
||||
@@ -373,7 +396,7 @@ impl PluginConnection for StdioPluginConnection {
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if let Err(e) = tx
|
||||
if let Err(e) = tx_out
|
||||
.send(Err(PluginConnectionError::DecodeResponse(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::Execute,
|
||||
@@ -387,8 +410,10 @@ impl PluginConnection for StdioPluginConnection {
|
||||
}
|
||||
}
|
||||
}
|
||||
// Remove response listener
|
||||
self.pending_requests.lock().await.remove_entry(&request_id);
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
Ok(rx_out)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user