diff --git a/Cargo.lock b/Cargo.lock index b5397b1..0f402be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -653,11 +653,14 @@ name = "bot" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "clap", "deltachat", "env_logger", "eui48", "log", + "prost", + "prost-build", "russh", "serde", "serde_yaml", @@ -1866,6 +1869,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + [[package]] name = "elliptic-curve" version = "0.13.8" @@ -3390,6 +3399,15 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.17" @@ -3748,6 +3766,12 @@ dependencies = [ "pxfm", ] +[[package]] +name = "multimap" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" + [[package]] name = "mutate_once" version = "0.1.2" @@ -4438,6 +4462,17 @@ dependencies = [ "sha2 0.10.9", ] +[[package]] +name = "petgraph" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" +dependencies = [ + "fixedbitset", + "hashbrown 0.15.5", + "indexmap", +] + [[package]] name = "pgp" version = "0.19.0" @@ -4928,6 +4963,57 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prost" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-build" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" +dependencies = [ + "heck", + "itertools", + "log", + "multimap", + "petgraph", + "prettyplease", + "prost", + "prost-types", + "regex", + "syn", + "tempfile", +] + +[[package]] +name = "prost-derive" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "prost-types" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" +dependencies = [ + "prost", +] + [[package]] name = "pxfm" version = "0.1.28" diff --git a/Cargo.toml b/Cargo.toml index d7e2258..821c5a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,12 +5,17 @@ edition = "2024" [dependencies] anyhow = "1.0.102" +async-trait = "0.1.89" clap = { version = "4", features = [ "derive" ] } deltachat = { path = "./chatmail-core" } env_logger = "0.11.9" eui48 = { version = "1.1.0", features = [ "serde" ] } log = { version = "0.4.29" } +prost = "0.14.3" russh = { version = "0.60.0" } serde = { version = "1", features = [ "derive" ] } serde_yaml = { version = "0.9" } tokio = { version = "1.50.0", features = ["full"] } + +[build-dependencies] +prost-build = "0.14.3" diff --git a/bot.example.yml b/bot.example.yml index 16b84f6..f05a767 100644 --- a/bot.example.yml +++ b/bot.example.yml @@ -16,3 +16,7 @@ deltaChat: email: "0000000000000000@email.com" password: "kek-a-kek" avatar: "avatar.png" + +plugins: + - name: additional-commands-plus + enabled: true diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..6520a2d --- /dev/null +++ b/build.rs @@ -0,0 +1,5 @@ +use std::io::Result; + +fn main() -> Result<()> { + prost_build::compile_protos(&["protobuf/plugin.proto"], &["protobuf/"]) +} diff --git a/protobuf/plugin.proto b/protobuf/plugin.proto new file mode 100644 index 0000000..34033c8 --- /dev/null +++ b/protobuf/plugin.proto @@ -0,0 +1,53 @@ +syntax = "proto3"; +package deltachat_remotecontrol_bot.plugin; + +message Request { + uint32 request_id = 1; + oneof req { + PluginInitializeRequest initialize_req = 10; + PluginCommandListRequest command_list_req = 11; + PluginExecuteRequest execute_req = 12; + } +} + +message PluginInitializeRequest { + string config = 1; +} + +message PluginCommandListRequest {} + +message PluginExecuteRequest { + string command_id = 1; + repeated string arg_vector = 5; +} + +message Response { + uint32 request_id = 1; + oneof res { + PluginInitializeResponse initialize_res = 10; + PluginCommandListResponse command_list_res = 11; + PluginExecuteResponse execute_res = 12; + } +} + +message PluginInitializeResponse { + string unique_name = 1; + string name = 2; + string version = 3; + string authors = 4; +} + +message PluginCommandListResponse { + repeated CommandSpec commands = 1; +} + +message CommandSpec { + string name = 1; + repeated string aliases = 2; + string usage = 3; + string description = 4; +} + +message PluginExecuteResponse { + +} diff --git a/src/config.rs b/src/config.rs index 9606ca0..ec5427b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -13,6 +13,8 @@ pub struct BotConfig { pub machines: HashMap, #[serde(rename = "deltaChat")] pub delta_chat: BotDeltaChatConfig, + #[serde(default)] + pub plugins: Vec, } #[derive(Deserialize, Debug)] @@ -47,6 +49,17 @@ pub struct BotDeltaChatConfig { pub avatar: Option, } +#[derive(Deserialize, Debug)] +pub struct PluginConfig { + pub name: String, + #[serde(default = "default_plugin_enabled")] + pub enabled: bool, +} + +fn default_plugin_enabled() -> bool { + true +} + #[derive(Debug)] pub enum ConfigError { Io(io::Error), diff --git a/src/main.rs b/src/main.rs index 4754d1d..9b94b8c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,21 @@ mod commands; mod config; +mod paths; +mod plugin; mod ssh; +mod proto { + pub(crate) mod deltachat_remotecontrol_bot { + pub(crate) mod plugin { + include!(concat!( + env!("OUT_DIR"), + "/deltachat_remotecontrol_bot.plugin.rs" + )); + } + } + pub(crate) use deltachat_remotecontrol_bot::plugin::*; +} + use anyhow::{Context as _, Result as AnyhowResult}; use clap::Parser; use config::BotConfig; @@ -14,20 +28,25 @@ use deltachat::{ securejoin, stock_str::StockStrings, }; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::{path::PathBuf, sync::Arc}; use tokio::sync::Mutex; -const CONFIG_FILENAME: &str = "bot.yml"; +use crate::{ + paths::{data_path, default_config_paths}, + plugin::{LoadedPlugin, PluginCommand, try_load_plugin}, +}; + const APP_NAME: &str = "deltachat-remotecontrol-bot"; -const APP_CONFIG_DIR: &str = APP_NAME; -const APP_DATA_DIR: &str = APP_NAME; const BOT_DISPLAY_NAME: &str = "🤖Remote🖲️"; const AUTH_REQUIRED: bool = false; pub struct BotContext { authed_contacts: HashSet, config: BotConfig, + plugins: HashMap>>, // maps plugin's unique name (id) to plugin object + plugin_commands: HashMap>, // maps command name to command object + plugin_cmd_aliases: HashMap>, // maps command alias to command object } impl BotContext { @@ -35,6 +54,9 @@ impl BotContext { BotContext { authed_contacts: HashSet::new(), config, + plugins: HashMap::new(), + plugin_commands: HashMap::new(), + plugin_cmd_aliases: HashMap::new(), } } } @@ -51,40 +73,7 @@ struct Args { config: Option, } -fn default_config_paths() -> Vec { - let mut paths = vec![PathBuf::from(CONFIG_FILENAME)]; - - if let Ok(config_home) = std::env::var("XDG_CONFIG_HOME") - .map(PathBuf::from) - .or(std::env::var("HOME").map(|home| PathBuf::from(home).join(".config"))) - { - paths.push(config_home.join(APP_CONFIG_DIR).join(CONFIG_FILENAME)); - } - - paths.push( - PathBuf::from("/etc") - .join(APP_CONFIG_DIR) - .join(CONFIG_FILENAME), - ); - - paths -} - -fn data_path() -> PathBuf { - std::env::var("XDG_DATA_HOME") - .map(|dh| PathBuf::from(dh).join(APP_DATA_DIR)) - .or_else(|_| { - std::env::var("HOME").map(|h| { - PathBuf::from(h) - .join(".local") - .join("share") - .join(APP_DATA_DIR) - }) - }) - .unwrap_or(PathBuf::from(".")) -} - -async fn run_bot(cfg: config::BotConfig) -> AnyhowResult<()> { +async fn run_bot(bot_context: Arc>) -> AnyhowResult<()> { let dchat_db_dir = data_path(); std::fs::create_dir_all(&dchat_db_dir) .with_context(|| format!("Failed to create data directory {}", dchat_db_dir.display()))?; @@ -100,12 +89,14 @@ async fn run_bot(cfg: config::BotConfig) -> AnyhowResult<()> { .await .context("Failed to open Delta Chat client DB")?; + let ctx_lock = bot_context.lock().await; + if !dchat_ctx.is_configured().await? { dchat_ctx - .set_config(Config::Addr, Some(&cfg.delta_chat.email)) + .set_config(Config::Addr, Some(&ctx_lock.config.delta_chat.email)) .await?; dchat_ctx - .set_config(Config::MailPw, Some(&cfg.delta_chat.password)) + .set_config(Config::MailPw, Some(&ctx_lock.config.delta_chat.password)) .await?; dchat_ctx.set_config(Config::Bot, Some("1")).await?; dchat_ctx @@ -116,7 +107,7 @@ async fn run_bot(cfg: config::BotConfig) -> AnyhowResult<()> { dchat_ctx .set_config(Config::Displayname, Some(BOT_DISPLAY_NAME)) .await?; - if let Some(mut avatar_path) = cfg.delta_chat.avatar.clone() { + if let Some(mut avatar_path) = ctx_lock.config.delta_chat.avatar.clone() { if avatar_path.is_relative() { avatar_path = data_path().join(&avatar_path); } @@ -127,6 +118,7 @@ async fn run_bot(cfg: config::BotConfig) -> AnyhowResult<()> { } } } + drop(ctx_lock); dchat_ctx.start_io().await; @@ -141,8 +133,6 @@ async fn run_bot(cfg: config::BotConfig) -> AnyhowResult<()> { let dchat_ctx = Arc::new(Mutex::new(dchat_ctx)); - let bot_context = Arc::new(Mutex::new(BotContext::new(cfg))); - let ev_emitter = dchat_ctx.lock().await.get_event_emitter(); while let Some(ev) = ev_emitter.recv().await { @@ -232,5 +222,23 @@ async fn main() { } }; - run_bot(config).await.expect("error"); + let requested_plugins: Vec = config + .plugins + .iter() + .filter(|p| p.enabled) + .map(|p| p.name.clone()) + .collect(); + + let bot_context = Arc::new(Mutex::new(BotContext::new(config))); + + if requested_plugins.len() > 0 { + log::info!("Loading plugins ({})", requested_plugins.join(", ")); + for plugin in &requested_plugins { + if let Err(e) = try_load_plugin(Arc::clone(&bot_context), plugin.clone()).await { + log::error!("Failed to load plugin \"{plugin}\": {e}"); + } + } + } + + run_bot(bot_context).await.expect("error"); } diff --git a/src/paths.rs b/src/paths.rs new file mode 100644 index 0000000..556b186 --- /dev/null +++ b/src/paths.rs @@ -0,0 +1,44 @@ +use std::path::PathBuf; + +use crate::APP_NAME; + +const APP_CONFIG_DIR: &str = APP_NAME; +const APP_DATA_DIR: &str = APP_NAME; +const CONFIG_FILENAME: &str = "bot.yml"; + +pub(crate) fn default_config_paths() -> Vec { + let mut paths = vec![PathBuf::from(CONFIG_FILENAME)]; + + if let Ok(config_home) = std::env::var("XDG_CONFIG_HOME") + .map(PathBuf::from) + .or(std::env::var("HOME").map(|home| PathBuf::from(home).join(".config"))) + { + paths.push(config_home.join(APP_CONFIG_DIR).join(CONFIG_FILENAME)); + } + + paths.push( + PathBuf::from("/etc") + .join(APP_CONFIG_DIR) + .join(CONFIG_FILENAME), + ); + + paths +} + +pub(crate) fn data_path() -> PathBuf { + std::env::var("XDG_DATA_HOME") + .map(|dh| PathBuf::from(dh).join(APP_DATA_DIR)) + .or_else(|_| { + std::env::var("HOME").map(|h| { + PathBuf::from(h) + .join(".local") + .join("share") + .join(APP_DATA_DIR) + }) + }) + .unwrap_or(PathBuf::from(".")) +} + +pub(crate) fn plugins_path() -> PathBuf { + data_path().join("plugins") +} diff --git a/src/plugin/mod.rs b/src/plugin/mod.rs new file mode 100644 index 0000000..317bbde --- /dev/null +++ b/src/plugin/mod.rs @@ -0,0 +1,163 @@ +mod stdio; + +use anyhow::{Context as _, Result as AnyhowResult, bail}; +use async_trait::async_trait; +use prost::{DecodeError, Message}; +use std::{ + collections::HashMap, + error::Error, + fmt::{Debug, Display}, + ops::DerefMut, + process::Stdio, + sync::Arc, +}; + +use tokio::{ + io::{self, AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader}, + process::{Child, ChildStdout}, + sync::{Mutex, oneshot}, +}; + +use crate::{BotContext, paths::plugins_path, proto}; + +pub(crate) struct PluginCommand { + pub plugin_id: String, + pub name: String, + pub aliases: Vec, + pub usage: String, + pub description: String, +} + +#[derive(Default)] +pub(crate) struct LoadedPlugin { + pub plugin_id: String, + pub name: String, + pub version: String, + pub authors: String, + pub commands: Vec>, + pub connection: Option>, +} + +#[derive(Debug)] +pub(crate) enum PluginRequestType { + Initialize, + CommandList, +} + +impl From for i32 { + fn from(value: PluginRequestType) -> Self { + match value { + PluginRequestType::Initialize => 1, + PluginRequestType::CommandList => 2, + } + } +} + +impl Display for PluginRequestType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ::fmt(self, f) + } +} + +#[async_trait] +pub(crate) trait PluginConnection: Send + Sync { + async fn initialize_plugin( + &self, + config: String, + ) -> Result; + + async fn request_plugin_command_list( + &self, + ) -> Result; +} + +#[derive(Debug)] +pub(crate) enum PluginConnectionError { + SendRequest(String, PluginRequestType, Box), + ReadResponse(String, PluginRequestType, Box), + DecodeResponse(String, PluginRequestType, Option), + InvalidMessageLength, +} + +impl Display for PluginConnectionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::SendRequest(plugin_name, req_type, e) => f.write_fmt(format_args!( + "Can't send request ({req_type}) to plugin {plugin_name}: {e}" + )), + Self::ReadResponse(plugin_name, resp_type, e) => f.write_fmt(format_args!( + "Can't read response ({resp_type}) from plugin {plugin_name}: {e}" + )), + Self::DecodeResponse(plugin_name, resp_type, e) => f.write_fmt(format_args!( + "Can't decode response ({resp_type}) from plugin {plugin_name}{}", + e.as_ref().map(|e| format!(": {e}")).unwrap_or_default() + )), + Self::InvalidMessageLength => f.write_str("Plugin response length is invalid"), + } + } +} + +impl Error for PluginConnectionError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + Self::SendRequest(_, _, e) => Some(e.as_ref()), + Self::ReadResponse(_, _, e) => Some(e.as_ref()), + Self::DecodeResponse(_, _, e) => e.as_ref().map(|e| e as &dyn Error), + Self::InvalidMessageLength => None, + } + } +} + +pub(crate) async fn try_load_plugin( + ctx: Arc>, + unique_name: String, +) -> AnyhowResult<()> { + let plugin_dir = plugins_path().join(&unique_name); + if ctx.lock().await.plugins.contains_key(&unique_name) { + bail!("plugin unique name is not unique"); + } + if !std::fs::metadata(&plugin_dir)?.is_dir() { + bail!("Plugin directory doesn't exist"); + } + let plugin_executable_path = plugin_dir.join("plugin_run"); + log::debug!("Starting plugin executable {:?}", &plugin_executable_path); + let mut cmd = tokio::process::Command::new(plugin_executable_path); + cmd.stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()); + cmd.current_dir(&plugin_dir); + + // TODO добавить какие-нибудь перемнные среды + + let plugin_process = cmd.spawn().context("Failed to start the plugin")?; + let plugin = stdio::initialize_stdio_plugin(plugin_process, unique_name.clone()).await?; + + let mut ctx_lock = ctx.lock().await; + for cmd in plugin.lock().await.commands.iter().cloned() { + log::debug!("adding command /{} of plugin {}", &cmd.name, &unique_name); + if ctx_lock.plugin_commands.contains_key(&cmd.name) { + bail!("duplicate command specification"); + } + ctx_lock + .plugin_commands + .insert(cmd.name.clone(), Arc::clone(&cmd)); + ctx_lock + .plugin_cmd_aliases + .insert(cmd.name.clone(), Arc::clone(&cmd)); + for alias in cmd.aliases.iter() { + log::debug!( + "adding command alias /{} -> /{} of plugin {}", + alias, + &cmd.name, + &unique_name + ); + ctx_lock + .plugin_cmd_aliases + .insert(alias.to_owned(), Arc::clone(&cmd)); + } + } + drop(ctx_lock); + + ctx.lock().await.plugins.insert(unique_name, plugin); + Ok(()) +} diff --git a/src/plugin/stdio.rs b/src/plugin/stdio.rs new file mode 100644 index 0000000..f7d94bb --- /dev/null +++ b/src/plugin/stdio.rs @@ -0,0 +1,300 @@ +use std::{collections::HashMap, error::Error, ops::DerefMut, sync::Arc, time::Duration}; + +use anyhow::{Context as _, Result as AnyhowResult, bail}; +use async_trait::async_trait; +use prost::Message; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader}, + process::{Child, ChildStdout}, + sync::{Mutex, oneshot}, + time::error::Elapsed, +}; + +use crate::{ + plugin::{ + LoadedPlugin, PluginCommand, PluginConnection, PluginConnectionError, PluginRequestType, + }, + proto, +}; + +pub(super) async fn initialize_stdio_plugin( + process: Child, + unique_name: String, +) -> AnyhowResult>> { + let plugin = Arc::new(Mutex::new(LoadedPlugin::default())); + log::info!("Connecting to plugin {} using standard I/O", &unique_name); + plugin.lock().await.plugin_id = unique_name; + let connection = Arc::new(StdioPluginConnection::new(Arc::clone(&plugin), process)); + Arc::clone(&connection).run_stdio_loops(); + + let plugin_info = connection.initialize_plugin(String::new()).await?; + log::debug!("received plugin identification: {:?}", plugin_info); + let mut plugin_lock = plugin.lock().await; + plugin_lock.name = plugin_info.name; + plugin_lock.plugin_id = plugin_info.unique_name.clone(); + plugin_lock.version = plugin_info.version; + plugin_lock.authors = plugin_info.authors; + drop(plugin_lock); + + let command_list = connection.request_plugin_command_list().await?; + let mut plugin_lock = plugin.lock().await; + plugin_lock.commands = command_list + .commands + .into_iter() + .map(|cmd| { + Arc::new(PluginCommand { + name: cmd.name, + plugin_id: plugin_info.unique_name.clone(), + aliases: cmd.aliases, + usage: cmd.usage, + description: cmd.description, + }) + }) + .collect(); + + plugin_lock.connection = Some(connection); + drop(plugin_lock); + + Ok(plugin) +} + +struct StdioPluginConnection { + plugin: Arc>, + process: Mutex, + buffered_stdout: Mutex>, + next_request_id: Mutex, + pending_requests: Mutex>>, +} + +impl StdioPluginConnection { + pub fn new(plugin: Arc>, mut process: Child) -> StdioPluginConnection { + let stdout = process.stdout.take().unwrap(); + let conn = StdioPluginConnection { + plugin, + process: Mutex::new(process), + buffered_stdout: Mutex::new(BufReader::new(stdout)), + next_request_id: Mutex::new(0), + pending_requests: Mutex::new(HashMap::new()), + }; + conn + } + + fn run_stdio_loops(self: Arc) { + tokio::spawn(self.stdout_reader_loop()); + } + + async fn stdout_reader_loop(self: Arc) { + loop { + let frame = + match Self::read_length_delimited(self.buffered_stdout.lock().await.deref_mut()) + .await + { + Ok(frame) => frame, + Err(e) => { + log::error!( + "Error while reading STDOUT of stdio plugin {}: {e}", + &self.plugin.lock().await.plugin_id + ); + break; + } + }; + let response = match proto::Response::decode(frame.as_slice()) { + Ok(response) => response, + Err(_) => { + log::error!( + "Invalid response received from stdio plugin {}", + &self.plugin.lock().await.plugin_id + ); + continue; + } + }; + match self + .pending_requests + .lock() + .await + .remove(&response.request_id) + { + Some(sender) => match sender.send(response) { + Ok(()) => {} + Err(response) => { + log::warn!( + "Dropping response with request_id {} from plugin {}", + response.request_id, + self.plugin.lock().await.plugin_id + ); + } + }, + None => { + continue; + } + } + } + + self.pending_requests.lock().await.clear(); + } + + async fn read_length_delimited( + reader: &mut R, + ) -> Result, Box> + where + R: AsyncRead + Unpin, + { + const MAX_FRAME_LENGTH: usize = 0x40000000; + let mut length: usize = 0; + let mut bit = 0; + let mut byte = [0u8; 1]; + let mut complete = false; + for _ in 0..10 { + reader.read_exact(&mut byte).await?; + length += ((byte[0] & 0x7f) as usize).unbounded_shl(bit); + bit += 7; + if (byte[0] & 0x80) == 0 { + complete = true; + break; + } + } + + if !complete || length > MAX_FRAME_LENGTH { + return Err(Box::new(PluginConnectionError::InvalidMessageLength)); + } + + let mut message = vec![0u8; length]; + reader.read_exact(&mut message).await?; + Ok(message) + } + + async fn await_response_to( + &self, + request_id: u32, + timeout: Duration, + ) -> Result> { + let (tx, rx) = oneshot::channel(); + self.pending_requests.lock().await.insert(request_id, tx); + + match tokio::time::timeout(timeout, rx).await { + Ok(Ok(response)) => Ok(response), + Ok(Err(e)) => Err(Box::new(e)), + Err(e) => Err(Box::new(e)), + } + } +} + +#[async_trait] +impl PluginConnection for StdioPluginConnection { + async fn initialize_plugin( + &self, + config: String, + ) -> Result { + let request_id = { + let mut r = self.next_request_id.lock().await; + let id = *r; + *r += 1; + id + }; + let request = proto::Request { + request_id, + req: Some(proto::request::Req::InitializeReq( + proto::PluginInitializeRequest { config }, + )), + } + .encode_length_delimited_to_vec(); + + if let Err(e) = self + .process + .lock() + .await + .stdin + .as_mut() + .unwrap() + .write_all(&request) + .await + { + return Err(PluginConnectionError::SendRequest( + self.plugin.lock().await.plugin_id.clone(), + PluginRequestType::Initialize, + Box::new(e), + )); + } + + let response = match self + .await_response_to(request_id, Duration::from_secs(10)) + .await + { + Ok(response) => response, + Err(e) => { + return Err(PluginConnectionError::ReadResponse( + self.plugin.lock().await.plugin_id.clone(), + PluginRequestType::Initialize, + e, + )); + } + }; + + match response.res { + Some(proto::response::Res::InitializeRes(resp)) => Ok(resp), + _ => Err(PluginConnectionError::DecodeResponse( + self.plugin.lock().await.plugin_id.clone(), + PluginRequestType::Initialize, + None, + )), + } + } + + async fn request_plugin_command_list( + &self, + ) -> Result { + let request_id = { + let mut r = self.next_request_id.lock().await; + let id = *r; + *r += 1; + id + }; + let request = proto::Request { + request_id, + req: Some(proto::request::Req::CommandListReq( + proto::PluginCommandListRequest {}, + )), + } + .encode_length_delimited_to_vec(); + + if let Err(e) = self + .process + .lock() + .await + .stdin + .as_mut() + .unwrap() + .write_all(&request) + .await + { + return Err(PluginConnectionError::SendRequest( + self.plugin.lock().await.plugin_id.clone(), + PluginRequestType::CommandList, + Box::new(e), + )); + } + + let response = match self + .await_response_to(request_id, Duration::from_secs(10)) + .await + { + Ok(response) => response, + Err(e) => { + return Err(PluginConnectionError::ReadResponse( + self.plugin.lock().await.plugin_id.clone(), + PluginRequestType::CommandList, + e, + )); + } + }; + + match response.res { + Some(proto::response::Res::CommandListRes(resp)) => Ok(resp), + _ => Err(PluginConnectionError::DecodeResponse( + self.plugin.lock().await.plugin_id.clone(), + PluginRequestType::CommandList, + None, + )), + } + } +}