diff --git a/src/config/util.rs b/src/config/util.rs index 0dc60d3..18b9090 100644 --- a/src/config/util.rs +++ b/src/config/util.rs @@ -7,78 +7,107 @@ pub(crate) fn yaml_to_json(yaml: &Value) -> String { } fn write_yaml_to_json(yaml: &Value, out: &mut String) { + eprintln!("write_yaml_to_json {:?}", yaml); match yaml { Value::Null => { + eprintln!("write null"); out.push_str("null"); } Value::Bool(b) => { + eprintln!("write true/false"); out.push_str(if *b { "true" } else { "false" }); } Value::Number(n) => { + eprintln!("write {n}"); out.push_str(&n.to_string()); } Value::String(s) => { + eprintln!("write \""); out.push('"'); for c in s.encode_utf16() { match c { 0x005C | 0x0022 => { + eprintln!("write \\{}", char::from_u32(c as u32).unwrap()); out.push('\\'); out.push(char::from_u32(c as u32).unwrap()); } 0x0008 => { + eprintln!("write \\b"); out.push('\\'); out.push('b'); } 0x0009 => { + eprintln!("write \\t"); out.push('\\'); out.push('t'); } 0x000a => { + eprintln!("write \\n"); out.push('\\'); out.push('n'); } 0x000C => { + eprintln!("write \\f"); out.push('\\'); out.push('f'); } 0x0000..=0x001F | 0x0080..=0xFFFF => { + eprintln!("write \\u{c:04x}"); out.push_str(&format!("\\u{c:04x}")); } _ => { + eprintln!("write {}", char::from_u32(c as u32).unwrap()); out.push(char::from_u32(c as u32).unwrap()); } } } - out.push_str(s); // TODO escape + eprintln!("write \""); out.push('"'); } Value::Tagged(_) => {} Value::Sequence(list) => { + eprintln!("write ["); out.push('['); if !list.is_empty() { write_yaml_to_json(&list[0], out); for i in 1..list.len() { + eprintln!("write ,"); out.push(','); write_yaml_to_json(&list[i], out); } } + eprintln!("write ]"); out.push(']'); } Value::Mapping(map) => { + eprintln!("write {{"); out.push('{'); let mut iter = map.iter(); if let Some(kv) = iter.next() { - write_yaml_to_json(&Value::String(yaml_to_json(kv.0)), out); + let key = match kv.0 { + Value::String(_) => kv.0, + _ => &Value::String(yaml_to_json(kv.0)), + }; + write_yaml_to_json(key, out); + eprintln!("write :"); out.push(':'); write_yaml_to_json(kv.1, out); } for kv in iter { + eprintln!("write ,"); out.push(','); - write_yaml_to_json(&Value::String(yaml_to_json(kv.0)), out); + let key = match kv.0 { + Value::String(_) => kv.0, + _ => &Value::String(yaml_to_json(kv.0)), + }; + write_yaml_to_json(key, out); + eprintln!("write :"); out.push(':'); write_yaml_to_json(kv.1, out); } + eprintln!("write }}"); out.push('}'); } } + eprintln!("written {:?}", yaml); } diff --git a/src/main.rs b/src/main.rs index 62e5236..b2677fd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -224,6 +224,7 @@ async fn handle_message( Ok(mut replies) => { let mut last_message_id = None; while let Some(reply) = replies.recv().await { + log::trace!("reply recv'ed: Some({:?})", reply); match reply { Ok(reply) => { log::debug!( @@ -274,6 +275,7 @@ async fn handle_message( } } } + log::debug!("{} :: /{} command execution finished", &plugin_name, &cmd_name); } Err(e) => { log::error!( diff --git a/src/plugin/mod.rs b/src/plugin/mod.rs index 574bb93..412de92 100644 --- a/src/plugin/mod.rs +++ b/src/plugin/mod.rs @@ -74,6 +74,7 @@ pub(crate) trait PluginConnection: Send + Sync + Debug { #[derive(Debug)] pub(crate) enum PluginConnectionError { SendRequest(String, PluginRequestType, Box), + Timeout(String, PluginRequestType), ReadResponse(String, PluginRequestType, Box), DecodeResponse(String, PluginRequestType, Option), InvalidMessageLength, @@ -85,6 +86,9 @@ impl Display for PluginConnectionError { Self::SendRequest(plugin_name, req_type, e) => f.write_fmt(format_args!( "Can't send request ({req_type}) to plugin {plugin_name}: {e}" )), + Self::Timeout(plugin_name, req_type) => f.write_fmt(format_args!( + "Timed out waiting for response to {req_type} request to plugin {plugin_name}" + )), Self::ReadResponse(plugin_name, resp_type, e) => f.write_fmt(format_args!( "Can't read response ({resp_type}) from plugin {plugin_name}: {e}" )), @@ -101,6 +105,7 @@ impl Error for PluginConnectionError { fn source(&self) -> Option<&(dyn Error + 'static)> { match self { Self::SendRequest(_, _, e) => Some(e.as_ref()), + Self::Timeout(_, _) => None, Self::ReadResponse(_, _, e) => Some(e.as_ref()), Self::DecodeResponse(_, _, e) => e.as_ref().map(|e| e as &dyn Error), Self::InvalidMessageLength => None, diff --git a/src/plugin/stdio.rs b/src/plugin/stdio.rs index c781307..617a7d6 100644 --- a/src/plugin/stdio.rs +++ b/src/plugin/stdio.rs @@ -55,13 +55,19 @@ pub(super) async fn initialize_stdio_plugin( Ok(plugin) } +#[derive(Debug)] +enum ResponseListener { + Oneshot(oneshot::Sender), + Multishot(mpsc::Sender), +} + #[derive(Debug)] struct StdioPluginConnection { plugin: Arc>, process: Mutex, buffered_stdout: Mutex>, next_request_id: Mutex, - pending_requests: Mutex>>, + pending_requests: Mutex>, } impl StdioPluginConnection { @@ -103,18 +109,29 @@ impl StdioPluginConnection { continue; } }; - match self - .pending_requests - .lock() - .await - .remove(&response.request_id) - { - Some(sender) => match sender.send(response) { + let mut pending_req_lock = self.pending_requests.lock().await; + match pending_req_lock.get(&response.request_id) { + Some(ResponseListener::Oneshot(_)) => { + // We must remove to be able to send the response to the oneshot channel + if let Some(ResponseListener::Oneshot(sender)) = pending_req_lock.remove(&response.request_id) { + match sender.send(response) { + Ok(()) => {} + Err(response) => { + log::warn!( + "Dropping response with request_id {} from plugin {}", + response.request_id, + self.plugin.lock().await.plugin_id + ); + } + } + } + } + Some(ResponseListener::Multishot(sender)) => match sender.send(response).await { Ok(()) => {} - Err(response) => { + Err(e) => { log::warn!( "Dropping response with request_id {} from plugin {}", - response.request_id, + e.0.request_id, self.plugin.lock().await.plugin_id ); } @@ -155,21 +172,6 @@ impl StdioPluginConnection { reader.read_exact(&mut message).await?; Ok(message) } - - async fn await_response_to( - &self, - request_id: u32, - timeout: Duration, - ) -> Result> { - let (tx, rx) = oneshot::channel(); - self.pending_requests.lock().await.insert(request_id, tx); - - match tokio::time::timeout(timeout, rx).await { - Ok(Ok(response)) => Ok(response), - Ok(Err(e)) => Err(Box::new(e)), - Err(e) => Err(Box::new(e)), - } - } } #[async_trait] @@ -178,20 +180,31 @@ impl PluginConnection for StdioPluginConnection { &self, config: String, ) -> Result { + // Generate request id let request_id = { let mut r = self.next_request_id.lock().await; let id = *r; *r += 1; id }; + + // Encode request let request = proto::Request { request_id, - req: Some(proto::request::Req::InitializeReq( - proto::PluginInitializeRequest { config }, - )), + req: Some(proto::request::Req::InitializeReq(proto::PluginInitializeRequest { + config, + })), } .encode_length_delimited_to_vec(); + // Add response listener + let (tx, rx) = oneshot::channel(); + self.pending_requests + .lock() + .await + .insert(request_id, ResponseListener::Oneshot(tx)); + + // Send request if let Err(e) = self .process .lock() @@ -209,16 +222,20 @@ impl PluginConnection for StdioPluginConnection { )); } - let response = match self - .await_response_to(request_id, Duration::from_secs(10)) - .await - { - Ok(response) => response, - Err(e) => { + // Receive response + let response = match tokio::time::timeout(Duration::from_secs(10), rx).await { + Ok(Ok(response)) => response, + Ok(Err(e)) => { return Err(PluginConnectionError::ReadResponse( self.plugin.lock().await.plugin_id.clone(), PluginRequestType::Initialize, - e, + Box::new(e), + )); + } + Err(_) => { + return Err(PluginConnectionError::Timeout( + self.plugin.lock().await.plugin_id.clone(), + PluginRequestType::Initialize, )); } }; @@ -233,23 +250,30 @@ impl PluginConnection for StdioPluginConnection { } } - async fn request_plugin_command_list( - &self, - ) -> Result { + async fn request_plugin_command_list(&self) -> Result { + // Generate request id let request_id = { let mut r = self.next_request_id.lock().await; let id = *r; *r += 1; id }; + + // Encode request let request = proto::Request { request_id, - req: Some(proto::request::Req::CommandListReq( - proto::PluginCommandListRequest {}, - )), + req: Some(proto::request::Req::CommandListReq(proto::PluginCommandListRequest {})), } .encode_length_delimited_to_vec(); + // Add response listener + let (tx, rx) = oneshot::channel(); + self.pending_requests + .lock() + .await + .insert(request_id, ResponseListener::Oneshot(tx)); + + // Send request if let Err(e) = self .process .lock() @@ -267,16 +291,20 @@ impl PluginConnection for StdioPluginConnection { )); } - let response = match self - .await_response_to(request_id, Duration::from_secs(10)) - .await - { - Ok(response) => response, - Err(e) => { + // Receive response + let response = match tokio::time::timeout(Duration::from_secs(10), rx).await { + Ok(Ok(response)) => response, + Ok(Err(e)) => { return Err(PluginConnectionError::ReadResponse( self.plugin.lock().await.plugin_id.clone(), PluginRequestType::CommandList, - e, + Box::new(e), + )); + } + Err(_) => { + return Err(PluginConnectionError::Timeout( + self.plugin.lock().await.plugin_id.clone(), + PluginRequestType::CommandList, )); } }; @@ -296,28 +324,34 @@ impl PluginConnection for StdioPluginConnection { command_id: String, issuer_id: String, argv: Vec, - ) -> Result< - mpsc::Receiver>, - PluginConnectionError, - > { + ) -> Result>, PluginConnectionError> { + // Generate request id let request_id = { let mut r = self.next_request_id.lock().await; let id = *r; *r += 1; id }; + + // Encode request let request = proto::Request { request_id, - req: Some(proto::request::Req::ExecuteReq( - proto::PluginExecuteRequest { - command_id, - issuer_id, - arg_vector: argv, - }, - )), + req: Some(proto::request::Req::ExecuteReq(proto::PluginExecuteRequest { + command_id, + issuer_id, + arg_vector: argv, + })), } .encode_length_delimited_to_vec(); + // Add response listener + let (tx, mut rx) = mpsc::channel(4); + self.pending_requests + .lock() + .await + .insert(request_id, ResponseListener::Multishot(tx)); + + // Send request if let Err(e) = self .process .lock() @@ -335,26 +369,14 @@ impl PluginConnection for StdioPluginConnection { )); } - let (tx, rx) = mpsc::channel(4); - + // Receive responses + let (tx_out, rx_out) = mpsc::channel(4); tokio::spawn(async move { loop { - let response = match self - .await_response_to(request_id, Duration::from_secs(600)) - .await - { - Ok(response) => response, - Err(e) => { - if let Err(e) = tx - .send(Err(PluginConnectionError::ReadResponse( - self.plugin.lock().await.plugin_id.clone(), - PluginRequestType::Execute, - e, - ))) - .await - { - log::error!("Cannot send error notification to another task: {e}"); - } + let response = match rx.recv().await { + Some(response) => response, + None => { + log::error!("Response channel for request {request_id} closed unexpectedly"); break; } }; @@ -365,7 +387,8 @@ impl PluginConnection for StdioPluginConnection { Some(proto::command_reply::Reply::End(_)) => true, _ => false, }; - if let Err(e) = tx.send(Ok(reply)).await { + log::trace!("reply sent: Ok({:?})", &reply); + if let Err(e) = tx_out.send(Ok(reply)).await { log::error!("Cannot send command reply to another task: {e}"); } if end { @@ -373,7 +396,7 @@ impl PluginConnection for StdioPluginConnection { } } _ => { - if let Err(e) = tx + if let Err(e) = tx_out .send(Err(PluginConnectionError::DecodeResponse( self.plugin.lock().await.plugin_id.clone(), PluginRequestType::Execute, @@ -387,8 +410,10 @@ impl PluginConnection for StdioPluginConnection { } } } + // Remove response listener + self.pending_requests.lock().await.remove_entry(&request_id); }); - Ok(rx) + Ok(rx_out) } }