fix message loss with requests with multiple responses

This commit is contained in:
2026-05-14 00:04:51 +03:00
parent 21556aeb04
commit 5d2f0a5011
4 changed files with 144 additions and 83 deletions

View File

@@ -7,78 +7,107 @@ pub(crate) fn yaml_to_json(yaml: &Value) -> String {
} }
fn write_yaml_to_json(yaml: &Value, out: &mut String) { fn write_yaml_to_json(yaml: &Value, out: &mut String) {
eprintln!("write_yaml_to_json {:?}", yaml);
match yaml { match yaml {
Value::Null => { Value::Null => {
eprintln!("write null");
out.push_str("null"); out.push_str("null");
} }
Value::Bool(b) => { Value::Bool(b) => {
eprintln!("write true/false");
out.push_str(if *b { "true" } else { "false" }); out.push_str(if *b { "true" } else { "false" });
} }
Value::Number(n) => { Value::Number(n) => {
eprintln!("write {n}");
out.push_str(&n.to_string()); out.push_str(&n.to_string());
} }
Value::String(s) => { Value::String(s) => {
eprintln!("write \"");
out.push('"'); out.push('"');
for c in s.encode_utf16() { for c in s.encode_utf16() {
match c { match c {
0x005C | 0x0022 => { 0x005C | 0x0022 => {
eprintln!("write \\{}", char::from_u32(c as u32).unwrap());
out.push('\\'); out.push('\\');
out.push(char::from_u32(c as u32).unwrap()); out.push(char::from_u32(c as u32).unwrap());
} }
0x0008 => { 0x0008 => {
eprintln!("write \\b");
out.push('\\'); out.push('\\');
out.push('b'); out.push('b');
} }
0x0009 => { 0x0009 => {
eprintln!("write \\t");
out.push('\\'); out.push('\\');
out.push('t'); out.push('t');
} }
0x000a => { 0x000a => {
eprintln!("write \\n");
out.push('\\'); out.push('\\');
out.push('n'); out.push('n');
} }
0x000C => { 0x000C => {
eprintln!("write \\f");
out.push('\\'); out.push('\\');
out.push('f'); out.push('f');
} }
0x0000..=0x001F | 0x0080..=0xFFFF => { 0x0000..=0x001F | 0x0080..=0xFFFF => {
eprintln!("write \\u{c:04x}");
out.push_str(&format!("\\u{c:04x}")); out.push_str(&format!("\\u{c:04x}"));
} }
_ => { _ => {
eprintln!("write {}", char::from_u32(c as u32).unwrap());
out.push(char::from_u32(c as u32).unwrap()); out.push(char::from_u32(c as u32).unwrap());
} }
} }
} }
out.push_str(s); // TODO escape eprintln!("write \"");
out.push('"'); out.push('"');
} }
Value::Tagged(_) => {} Value::Tagged(_) => {}
Value::Sequence(list) => { Value::Sequence(list) => {
eprintln!("write [");
out.push('['); out.push('[');
if !list.is_empty() { if !list.is_empty() {
write_yaml_to_json(&list[0], out); write_yaml_to_json(&list[0], out);
for i in 1..list.len() { for i in 1..list.len() {
eprintln!("write ,");
out.push(','); out.push(',');
write_yaml_to_json(&list[i], out); write_yaml_to_json(&list[i], out);
} }
} }
eprintln!("write ]");
out.push(']'); out.push(']');
} }
Value::Mapping(map) => { Value::Mapping(map) => {
eprintln!("write {{");
out.push('{'); out.push('{');
let mut iter = map.iter(); let mut iter = map.iter();
if let Some(kv) = iter.next() { if let Some(kv) = iter.next() {
write_yaml_to_json(&Value::String(yaml_to_json(kv.0)), out); let key = match kv.0 {
Value::String(_) => kv.0,
_ => &Value::String(yaml_to_json(kv.0)),
};
write_yaml_to_json(key, out);
eprintln!("write :");
out.push(':'); out.push(':');
write_yaml_to_json(kv.1, out); write_yaml_to_json(kv.1, out);
} }
for kv in iter { for kv in iter {
eprintln!("write ,");
out.push(','); out.push(',');
write_yaml_to_json(&Value::String(yaml_to_json(kv.0)), out); let key = match kv.0 {
Value::String(_) => kv.0,
_ => &Value::String(yaml_to_json(kv.0)),
};
write_yaml_to_json(key, out);
eprintln!("write :");
out.push(':'); out.push(':');
write_yaml_to_json(kv.1, out); write_yaml_to_json(kv.1, out);
} }
eprintln!("write }}");
out.push('}'); out.push('}');
} }
} }
eprintln!("written {:?}", yaml);
} }

View File

@@ -224,6 +224,7 @@ async fn handle_message(
Ok(mut replies) => { Ok(mut replies) => {
let mut last_message_id = None; let mut last_message_id = None;
while let Some(reply) = replies.recv().await { while let Some(reply) = replies.recv().await {
log::trace!("reply recv'ed: Some({:?})", reply);
match reply { match reply {
Ok(reply) => { Ok(reply) => {
log::debug!( log::debug!(
@@ -274,6 +275,7 @@ async fn handle_message(
} }
} }
} }
log::debug!("{} :: /{} command execution finished", &plugin_name, &cmd_name);
} }
Err(e) => { Err(e) => {
log::error!( log::error!(

View File

@@ -74,6 +74,7 @@ pub(crate) trait PluginConnection: Send + Sync + Debug {
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum PluginConnectionError { pub(crate) enum PluginConnectionError {
SendRequest(String, PluginRequestType, Box<dyn Error + Send + Sync>), SendRequest(String, PluginRequestType, Box<dyn Error + Send + Sync>),
Timeout(String, PluginRequestType),
ReadResponse(String, PluginRequestType, Box<dyn Error + Send + Sync>), ReadResponse(String, PluginRequestType, Box<dyn Error + Send + Sync>),
DecodeResponse(String, PluginRequestType, Option<DecodeError>), DecodeResponse(String, PluginRequestType, Option<DecodeError>),
InvalidMessageLength, InvalidMessageLength,
@@ -85,6 +86,9 @@ impl Display for PluginConnectionError {
Self::SendRequest(plugin_name, req_type, e) => f.write_fmt(format_args!( Self::SendRequest(plugin_name, req_type, e) => f.write_fmt(format_args!(
"Can't send request ({req_type}) to plugin {plugin_name}: {e}" "Can't send request ({req_type}) to plugin {plugin_name}: {e}"
)), )),
Self::Timeout(plugin_name, req_type) => f.write_fmt(format_args!(
"Timed out waiting for response to {req_type} request to plugin {plugin_name}"
)),
Self::ReadResponse(plugin_name, resp_type, e) => f.write_fmt(format_args!( Self::ReadResponse(plugin_name, resp_type, e) => f.write_fmt(format_args!(
"Can't read response ({resp_type}) from plugin {plugin_name}: {e}" "Can't read response ({resp_type}) from plugin {plugin_name}: {e}"
)), )),
@@ -101,6 +105,7 @@ impl Error for PluginConnectionError {
fn source(&self) -> Option<&(dyn Error + 'static)> { fn source(&self) -> Option<&(dyn Error + 'static)> {
match self { match self {
Self::SendRequest(_, _, e) => Some(e.as_ref()), Self::SendRequest(_, _, e) => Some(e.as_ref()),
Self::Timeout(_, _) => None,
Self::ReadResponse(_, _, e) => Some(e.as_ref()), Self::ReadResponse(_, _, e) => Some(e.as_ref()),
Self::DecodeResponse(_, _, e) => e.as_ref().map(|e| e as &dyn Error), Self::DecodeResponse(_, _, e) => e.as_ref().map(|e| e as &dyn Error),
Self::InvalidMessageLength => None, Self::InvalidMessageLength => None,

View File

@@ -55,13 +55,19 @@ pub(super) async fn initialize_stdio_plugin(
Ok(plugin) Ok(plugin)
} }
#[derive(Debug)]
enum ResponseListener {
Oneshot(oneshot::Sender<proto::Response>),
Multishot(mpsc::Sender<proto::Response>),
}
#[derive(Debug)] #[derive(Debug)]
struct StdioPluginConnection { struct StdioPluginConnection {
plugin: Arc<Mutex<LoadedPlugin>>, plugin: Arc<Mutex<LoadedPlugin>>,
process: Mutex<Child>, process: Mutex<Child>,
buffered_stdout: Mutex<BufReader<ChildStdout>>, buffered_stdout: Mutex<BufReader<ChildStdout>>,
next_request_id: Mutex<u32>, next_request_id: Mutex<u32>,
pending_requests: Mutex<HashMap<u32, oneshot::Sender<proto::Response>>>, pending_requests: Mutex<HashMap<u32, ResponseListener>>,
} }
impl StdioPluginConnection { impl StdioPluginConnection {
@@ -103,13 +109,12 @@ impl StdioPluginConnection {
continue; continue;
} }
}; };
match self let mut pending_req_lock = self.pending_requests.lock().await;
.pending_requests match pending_req_lock.get(&response.request_id) {
.lock() Some(ResponseListener::Oneshot(_)) => {
.await // We must remove to be able to send the response to the oneshot channel
.remove(&response.request_id) if let Some(ResponseListener::Oneshot(sender)) = pending_req_lock.remove(&response.request_id) {
{ match sender.send(response) {
Some(sender) => match sender.send(response) {
Ok(()) => {} Ok(()) => {}
Err(response) => { Err(response) => {
log::warn!( log::warn!(
@@ -118,6 +123,18 @@ impl StdioPluginConnection {
self.plugin.lock().await.plugin_id self.plugin.lock().await.plugin_id
); );
} }
}
}
}
Some(ResponseListener::Multishot(sender)) => match sender.send(response).await {
Ok(()) => {}
Err(e) => {
log::warn!(
"Dropping response with request_id {} from plugin {}",
e.0.request_id,
self.plugin.lock().await.plugin_id
);
}
}, },
None => { None => {
continue; continue;
@@ -155,21 +172,6 @@ impl StdioPluginConnection {
reader.read_exact(&mut message).await?; reader.read_exact(&mut message).await?;
Ok(message) Ok(message)
} }
async fn await_response_to(
&self,
request_id: u32,
timeout: Duration,
) -> Result<proto::Response, Box<dyn Error + Send + Sync>> {
let (tx, rx) = oneshot::channel();
self.pending_requests.lock().await.insert(request_id, tx);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(e)) => Err(Box::new(e)),
Err(e) => Err(Box::new(e)),
}
}
} }
#[async_trait] #[async_trait]
@@ -178,20 +180,31 @@ impl PluginConnection for StdioPluginConnection {
&self, &self,
config: String, config: String,
) -> Result<proto::PluginInitializeResponse, PluginConnectionError> { ) -> Result<proto::PluginInitializeResponse, PluginConnectionError> {
// Generate request id
let request_id = { let request_id = {
let mut r = self.next_request_id.lock().await; let mut r = self.next_request_id.lock().await;
let id = *r; let id = *r;
*r += 1; *r += 1;
id id
}; };
// Encode request
let request = proto::Request { let request = proto::Request {
request_id, request_id,
req: Some(proto::request::Req::InitializeReq( req: Some(proto::request::Req::InitializeReq(proto::PluginInitializeRequest {
proto::PluginInitializeRequest { config }, config,
)), })),
} }
.encode_length_delimited_to_vec(); .encode_length_delimited_to_vec();
// Add response listener
let (tx, rx) = oneshot::channel();
self.pending_requests
.lock()
.await
.insert(request_id, ResponseListener::Oneshot(tx));
// Send request
if let Err(e) = self if let Err(e) = self
.process .process
.lock() .lock()
@@ -209,16 +222,20 @@ impl PluginConnection for StdioPluginConnection {
)); ));
} }
let response = match self // Receive response
.await_response_to(request_id, Duration::from_secs(10)) let response = match tokio::time::timeout(Duration::from_secs(10), rx).await {
.await Ok(Ok(response)) => response,
{ Ok(Err(e)) => {
Ok(response) => response,
Err(e) => {
return Err(PluginConnectionError::ReadResponse( return Err(PluginConnectionError::ReadResponse(
self.plugin.lock().await.plugin_id.clone(), self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::Initialize, PluginRequestType::Initialize,
e, Box::new(e),
));
}
Err(_) => {
return Err(PluginConnectionError::Timeout(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::Initialize,
)); ));
} }
}; };
@@ -233,23 +250,30 @@ impl PluginConnection for StdioPluginConnection {
} }
} }
async fn request_plugin_command_list( async fn request_plugin_command_list(&self) -> Result<proto::PluginCommandListResponse, PluginConnectionError> {
&self, // Generate request id
) -> Result<proto::PluginCommandListResponse, PluginConnectionError> {
let request_id = { let request_id = {
let mut r = self.next_request_id.lock().await; let mut r = self.next_request_id.lock().await;
let id = *r; let id = *r;
*r += 1; *r += 1;
id id
}; };
// Encode request
let request = proto::Request { let request = proto::Request {
request_id, request_id,
req: Some(proto::request::Req::CommandListReq( req: Some(proto::request::Req::CommandListReq(proto::PluginCommandListRequest {})),
proto::PluginCommandListRequest {},
)),
} }
.encode_length_delimited_to_vec(); .encode_length_delimited_to_vec();
// Add response listener
let (tx, rx) = oneshot::channel();
self.pending_requests
.lock()
.await
.insert(request_id, ResponseListener::Oneshot(tx));
// Send request
if let Err(e) = self if let Err(e) = self
.process .process
.lock() .lock()
@@ -267,16 +291,20 @@ impl PluginConnection for StdioPluginConnection {
)); ));
} }
let response = match self // Receive response
.await_response_to(request_id, Duration::from_secs(10)) let response = match tokio::time::timeout(Duration::from_secs(10), rx).await {
.await Ok(Ok(response)) => response,
{ Ok(Err(e)) => {
Ok(response) => response,
Err(e) => {
return Err(PluginConnectionError::ReadResponse( return Err(PluginConnectionError::ReadResponse(
self.plugin.lock().await.plugin_id.clone(), self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::CommandList, PluginRequestType::CommandList,
e, Box::new(e),
));
}
Err(_) => {
return Err(PluginConnectionError::Timeout(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::CommandList,
)); ));
} }
}; };
@@ -296,28 +324,34 @@ impl PluginConnection for StdioPluginConnection {
command_id: String, command_id: String,
issuer_id: String, issuer_id: String,
argv: Vec<String>, argv: Vec<String>,
) -> Result< ) -> Result<mpsc::Receiver<Result<proto::CommandReply, PluginConnectionError>>, PluginConnectionError> {
mpsc::Receiver<Result<proto::CommandReply, PluginConnectionError>>, // Generate request id
PluginConnectionError,
> {
let request_id = { let request_id = {
let mut r = self.next_request_id.lock().await; let mut r = self.next_request_id.lock().await;
let id = *r; let id = *r;
*r += 1; *r += 1;
id id
}; };
// Encode request
let request = proto::Request { let request = proto::Request {
request_id, request_id,
req: Some(proto::request::Req::ExecuteReq( req: Some(proto::request::Req::ExecuteReq(proto::PluginExecuteRequest {
proto::PluginExecuteRequest {
command_id, command_id,
issuer_id, issuer_id,
arg_vector: argv, arg_vector: argv,
}, })),
)),
} }
.encode_length_delimited_to_vec(); .encode_length_delimited_to_vec();
// Add response listener
let (tx, mut rx) = mpsc::channel(4);
self.pending_requests
.lock()
.await
.insert(request_id, ResponseListener::Multishot(tx));
// Send request
if let Err(e) = self if let Err(e) = self
.process .process
.lock() .lock()
@@ -335,26 +369,14 @@ impl PluginConnection for StdioPluginConnection {
)); ));
} }
let (tx, rx) = mpsc::channel(4); // Receive responses
let (tx_out, rx_out) = mpsc::channel(4);
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
let response = match self let response = match rx.recv().await {
.await_response_to(request_id, Duration::from_secs(600)) Some(response) => response,
.await None => {
{ log::error!("Response channel for request {request_id} closed unexpectedly");
Ok(response) => response,
Err(e) => {
if let Err(e) = tx
.send(Err(PluginConnectionError::ReadResponse(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::Execute,
e,
)))
.await
{
log::error!("Cannot send error notification to another task: {e}");
}
break; break;
} }
}; };
@@ -365,7 +387,8 @@ impl PluginConnection for StdioPluginConnection {
Some(proto::command_reply::Reply::End(_)) => true, Some(proto::command_reply::Reply::End(_)) => true,
_ => false, _ => false,
}; };
if let Err(e) = tx.send(Ok(reply)).await { log::trace!("reply sent: Ok({:?})", &reply);
if let Err(e) = tx_out.send(Ok(reply)).await {
log::error!("Cannot send command reply to another task: {e}"); log::error!("Cannot send command reply to another task: {e}");
} }
if end { if end {
@@ -373,7 +396,7 @@ impl PluginConnection for StdioPluginConnection {
} }
} }
_ => { _ => {
if let Err(e) = tx if let Err(e) = tx_out
.send(Err(PluginConnectionError::DecodeResponse( .send(Err(PluginConnectionError::DecodeResponse(
self.plugin.lock().await.plugin_id.clone(), self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::Execute, PluginRequestType::Execute,
@@ -387,8 +410,10 @@ impl PluginConnection for StdioPluginConnection {
} }
} }
} }
// Remove response listener
self.pending_requests.lock().await.remove_entry(&request_id);
}); });
Ok(rx) Ok(rx_out)
} }
} }