fix message loss with requests with multiple responses

This commit is contained in:
2026-05-14 00:04:51 +03:00
parent 21556aeb04
commit 5d2f0a5011
4 changed files with 144 additions and 83 deletions

View File

@@ -7,78 +7,107 @@ pub(crate) fn yaml_to_json(yaml: &Value) -> String {
}
fn write_yaml_to_json(yaml: &Value, out: &mut String) {
eprintln!("write_yaml_to_json {:?}", yaml);
match yaml {
Value::Null => {
eprintln!("write null");
out.push_str("null");
}
Value::Bool(b) => {
eprintln!("write true/false");
out.push_str(if *b { "true" } else { "false" });
}
Value::Number(n) => {
eprintln!("write {n}");
out.push_str(&n.to_string());
}
Value::String(s) => {
eprintln!("write \"");
out.push('"');
for c in s.encode_utf16() {
match c {
0x005C | 0x0022 => {
eprintln!("write \\{}", char::from_u32(c as u32).unwrap());
out.push('\\');
out.push(char::from_u32(c as u32).unwrap());
}
0x0008 => {
eprintln!("write \\b");
out.push('\\');
out.push('b');
}
0x0009 => {
eprintln!("write \\t");
out.push('\\');
out.push('t');
}
0x000a => {
eprintln!("write \\n");
out.push('\\');
out.push('n');
}
0x000C => {
eprintln!("write \\f");
out.push('\\');
out.push('f');
}
0x0000..=0x001F | 0x0080..=0xFFFF => {
eprintln!("write \\u{c:04x}");
out.push_str(&format!("\\u{c:04x}"));
}
_ => {
eprintln!("write {}", char::from_u32(c as u32).unwrap());
out.push(char::from_u32(c as u32).unwrap());
}
}
}
out.push_str(s); // TODO escape
eprintln!("write \"");
out.push('"');
}
Value::Tagged(_) => {}
Value::Sequence(list) => {
eprintln!("write [");
out.push('[');
if !list.is_empty() {
write_yaml_to_json(&list[0], out);
for i in 1..list.len() {
eprintln!("write ,");
out.push(',');
write_yaml_to_json(&list[i], out);
}
}
eprintln!("write ]");
out.push(']');
}
Value::Mapping(map) => {
eprintln!("write {{");
out.push('{');
let mut iter = map.iter();
if let Some(kv) = iter.next() {
write_yaml_to_json(&Value::String(yaml_to_json(kv.0)), out);
let key = match kv.0 {
Value::String(_) => kv.0,
_ => &Value::String(yaml_to_json(kv.0)),
};
write_yaml_to_json(key, out);
eprintln!("write :");
out.push(':');
write_yaml_to_json(kv.1, out);
}
for kv in iter {
eprintln!("write ,");
out.push(',');
write_yaml_to_json(&Value::String(yaml_to_json(kv.0)), out);
let key = match kv.0 {
Value::String(_) => kv.0,
_ => &Value::String(yaml_to_json(kv.0)),
};
write_yaml_to_json(key, out);
eprintln!("write :");
out.push(':');
write_yaml_to_json(kv.1, out);
}
eprintln!("write }}");
out.push('}');
}
}
eprintln!("written {:?}", yaml);
}

View File

@@ -224,6 +224,7 @@ async fn handle_message(
Ok(mut replies) => {
let mut last_message_id = None;
while let Some(reply) = replies.recv().await {
log::trace!("reply recv'ed: Some({:?})", reply);
match reply {
Ok(reply) => {
log::debug!(
@@ -274,6 +275,7 @@ async fn handle_message(
}
}
}
log::debug!("{} :: /{} command execution finished", &plugin_name, &cmd_name);
}
Err(e) => {
log::error!(

View File

@@ -74,6 +74,7 @@ pub(crate) trait PluginConnection: Send + Sync + Debug {
#[derive(Debug)]
pub(crate) enum PluginConnectionError {
SendRequest(String, PluginRequestType, Box<dyn Error + Send + Sync>),
Timeout(String, PluginRequestType),
ReadResponse(String, PluginRequestType, Box<dyn Error + Send + Sync>),
DecodeResponse(String, PluginRequestType, Option<DecodeError>),
InvalidMessageLength,
@@ -85,6 +86,9 @@ impl Display for PluginConnectionError {
Self::SendRequest(plugin_name, req_type, e) => f.write_fmt(format_args!(
"Can't send request ({req_type}) to plugin {plugin_name}: {e}"
)),
Self::Timeout(plugin_name, req_type) => f.write_fmt(format_args!(
"Timed out waiting for response to {req_type} request to plugin {plugin_name}"
)),
Self::ReadResponse(plugin_name, resp_type, e) => f.write_fmt(format_args!(
"Can't read response ({resp_type}) from plugin {plugin_name}: {e}"
)),
@@ -101,6 +105,7 @@ impl Error for PluginConnectionError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::SendRequest(_, _, e) => Some(e.as_ref()),
Self::Timeout(_, _) => None,
Self::ReadResponse(_, _, e) => Some(e.as_ref()),
Self::DecodeResponse(_, _, e) => e.as_ref().map(|e| e as &dyn Error),
Self::InvalidMessageLength => None,

View File

@@ -55,13 +55,19 @@ pub(super) async fn initialize_stdio_plugin(
Ok(plugin)
}
#[derive(Debug)]
enum ResponseListener {
Oneshot(oneshot::Sender<proto::Response>),
Multishot(mpsc::Sender<proto::Response>),
}
#[derive(Debug)]
struct StdioPluginConnection {
plugin: Arc<Mutex<LoadedPlugin>>,
process: Mutex<Child>,
buffered_stdout: Mutex<BufReader<ChildStdout>>,
next_request_id: Mutex<u32>,
pending_requests: Mutex<HashMap<u32, oneshot::Sender<proto::Response>>>,
pending_requests: Mutex<HashMap<u32, ResponseListener>>,
}
impl StdioPluginConnection {
@@ -103,18 +109,29 @@ impl StdioPluginConnection {
continue;
}
};
match self
.pending_requests
.lock()
.await
.remove(&response.request_id)
{
Some(sender) => match sender.send(response) {
let mut pending_req_lock = self.pending_requests.lock().await;
match pending_req_lock.get(&response.request_id) {
Some(ResponseListener::Oneshot(_)) => {
// We must remove to be able to send the response to the oneshot channel
if let Some(ResponseListener::Oneshot(sender)) = pending_req_lock.remove(&response.request_id) {
match sender.send(response) {
Ok(()) => {}
Err(response) => {
log::warn!(
"Dropping response with request_id {} from plugin {}",
response.request_id,
self.plugin.lock().await.plugin_id
);
}
}
}
}
Some(ResponseListener::Multishot(sender)) => match sender.send(response).await {
Ok(()) => {}
Err(response) => {
Err(e) => {
log::warn!(
"Dropping response with request_id {} from plugin {}",
response.request_id,
e.0.request_id,
self.plugin.lock().await.plugin_id
);
}
@@ -155,21 +172,6 @@ impl StdioPluginConnection {
reader.read_exact(&mut message).await?;
Ok(message)
}
async fn await_response_to(
&self,
request_id: u32,
timeout: Duration,
) -> Result<proto::Response, Box<dyn Error + Send + Sync>> {
let (tx, rx) = oneshot::channel();
self.pending_requests.lock().await.insert(request_id, tx);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(e)) => Err(Box::new(e)),
Err(e) => Err(Box::new(e)),
}
}
}
#[async_trait]
@@ -178,20 +180,31 @@ impl PluginConnection for StdioPluginConnection {
&self,
config: String,
) -> Result<proto::PluginInitializeResponse, PluginConnectionError> {
// Generate request id
let request_id = {
let mut r = self.next_request_id.lock().await;
let id = *r;
*r += 1;
id
};
// Encode request
let request = proto::Request {
request_id,
req: Some(proto::request::Req::InitializeReq(
proto::PluginInitializeRequest { config },
)),
req: Some(proto::request::Req::InitializeReq(proto::PluginInitializeRequest {
config,
})),
}
.encode_length_delimited_to_vec();
// Add response listener
let (tx, rx) = oneshot::channel();
self.pending_requests
.lock()
.await
.insert(request_id, ResponseListener::Oneshot(tx));
// Send request
if let Err(e) = self
.process
.lock()
@@ -209,16 +222,20 @@ impl PluginConnection for StdioPluginConnection {
));
}
let response = match self
.await_response_to(request_id, Duration::from_secs(10))
.await
{
Ok(response) => response,
Err(e) => {
// Receive response
let response = match tokio::time::timeout(Duration::from_secs(10), rx).await {
Ok(Ok(response)) => response,
Ok(Err(e)) => {
return Err(PluginConnectionError::ReadResponse(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::Initialize,
e,
Box::new(e),
));
}
Err(_) => {
return Err(PluginConnectionError::Timeout(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::Initialize,
));
}
};
@@ -233,23 +250,30 @@ impl PluginConnection for StdioPluginConnection {
}
}
async fn request_plugin_command_list(
&self,
) -> Result<proto::PluginCommandListResponse, PluginConnectionError> {
async fn request_plugin_command_list(&self) -> Result<proto::PluginCommandListResponse, PluginConnectionError> {
// Generate request id
let request_id = {
let mut r = self.next_request_id.lock().await;
let id = *r;
*r += 1;
id
};
// Encode request
let request = proto::Request {
request_id,
req: Some(proto::request::Req::CommandListReq(
proto::PluginCommandListRequest {},
)),
req: Some(proto::request::Req::CommandListReq(proto::PluginCommandListRequest {})),
}
.encode_length_delimited_to_vec();
// Add response listener
let (tx, rx) = oneshot::channel();
self.pending_requests
.lock()
.await
.insert(request_id, ResponseListener::Oneshot(tx));
// Send request
if let Err(e) = self
.process
.lock()
@@ -267,16 +291,20 @@ impl PluginConnection for StdioPluginConnection {
));
}
let response = match self
.await_response_to(request_id, Duration::from_secs(10))
.await
{
Ok(response) => response,
Err(e) => {
// Receive response
let response = match tokio::time::timeout(Duration::from_secs(10), rx).await {
Ok(Ok(response)) => response,
Ok(Err(e)) => {
return Err(PluginConnectionError::ReadResponse(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::CommandList,
e,
Box::new(e),
));
}
Err(_) => {
return Err(PluginConnectionError::Timeout(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::CommandList,
));
}
};
@@ -296,28 +324,34 @@ impl PluginConnection for StdioPluginConnection {
command_id: String,
issuer_id: String,
argv: Vec<String>,
) -> Result<
mpsc::Receiver<Result<proto::CommandReply, PluginConnectionError>>,
PluginConnectionError,
> {
) -> Result<mpsc::Receiver<Result<proto::CommandReply, PluginConnectionError>>, PluginConnectionError> {
// Generate request id
let request_id = {
let mut r = self.next_request_id.lock().await;
let id = *r;
*r += 1;
id
};
// Encode request
let request = proto::Request {
request_id,
req: Some(proto::request::Req::ExecuteReq(
proto::PluginExecuteRequest {
command_id,
issuer_id,
arg_vector: argv,
},
)),
req: Some(proto::request::Req::ExecuteReq(proto::PluginExecuteRequest {
command_id,
issuer_id,
arg_vector: argv,
})),
}
.encode_length_delimited_to_vec();
// Add response listener
let (tx, mut rx) = mpsc::channel(4);
self.pending_requests
.lock()
.await
.insert(request_id, ResponseListener::Multishot(tx));
// Send request
if let Err(e) = self
.process
.lock()
@@ -335,26 +369,14 @@ impl PluginConnection for StdioPluginConnection {
));
}
let (tx, rx) = mpsc::channel(4);
// Receive responses
let (tx_out, rx_out) = mpsc::channel(4);
tokio::spawn(async move {
loop {
let response = match self
.await_response_to(request_id, Duration::from_secs(600))
.await
{
Ok(response) => response,
Err(e) => {
if let Err(e) = tx
.send(Err(PluginConnectionError::ReadResponse(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::Execute,
e,
)))
.await
{
log::error!("Cannot send error notification to another task: {e}");
}
let response = match rx.recv().await {
Some(response) => response,
None => {
log::error!("Response channel for request {request_id} closed unexpectedly");
break;
}
};
@@ -365,7 +387,8 @@ impl PluginConnection for StdioPluginConnection {
Some(proto::command_reply::Reply::End(_)) => true,
_ => false,
};
if let Err(e) = tx.send(Ok(reply)).await {
log::trace!("reply sent: Ok({:?})", &reply);
if let Err(e) = tx_out.send(Ok(reply)).await {
log::error!("Cannot send command reply to another task: {e}");
}
if end {
@@ -373,7 +396,7 @@ impl PluginConnection for StdioPluginConnection {
}
}
_ => {
if let Err(e) = tx
if let Err(e) = tx_out
.send(Err(PluginConnectionError::DecodeResponse(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::Execute,
@@ -387,8 +410,10 @@ impl PluginConnection for StdioPluginConnection {
}
}
}
// Remove response listener
self.pending_requests.lock().await.remove_entry(&request_id);
});
Ok(rx)
Ok(rx_out)
}
}