start implementing a plugin system
start implementing a plugin system using standard I/O as a transport
This commit is contained in:
163
src/plugin/mod.rs
Normal file
163
src/plugin/mod.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
mod stdio;
|
||||
|
||||
use anyhow::{Context as _, Result as AnyhowResult, bail};
|
||||
use async_trait::async_trait;
|
||||
use prost::{DecodeError, Message};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
error::Error,
|
||||
fmt::{Debug, Display},
|
||||
ops::DerefMut,
|
||||
process::Stdio,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use tokio::{
|
||||
io::{self, AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader},
|
||||
process::{Child, ChildStdout},
|
||||
sync::{Mutex, oneshot},
|
||||
};
|
||||
|
||||
use crate::{BotContext, paths::plugins_path, proto};
|
||||
|
||||
pub(crate) struct PluginCommand {
|
||||
pub plugin_id: String,
|
||||
pub name: String,
|
||||
pub aliases: Vec<String>,
|
||||
pub usage: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct LoadedPlugin {
|
||||
pub plugin_id: String,
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
pub authors: String,
|
||||
pub commands: Vec<Arc<PluginCommand>>,
|
||||
pub connection: Option<Arc<dyn PluginConnection>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum PluginRequestType {
|
||||
Initialize,
|
||||
CommandList,
|
||||
}
|
||||
|
||||
impl From<PluginRequestType> for i32 {
|
||||
fn from(value: PluginRequestType) -> Self {
|
||||
match value {
|
||||
PluginRequestType::Initialize => 1,
|
||||
PluginRequestType::CommandList => 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for PluginRequestType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
<PluginRequestType as Debug>::fmt(self, f)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait PluginConnection: Send + Sync {
|
||||
async fn initialize_plugin(
|
||||
&self,
|
||||
config: String,
|
||||
) -> Result<proto::PluginInitializeResponse, PluginConnectionError>;
|
||||
|
||||
async fn request_plugin_command_list(
|
||||
&self,
|
||||
) -> Result<proto::PluginCommandListResponse, PluginConnectionError>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum PluginConnectionError {
|
||||
SendRequest(String, PluginRequestType, Box<dyn Error + Send + Sync>),
|
||||
ReadResponse(String, PluginRequestType, Box<dyn Error + Send + Sync>),
|
||||
DecodeResponse(String, PluginRequestType, Option<DecodeError>),
|
||||
InvalidMessageLength,
|
||||
}
|
||||
|
||||
impl Display for PluginConnectionError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::SendRequest(plugin_name, req_type, e) => f.write_fmt(format_args!(
|
||||
"Can't send request ({req_type}) to plugin {plugin_name}: {e}"
|
||||
)),
|
||||
Self::ReadResponse(plugin_name, resp_type, e) => f.write_fmt(format_args!(
|
||||
"Can't read response ({resp_type}) from plugin {plugin_name}: {e}"
|
||||
)),
|
||||
Self::DecodeResponse(plugin_name, resp_type, e) => f.write_fmt(format_args!(
|
||||
"Can't decode response ({resp_type}) from plugin {plugin_name}{}",
|
||||
e.as_ref().map(|e| format!(": {e}")).unwrap_or_default()
|
||||
)),
|
||||
Self::InvalidMessageLength => f.write_str("Plugin response length is invalid"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for PluginConnectionError {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
match self {
|
||||
Self::SendRequest(_, _, e) => Some(e.as_ref()),
|
||||
Self::ReadResponse(_, _, e) => Some(e.as_ref()),
|
||||
Self::DecodeResponse(_, _, e) => e.as_ref().map(|e| e as &dyn Error),
|
||||
Self::InvalidMessageLength => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn try_load_plugin(
|
||||
ctx: Arc<Mutex<BotContext>>,
|
||||
unique_name: String,
|
||||
) -> AnyhowResult<()> {
|
||||
let plugin_dir = plugins_path().join(&unique_name);
|
||||
if ctx.lock().await.plugins.contains_key(&unique_name) {
|
||||
bail!("plugin unique name is not unique");
|
||||
}
|
||||
if !std::fs::metadata(&plugin_dir)?.is_dir() {
|
||||
bail!("Plugin directory doesn't exist");
|
||||
}
|
||||
let plugin_executable_path = plugin_dir.join("plugin_run");
|
||||
log::debug!("Starting plugin executable {:?}", &plugin_executable_path);
|
||||
let mut cmd = tokio::process::Command::new(plugin_executable_path);
|
||||
cmd.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::inherit());
|
||||
cmd.current_dir(&plugin_dir);
|
||||
|
||||
// TODO добавить какие-нибудь перемнные среды
|
||||
|
||||
let plugin_process = cmd.spawn().context("Failed to start the plugin")?;
|
||||
let plugin = stdio::initialize_stdio_plugin(plugin_process, unique_name.clone()).await?;
|
||||
|
||||
let mut ctx_lock = ctx.lock().await;
|
||||
for cmd in plugin.lock().await.commands.iter().cloned() {
|
||||
log::debug!("adding command /{} of plugin {}", &cmd.name, &unique_name);
|
||||
if ctx_lock.plugin_commands.contains_key(&cmd.name) {
|
||||
bail!("duplicate command specification");
|
||||
}
|
||||
ctx_lock
|
||||
.plugin_commands
|
||||
.insert(cmd.name.clone(), Arc::clone(&cmd));
|
||||
ctx_lock
|
||||
.plugin_cmd_aliases
|
||||
.insert(cmd.name.clone(), Arc::clone(&cmd));
|
||||
for alias in cmd.aliases.iter() {
|
||||
log::debug!(
|
||||
"adding command alias /{} -> /{} of plugin {}",
|
||||
alias,
|
||||
&cmd.name,
|
||||
&unique_name
|
||||
);
|
||||
ctx_lock
|
||||
.plugin_cmd_aliases
|
||||
.insert(alias.to_owned(), Arc::clone(&cmd));
|
||||
}
|
||||
}
|
||||
drop(ctx_lock);
|
||||
|
||||
ctx.lock().await.plugins.insert(unique_name, plugin);
|
||||
Ok(())
|
||||
}
|
||||
300
src/plugin/stdio.rs
Normal file
300
src/plugin/stdio.rs
Normal file
@@ -0,0 +1,300 @@
|
||||
use std::{collections::HashMap, error::Error, ops::DerefMut, sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::{Context as _, Result as AnyhowResult, bail};
|
||||
use async_trait::async_trait;
|
||||
use prost::Message;
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader},
|
||||
process::{Child, ChildStdout},
|
||||
sync::{Mutex, oneshot},
|
||||
time::error::Elapsed,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
plugin::{
|
||||
LoadedPlugin, PluginCommand, PluginConnection, PluginConnectionError, PluginRequestType,
|
||||
},
|
||||
proto,
|
||||
};
|
||||
|
||||
pub(super) async fn initialize_stdio_plugin(
|
||||
process: Child,
|
||||
unique_name: String,
|
||||
) -> AnyhowResult<Arc<Mutex<LoadedPlugin>>> {
|
||||
let plugin = Arc::new(Mutex::new(LoadedPlugin::default()));
|
||||
log::info!("Connecting to plugin {} using standard I/O", &unique_name);
|
||||
plugin.lock().await.plugin_id = unique_name;
|
||||
let connection = Arc::new(StdioPluginConnection::new(Arc::clone(&plugin), process));
|
||||
Arc::clone(&connection).run_stdio_loops();
|
||||
|
||||
let plugin_info = connection.initialize_plugin(String::new()).await?;
|
||||
log::debug!("received plugin identification: {:?}", plugin_info);
|
||||
let mut plugin_lock = plugin.lock().await;
|
||||
plugin_lock.name = plugin_info.name;
|
||||
plugin_lock.plugin_id = plugin_info.unique_name.clone();
|
||||
plugin_lock.version = plugin_info.version;
|
||||
plugin_lock.authors = plugin_info.authors;
|
||||
drop(plugin_lock);
|
||||
|
||||
let command_list = connection.request_plugin_command_list().await?;
|
||||
let mut plugin_lock = plugin.lock().await;
|
||||
plugin_lock.commands = command_list
|
||||
.commands
|
||||
.into_iter()
|
||||
.map(|cmd| {
|
||||
Arc::new(PluginCommand {
|
||||
name: cmd.name,
|
||||
plugin_id: plugin_info.unique_name.clone(),
|
||||
aliases: cmd.aliases,
|
||||
usage: cmd.usage,
|
||||
description: cmd.description,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
plugin_lock.connection = Some(connection);
|
||||
drop(plugin_lock);
|
||||
|
||||
Ok(plugin)
|
||||
}
|
||||
|
||||
struct StdioPluginConnection {
|
||||
plugin: Arc<Mutex<LoadedPlugin>>,
|
||||
process: Mutex<Child>,
|
||||
buffered_stdout: Mutex<BufReader<ChildStdout>>,
|
||||
next_request_id: Mutex<u32>,
|
||||
pending_requests: Mutex<HashMap<u32, oneshot::Sender<proto::Response>>>,
|
||||
}
|
||||
|
||||
impl StdioPluginConnection {
|
||||
pub fn new(plugin: Arc<Mutex<LoadedPlugin>>, mut process: Child) -> StdioPluginConnection {
|
||||
let stdout = process.stdout.take().unwrap();
|
||||
let conn = StdioPluginConnection {
|
||||
plugin,
|
||||
process: Mutex::new(process),
|
||||
buffered_stdout: Mutex::new(BufReader::new(stdout)),
|
||||
next_request_id: Mutex::new(0),
|
||||
pending_requests: Mutex::new(HashMap::new()),
|
||||
};
|
||||
conn
|
||||
}
|
||||
|
||||
fn run_stdio_loops(self: Arc<Self>) {
|
||||
tokio::spawn(self.stdout_reader_loop());
|
||||
}
|
||||
|
||||
async fn stdout_reader_loop(self: Arc<Self>) {
|
||||
loop {
|
||||
let frame =
|
||||
match Self::read_length_delimited(self.buffered_stdout.lock().await.deref_mut())
|
||||
.await
|
||||
{
|
||||
Ok(frame) => frame,
|
||||
Err(e) => {
|
||||
log::error!(
|
||||
"Error while reading STDOUT of stdio plugin {}: {e}",
|
||||
&self.plugin.lock().await.plugin_id
|
||||
);
|
||||
break;
|
||||
}
|
||||
};
|
||||
let response = match proto::Response::decode(frame.as_slice()) {
|
||||
Ok(response) => response,
|
||||
Err(_) => {
|
||||
log::error!(
|
||||
"Invalid response received from stdio plugin {}",
|
||||
&self.plugin.lock().await.plugin_id
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
match self
|
||||
.pending_requests
|
||||
.lock()
|
||||
.await
|
||||
.remove(&response.request_id)
|
||||
{
|
||||
Some(sender) => match sender.send(response) {
|
||||
Ok(()) => {}
|
||||
Err(response) => {
|
||||
log::warn!(
|
||||
"Dropping response with request_id {} from plugin {}",
|
||||
response.request_id,
|
||||
self.plugin.lock().await.plugin_id
|
||||
);
|
||||
}
|
||||
},
|
||||
None => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.pending_requests.lock().await.clear();
|
||||
}
|
||||
|
||||
async fn read_length_delimited<R>(
|
||||
reader: &mut R,
|
||||
) -> Result<Vec<u8>, Box<dyn Error + Send + Sync>>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
const MAX_FRAME_LENGTH: usize = 0x40000000;
|
||||
let mut length: usize = 0;
|
||||
let mut bit = 0;
|
||||
let mut byte = [0u8; 1];
|
||||
let mut complete = false;
|
||||
for _ in 0..10 {
|
||||
reader.read_exact(&mut byte).await?;
|
||||
length += ((byte[0] & 0x7f) as usize).unbounded_shl(bit);
|
||||
bit += 7;
|
||||
if (byte[0] & 0x80) == 0 {
|
||||
complete = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !complete || length > MAX_FRAME_LENGTH {
|
||||
return Err(Box::new(PluginConnectionError::InvalidMessageLength));
|
||||
}
|
||||
|
||||
let mut message = vec![0u8; length];
|
||||
reader.read_exact(&mut message).await?;
|
||||
Ok(message)
|
||||
}
|
||||
|
||||
async fn await_response_to(
|
||||
&self,
|
||||
request_id: u32,
|
||||
timeout: Duration,
|
||||
) -> Result<proto::Response, Box<dyn Error + Send + Sync>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.pending_requests.lock().await.insert(request_id, tx);
|
||||
|
||||
match tokio::time::timeout(timeout, rx).await {
|
||||
Ok(Ok(response)) => Ok(response),
|
||||
Ok(Err(e)) => Err(Box::new(e)),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl PluginConnection for StdioPluginConnection {
|
||||
async fn initialize_plugin(
|
||||
&self,
|
||||
config: String,
|
||||
) -> Result<proto::PluginInitializeResponse, PluginConnectionError> {
|
||||
let request_id = {
|
||||
let mut r = self.next_request_id.lock().await;
|
||||
let id = *r;
|
||||
*r += 1;
|
||||
id
|
||||
};
|
||||
let request = proto::Request {
|
||||
request_id,
|
||||
req: Some(proto::request::Req::InitializeReq(
|
||||
proto::PluginInitializeRequest { config },
|
||||
)),
|
||||
}
|
||||
.encode_length_delimited_to_vec();
|
||||
|
||||
if let Err(e) = self
|
||||
.process
|
||||
.lock()
|
||||
.await
|
||||
.stdin
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.write_all(&request)
|
||||
.await
|
||||
{
|
||||
return Err(PluginConnectionError::SendRequest(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::Initialize,
|
||||
Box::new(e),
|
||||
));
|
||||
}
|
||||
|
||||
let response = match self
|
||||
.await_response_to(request_id, Duration::from_secs(10))
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
return Err(PluginConnectionError::ReadResponse(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::Initialize,
|
||||
e,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
match response.res {
|
||||
Some(proto::response::Res::InitializeRes(resp)) => Ok(resp),
|
||||
_ => Err(PluginConnectionError::DecodeResponse(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::Initialize,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn request_plugin_command_list(
|
||||
&self,
|
||||
) -> Result<proto::PluginCommandListResponse, PluginConnectionError> {
|
||||
let request_id = {
|
||||
let mut r = self.next_request_id.lock().await;
|
||||
let id = *r;
|
||||
*r += 1;
|
||||
id
|
||||
};
|
||||
let request = proto::Request {
|
||||
request_id,
|
||||
req: Some(proto::request::Req::CommandListReq(
|
||||
proto::PluginCommandListRequest {},
|
||||
)),
|
||||
}
|
||||
.encode_length_delimited_to_vec();
|
||||
|
||||
if let Err(e) = self
|
||||
.process
|
||||
.lock()
|
||||
.await
|
||||
.stdin
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.write_all(&request)
|
||||
.await
|
||||
{
|
||||
return Err(PluginConnectionError::SendRequest(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::CommandList,
|
||||
Box::new(e),
|
||||
));
|
||||
}
|
||||
|
||||
let response = match self
|
||||
.await_response_to(request_id, Duration::from_secs(10))
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
return Err(PluginConnectionError::ReadResponse(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::CommandList,
|
||||
e,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
match response.res {
|
||||
Some(proto::response::Res::CommandListRes(resp)) => Ok(resp),
|
||||
_ => Err(PluginConnectionError::DecodeResponse(
|
||||
self.plugin.lock().await.plugin_id.clone(),
|
||||
PluginRequestType::CommandList,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user