fix message loss with requests with multiple responses
This commit is contained in:
@@ -7,78 +7,107 @@ pub(crate) fn yaml_to_json(yaml: &Value) -> String {
|
||||
}
|
||||
|
||||
fn write_yaml_to_json(yaml: &Value, out: &mut String) {
|
||||
eprintln!("write_yaml_to_json {:?}", yaml);
|
||||
match yaml {
|
||||
Value::Null => {
|
||||
eprintln!("write null");
|
||||
out.push_str("null");
|
||||
}
|
||||
Value::Bool(b) => {
|
||||
eprintln!("write true/false");
|
||||
out.push_str(if *b { "true" } else { "false" });
|
||||
}
|
||||
Value::Number(n) => {
|
||||
eprintln!("write {n}");
|
||||
out.push_str(&n.to_string());
|
||||
}
|
||||
Value::String(s) => {
|
||||
eprintln!("write \"");
|
||||
out.push('"');
|
||||
for c in s.encode_utf16() {
|
||||
match c {
|
||||
0x005C | 0x0022 => {
|
||||
eprintln!("write \\{}", char::from_u32(c as u32).unwrap());
|
||||
out.push('\\');
|
||||
out.push(char::from_u32(c as u32).unwrap());
|
||||
}
|
||||
0x0008 => {
|
||||
eprintln!("write \\b");
|
||||
out.push('\\');
|
||||
out.push('b');
|
||||
}
|
||||
0x0009 => {
|
||||
eprintln!("write \\t");
|
||||
out.push('\\');
|
||||
out.push('t');
|
||||
}
|
||||
0x000a => {
|
||||
eprintln!("write \\n");
|
||||
out.push('\\');
|
||||
out.push('n');
|
||||
}
|
||||
0x000C => {
|
||||
eprintln!("write \\f");
|
||||
out.push('\\');
|
||||
out.push('f');
|
||||
}
|
||||
0x0000..=0x001F | 0x0080..=0xFFFF => {
|
||||
eprintln!("write \\u{c:04x}");
|
||||
out.push_str(&format!("\\u{c:04x}"));
|
||||
}
|
||||
_ => {
|
||||
eprintln!("write {}", char::from_u32(c as u32).unwrap());
|
||||
out.push(char::from_u32(c as u32).unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
out.push_str(s); // TODO escape
|
||||
eprintln!("write \"");
|
||||
out.push('"');
|
||||
}
|
||||
Value::Tagged(_) => {}
|
||||
Value::Sequence(list) => {
|
||||
eprintln!("write [");
|
||||
out.push('[');
|
||||
if !list.is_empty() {
|
||||
write_yaml_to_json(&list[0], out);
|
||||
for i in 1..list.len() {
|
||||
eprintln!("write ,");
|
||||
out.push(',');
|
||||
write_yaml_to_json(&list[i], out);
|
||||
}
|
||||
}
|
||||
eprintln!("write ]");
|
||||
out.push(']');
|
||||
}
|
||||
Value::Mapping(map) => {
|
||||
eprintln!("write {{");
|
||||
out.push('{');
|
||||
let mut iter = map.iter();
|
||||
if let Some(kv) = iter.next() {
|
||||
write_yaml_to_json(&Value::String(yaml_to_json(kv.0)), out);
|
||||
let key = match kv.0 {
|
||||
Value::String(_) => kv.0,
|
||||
_ => &Value::String(yaml_to_json(kv.0)),
|
||||
};
|
||||
write_yaml_to_json(key, out);
|
||||
eprintln!("write :");
|
||||
out.push(':');
|
||||
write_yaml_to_json(kv.1, out);
|
||||
}
|
||||
for kv in iter {
|
||||
eprintln!("write ,");
|
||||
out.push(',');
|
||||
write_yaml_to_json(&Value::String(yaml_to_json(kv.0)), out);
|
||||
let key = match kv.0 {
|
||||
Value::String(_) => kv.0,
|
||||
_ => &Value::String(yaml_to_json(kv.0)),
|
||||
};
|
||||
write_yaml_to_json(key, out);
|
||||
eprintln!("write :");
|
||||
out.push(':');
|
||||
write_yaml_to_json(kv.1, out);
|
||||
}
|
||||
eprintln!("write }}");
|
||||
out.push('}');
|
||||
}
|
||||
}
|
||||
eprintln!("written {:?}", yaml);
|
||||
}
|
||||
|
||||
@@ -224,6 +224,7 @@ async fn handle_message(
|
||||
Ok(mut replies) => {
|
||||
let mut last_message_id = None;
|
||||
while let Some(reply) = replies.recv().await {
|
||||
log::trace!("reply recv'ed: Some({:?})", reply);
|
||||
match reply {
|
||||
Ok(reply) => {
|
||||
log::debug!(
|
||||
@@ -274,6 +275,7 @@ async fn handle_message(
|
||||
}
|
||||
}
|
||||
}
|
||||
log::debug!("{} :: /{} command execution finished", &plugin_name, &cmd_name);
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!(
|
||||
|
||||
@@ -74,6 +74,7 @@ pub(crate) trait PluginConnection: Send + Sync + Debug {
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum PluginConnectionError {
|
||||
SendRequest(String, PluginRequestType, Box<dyn Error + Send + Sync>),
|
||||
Timeout(String, PluginRequestType),
|
||||
ReadResponse(String, PluginRequestType, Box<dyn Error + Send + Sync>),
|
||||
DecodeResponse(String, PluginRequestType, Option<DecodeError>),
|
||||
InvalidMessageLength,
|
||||
@@ -85,6 +86,9 @@ impl Display for PluginConnectionError {
|
||||
Self::SendRequest(plugin_name, req_type, e) => f.write_fmt(format_args!(
|
||||
"Can't send request ({req_type}) to plugin {plugin_name}: {e}"
|
||||
)),
|
||||
Self::Timeout(plugin_name, req_type) => f.write_fmt(format_args!(
|
||||
"Timed out waiting for response to {req_type} request to plugin {plugin_name}"
|
||||
)),
|
||||
Self::ReadResponse(plugin_name, resp_type, e) => f.write_fmt(format_args!(
|
||||
"Can't read response ({resp_type}) from plugin {plugin_name}: {e}"
|
||||
)),
|
||||
@@ -101,6 +105,7 @@ impl Error for PluginConnectionError {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
match self {
|
||||
Self::SendRequest(_, _, e) => Some(e.as_ref()),
|
||||
Self::Timeout(_, _) => None,
|
||||
Self::ReadResponse(_, _, e) => Some(e.as_ref()),
|
||||
Self::DecodeResponse(_, _, e) => e.as_ref().map(|e| e as &dyn Error),
|
||||
Self::InvalidMessageLength => None,
|
||||
|
||||
@@ -55,13 +55,19 @@ pub(super) async fn initialize_stdio_plugin(
|
||||
Ok(plugin)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ResponseListener {
|
||||
Oneshot(oneshot::Sender<proto::Response>),
|
||||
Multishot(mpsc::Sender<proto::Response>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct StdioPluginConnection {
|
||||
plugin: Arc<Mutex<LoadedPlugin>>,
|
||||
process: Mutex<Child>,
|
||||
buffered_stdout: Mutex<BufReader<ChildStdout>>,
|
||||
next_request_id: Mutex<u32>,
|
||||
pending_requests: Mutex<HashMap<u32, oneshot::Sender<proto::Response>>>,
|
||||
pending_requests: Mutex<HashMap<u32, ResponseListener>>,
|
||||
}
|
||||
|
||||
impl StdioPluginConnection {
|
||||
@@ -103,18 +109,29 @@ impl StdioPluginConnection {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
match self
|
||||
.pending_requests
|
||||
.lock()
|
||||
.await
|
||||
.remove(&response.request_id)
|
||||
{
|
||||
Some(sender) => match sender.send(response) {
|
||||
let mut pending_req_lock = self.pending_requests.lock().await;
|
||||
match pending_req_lock.get(&response.request_id) {
|
||||
Some(ResponseListener::Oneshot(_)) => {
|
||||
// We must remove to be able to send the response to the oneshot channel
|
||||
if let Some(ResponseListener::Oneshot(sender)) = pending_req_lock.remove(&response.request_id) {
|
||||
match sender.send(response) {
|
||||
Ok(()) => {}
|
||||
Err(response) => {
|
||||
log::warn!(
|
||||
"Dropping response with request_id {} from plugin {}",
|
||||
response.request_id,
|
||||
self.plugin.lock().await.plugin_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(ResponseListener::Multishot(sender)) => match sender.send(response).await {
|
||||
Ok(()) => {}
|
||||
Err(response) => {
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"Dropping response with request_id {} from plugin {}",
|
||||
response.request_id,
|
||||
e.0.request_id,
|
||||
self.plugin.lock().await.plugin_id
|
||||
);
|
||||
}
|
||||
@@ -155,21 +172,6 @@ impl StdioPluginConnection {
|
||||
reader.read_exact(&mut message).await?;
|
||||
Ok(message)
|
||||
}
|
||||
|
||||
async fn await_response_to(
|
||||
&self,
|
||||
request_id: u32,
|
||||
timeout: Duration,
|
||||
) -> Result<proto::Response, Box<dyn Error + Send + Sync>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.pending_requests.lock().await.insert(request_id, tx);
|
||||
|
||||
match tokio::time::timeout(timeout, rx).await {
|
||||
Ok(Ok(response)) => Ok(response),
|
||||
Ok(Err(e)) => Err(Box::new(e)),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -178,20 +180,31 @@ impl PluginConnection for StdioPluginConnection {
|
||||
&self,
|
||||
config: String,
|
||||
) -> Result<proto::PluginInitializeResponse, PluginConnectionError> {
|
||||
// Generate request id
|
||||
let request_id = {
|
||||
let mut r = self.next_request_id.lock().await;
|
||||
let id = *r;
|
||||
*r += 1;
|
||||
id
|
||||
};
|
||||
|
||||
// Encode request
|
||||
let request = proto::Request {
|
||||
request_id,
|
||||
req: Some(proto::request::Req::InitializeReq(
|
||||
proto::PluginInitializeRequest { config },
|
||||
)),
|
||||
req: Some(proto::request::Req::InitializeReq(proto::PluginInitializeRequest {
|
||||
config,
|
||||
})),
|
||||
}
|
||||
.encode_length_delimited_to_vec();
|
||||
|
||||
// Add response listener
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.pending_requests
|
||||
.lock()
|
||||
.await
|
||||
.insert(request_id, ResponseListener::Oneshot(tx));
|
||||
|
||||
// Send request
|
||||
if let Err(e) = self
|
||||
.process
|
||||
.lock()
|
||||
@@ -209,16 +222,20 @@ impl PluginConnection for StdioPluginConnection {
|
||||
));
|
||||
}
|
||||
|
||||
let response = match self
|
||||
.await_response_to(request_id, Duration::from_secs(10))
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
// Receive response
|
||||
let response = match tokio::time::timeout(Duration::from_secs(10), rx).await {
|
||||
Ok(Ok(response)) => response,
|
||||
Ok(Err(e)) => {
|
||||
return Err(PluginConnectionError::ReadResponse(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::Initialize,
|
||||
e,
|
||||
Box::new(e),
|
||||
));
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(PluginConnectionError::Timeout(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::Initialize,
|
||||
));
|
||||
}
|
||||
};
|
||||
@@ -233,23 +250,30 @@ impl PluginConnection for StdioPluginConnection {
|
||||
}
|
||||
}
|
||||
|
||||
async fn request_plugin_command_list(
|
||||
&self,
|
||||
) -> Result<proto::PluginCommandListResponse, PluginConnectionError> {
|
||||
async fn request_plugin_command_list(&self) -> Result<proto::PluginCommandListResponse, PluginConnectionError> {
|
||||
// Generate request id
|
||||
let request_id = {
|
||||
let mut r = self.next_request_id.lock().await;
|
||||
let id = *r;
|
||||
*r += 1;
|
||||
id
|
||||
};
|
||||
|
||||
// Encode request
|
||||
let request = proto::Request {
|
||||
request_id,
|
||||
req: Some(proto::request::Req::CommandListReq(
|
||||
proto::PluginCommandListRequest {},
|
||||
)),
|
||||
req: Some(proto::request::Req::CommandListReq(proto::PluginCommandListRequest {})),
|
||||
}
|
||||
.encode_length_delimited_to_vec();
|
||||
|
||||
// Add response listener
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.pending_requests
|
||||
.lock()
|
||||
.await
|
||||
.insert(request_id, ResponseListener::Oneshot(tx));
|
||||
|
||||
// Send request
|
||||
if let Err(e) = self
|
||||
.process
|
||||
.lock()
|
||||
@@ -267,16 +291,20 @@ impl PluginConnection for StdioPluginConnection {
|
||||
));
|
||||
}
|
||||
|
||||
let response = match self
|
||||
.await_response_to(request_id, Duration::from_secs(10))
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
// Receive response
|
||||
let response = match tokio::time::timeout(Duration::from_secs(10), rx).await {
|
||||
Ok(Ok(response)) => response,
|
||||
Ok(Err(e)) => {
|
||||
return Err(PluginConnectionError::ReadResponse(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::CommandList,
|
||||
e,
|
||||
Box::new(e),
|
||||
));
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(PluginConnectionError::Timeout(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::CommandList,
|
||||
));
|
||||
}
|
||||
};
|
||||
@@ -296,28 +324,34 @@ impl PluginConnection for StdioPluginConnection {
|
||||
command_id: String,
|
||||
issuer_id: String,
|
||||
argv: Vec<String>,
|
||||
) -> Result<
|
||||
mpsc::Receiver<Result<proto::CommandReply, PluginConnectionError>>,
|
||||
PluginConnectionError,
|
||||
> {
|
||||
) -> Result<mpsc::Receiver<Result<proto::CommandReply, PluginConnectionError>>, PluginConnectionError> {
|
||||
// Generate request id
|
||||
let request_id = {
|
||||
let mut r = self.next_request_id.lock().await;
|
||||
let id = *r;
|
||||
*r += 1;
|
||||
id
|
||||
};
|
||||
|
||||
// Encode request
|
||||
let request = proto::Request {
|
||||
request_id,
|
||||
req: Some(proto::request::Req::ExecuteReq(
|
||||
proto::PluginExecuteRequest {
|
||||
command_id,
|
||||
issuer_id,
|
||||
arg_vector: argv,
|
||||
},
|
||||
)),
|
||||
req: Some(proto::request::Req::ExecuteReq(proto::PluginExecuteRequest {
|
||||
command_id,
|
||||
issuer_id,
|
||||
arg_vector: argv,
|
||||
})),
|
||||
}
|
||||
.encode_length_delimited_to_vec();
|
||||
|
||||
// Add response listener
|
||||
let (tx, mut rx) = mpsc::channel(4);
|
||||
self.pending_requests
|
||||
.lock()
|
||||
.await
|
||||
.insert(request_id, ResponseListener::Multishot(tx));
|
||||
|
||||
// Send request
|
||||
if let Err(e) = self
|
||||
.process
|
||||
.lock()
|
||||
@@ -335,26 +369,14 @@ impl PluginConnection for StdioPluginConnection {
|
||||
));
|
||||
}
|
||||
|
||||
let (tx, rx) = mpsc::channel(4);
|
||||
|
||||
// Receive responses
|
||||
let (tx_out, rx_out) = mpsc::channel(4);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let response = match self
|
||||
.await_response_to(request_id, Duration::from_secs(600))
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
if let Err(e) = tx
|
||||
.send(Err(PluginConnectionError::ReadResponse(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::Execute,
|
||||
e,
|
||||
)))
|
||||
.await
|
||||
{
|
||||
log::error!("Cannot send error notification to another task: {e}");
|
||||
}
|
||||
let response = match rx.recv().await {
|
||||
Some(response) => response,
|
||||
None => {
|
||||
log::error!("Response channel for request {request_id} closed unexpectedly");
|
||||
break;
|
||||
}
|
||||
};
|
||||
@@ -365,7 +387,8 @@ impl PluginConnection for StdioPluginConnection {
|
||||
Some(proto::command_reply::Reply::End(_)) => true,
|
||||
_ => false,
|
||||
};
|
||||
if let Err(e) = tx.send(Ok(reply)).await {
|
||||
log::trace!("reply sent: Ok({:?})", &reply);
|
||||
if let Err(e) = tx_out.send(Ok(reply)).await {
|
||||
log::error!("Cannot send command reply to another task: {e}");
|
||||
}
|
||||
if end {
|
||||
@@ -373,7 +396,7 @@ impl PluginConnection for StdioPluginConnection {
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if let Err(e) = tx
|
||||
if let Err(e) = tx_out
|
||||
.send(Err(PluginConnectionError::DecodeResponse(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::Execute,
|
||||
@@ -387,8 +410,10 @@ impl PluginConnection for StdioPluginConnection {
|
||||
}
|
||||
}
|
||||
}
|
||||
// Remove response listener
|
||||
self.pending_requests.lock().await.remove_entry(&request_id);
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
Ok(rx_out)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user