start implementing a plugin system

start implementing a plugin system using standard I/O as a transport
This commit is contained in:
2026-05-10 00:34:30 +03:00
parent dda56f1694
commit 21564dda30
10 changed files with 725 additions and 44 deletions

163
src/plugin/mod.rs Normal file
View File

@@ -0,0 +1,163 @@
mod stdio;
use anyhow::{Context as _, Result as AnyhowResult, bail};
use async_trait::async_trait;
use prost::{DecodeError, Message};
use std::{
collections::HashMap,
error::Error,
fmt::{Debug, Display},
ops::DerefMut,
process::Stdio,
sync::Arc,
};
use tokio::{
io::{self, AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader},
process::{Child, ChildStdout},
sync::{Mutex, oneshot},
};
use crate::{BotContext, paths::plugins_path, proto};
pub(crate) struct PluginCommand {
pub plugin_id: String,
pub name: String,
pub aliases: Vec<String>,
pub usage: String,
pub description: String,
}
#[derive(Default)]
pub(crate) struct LoadedPlugin {
pub plugin_id: String,
pub name: String,
pub version: String,
pub authors: String,
pub commands: Vec<Arc<PluginCommand>>,
pub connection: Option<Arc<dyn PluginConnection>>,
}
#[derive(Debug)]
pub(crate) enum PluginRequestType {
Initialize,
CommandList,
}
impl From<PluginRequestType> for i32 {
fn from(value: PluginRequestType) -> Self {
match value {
PluginRequestType::Initialize => 1,
PluginRequestType::CommandList => 2,
}
}
}
impl Display for PluginRequestType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
<PluginRequestType as Debug>::fmt(self, f)
}
}
#[async_trait]
pub(crate) trait PluginConnection: Send + Sync {
async fn initialize_plugin(
&self,
config: String,
) -> Result<proto::PluginInitializeResponse, PluginConnectionError>;
async fn request_plugin_command_list(
&self,
) -> Result<proto::PluginCommandListResponse, PluginConnectionError>;
}
#[derive(Debug)]
pub(crate) enum PluginConnectionError {
SendRequest(String, PluginRequestType, Box<dyn Error + Send + Sync>),
ReadResponse(String, PluginRequestType, Box<dyn Error + Send + Sync>),
DecodeResponse(String, PluginRequestType, Option<DecodeError>),
InvalidMessageLength,
}
impl Display for PluginConnectionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SendRequest(plugin_name, req_type, e) => f.write_fmt(format_args!(
"Can't send request ({req_type}) to plugin {plugin_name}: {e}"
)),
Self::ReadResponse(plugin_name, resp_type, e) => f.write_fmt(format_args!(
"Can't read response ({resp_type}) from plugin {plugin_name}: {e}"
)),
Self::DecodeResponse(plugin_name, resp_type, e) => f.write_fmt(format_args!(
"Can't decode response ({resp_type}) from plugin {plugin_name}{}",
e.as_ref().map(|e| format!(": {e}")).unwrap_or_default()
)),
Self::InvalidMessageLength => f.write_str("Plugin response length is invalid"),
}
}
}
impl Error for PluginConnectionError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::SendRequest(_, _, e) => Some(e.as_ref()),
Self::ReadResponse(_, _, e) => Some(e.as_ref()),
Self::DecodeResponse(_, _, e) => e.as_ref().map(|e| e as &dyn Error),
Self::InvalidMessageLength => None,
}
}
}
pub(crate) async fn try_load_plugin(
ctx: Arc<Mutex<BotContext>>,
unique_name: String,
) -> AnyhowResult<()> {
let plugin_dir = plugins_path().join(&unique_name);
if ctx.lock().await.plugins.contains_key(&unique_name) {
bail!("plugin unique name is not unique");
}
if !std::fs::metadata(&plugin_dir)?.is_dir() {
bail!("Plugin directory doesn't exist");
}
let plugin_executable_path = plugin_dir.join("plugin_run");
log::debug!("Starting plugin executable {:?}", &plugin_executable_path);
let mut cmd = tokio::process::Command::new(plugin_executable_path);
cmd.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit());
cmd.current_dir(&plugin_dir);
// TODO добавить какие-нибудь перемнные среды
let plugin_process = cmd.spawn().context("Failed to start the plugin")?;
let plugin = stdio::initialize_stdio_plugin(plugin_process, unique_name.clone()).await?;
let mut ctx_lock = ctx.lock().await;
for cmd in plugin.lock().await.commands.iter().cloned() {
log::debug!("adding command /{} of plugin {}", &cmd.name, &unique_name);
if ctx_lock.plugin_commands.contains_key(&cmd.name) {
bail!("duplicate command specification");
}
ctx_lock
.plugin_commands
.insert(cmd.name.clone(), Arc::clone(&cmd));
ctx_lock
.plugin_cmd_aliases
.insert(cmd.name.clone(), Arc::clone(&cmd));
for alias in cmd.aliases.iter() {
log::debug!(
"adding command alias /{} -> /{} of plugin {}",
alias,
&cmd.name,
&unique_name
);
ctx_lock
.plugin_cmd_aliases
.insert(alias.to_owned(), Arc::clone(&cmd));
}
}
drop(ctx_lock);
ctx.lock().await.plugins.insert(unique_name, plugin);
Ok(())
}

300
src/plugin/stdio.rs Normal file
View File

@@ -0,0 +1,300 @@
use std::{collections::HashMap, error::Error, ops::DerefMut, sync::Arc, time::Duration};
use anyhow::{Context as _, Result as AnyhowResult, bail};
use async_trait::async_trait;
use prost::Message;
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader},
process::{Child, ChildStdout},
sync::{Mutex, oneshot},
time::error::Elapsed,
};
use crate::{
plugin::{
LoadedPlugin, PluginCommand, PluginConnection, PluginConnectionError, PluginRequestType,
},
proto,
};
pub(super) async fn initialize_stdio_plugin(
process: Child,
unique_name: String,
) -> AnyhowResult<Arc<Mutex<LoadedPlugin>>> {
let plugin = Arc::new(Mutex::new(LoadedPlugin::default()));
log::info!("Connecting to plugin {} using standard I/O", &unique_name);
plugin.lock().await.plugin_id = unique_name;
let connection = Arc::new(StdioPluginConnection::new(Arc::clone(&plugin), process));
Arc::clone(&connection).run_stdio_loops();
let plugin_info = connection.initialize_plugin(String::new()).await?;
log::debug!("received plugin identification: {:?}", plugin_info);
let mut plugin_lock = plugin.lock().await;
plugin_lock.name = plugin_info.name;
plugin_lock.plugin_id = plugin_info.unique_name.clone();
plugin_lock.version = plugin_info.version;
plugin_lock.authors = plugin_info.authors;
drop(plugin_lock);
let command_list = connection.request_plugin_command_list().await?;
let mut plugin_lock = plugin.lock().await;
plugin_lock.commands = command_list
.commands
.into_iter()
.map(|cmd| {
Arc::new(PluginCommand {
name: cmd.name,
plugin_id: plugin_info.unique_name.clone(),
aliases: cmd.aliases,
usage: cmd.usage,
description: cmd.description,
})
})
.collect();
plugin_lock.connection = Some(connection);
drop(plugin_lock);
Ok(plugin)
}
struct StdioPluginConnection {
plugin: Arc<Mutex<LoadedPlugin>>,
process: Mutex<Child>,
buffered_stdout: Mutex<BufReader<ChildStdout>>,
next_request_id: Mutex<u32>,
pending_requests: Mutex<HashMap<u32, oneshot::Sender<proto::Response>>>,
}
impl StdioPluginConnection {
pub fn new(plugin: Arc<Mutex<LoadedPlugin>>, mut process: Child) -> StdioPluginConnection {
let stdout = process.stdout.take().unwrap();
let conn = StdioPluginConnection {
plugin,
process: Mutex::new(process),
buffered_stdout: Mutex::new(BufReader::new(stdout)),
next_request_id: Mutex::new(0),
pending_requests: Mutex::new(HashMap::new()),
};
conn
}
fn run_stdio_loops(self: Arc<Self>) {
tokio::spawn(self.stdout_reader_loop());
}
async fn stdout_reader_loop(self: Arc<Self>) {
loop {
let frame =
match Self::read_length_delimited(self.buffered_stdout.lock().await.deref_mut())
.await
{
Ok(frame) => frame,
Err(e) => {
log::error!(
"Error while reading STDOUT of stdio plugin {}: {e}",
&self.plugin.lock().await.plugin_id
);
break;
}
};
let response = match proto::Response::decode(frame.as_slice()) {
Ok(response) => response,
Err(_) => {
log::error!(
"Invalid response received from stdio plugin {}",
&self.plugin.lock().await.plugin_id
);
continue;
}
};
match self
.pending_requests
.lock()
.await
.remove(&response.request_id)
{
Some(sender) => match sender.send(response) {
Ok(()) => {}
Err(response) => {
log::warn!(
"Dropping response with request_id {} from plugin {}",
response.request_id,
self.plugin.lock().await.plugin_id
);
}
},
None => {
continue;
}
}
}
self.pending_requests.lock().await.clear();
}
async fn read_length_delimited<R>(
reader: &mut R,
) -> Result<Vec<u8>, Box<dyn Error + Send + Sync>>
where
R: AsyncRead + Unpin,
{
const MAX_FRAME_LENGTH: usize = 0x40000000;
let mut length: usize = 0;
let mut bit = 0;
let mut byte = [0u8; 1];
let mut complete = false;
for _ in 0..10 {
reader.read_exact(&mut byte).await?;
length += ((byte[0] & 0x7f) as usize).unbounded_shl(bit);
bit += 7;
if (byte[0] & 0x80) == 0 {
complete = true;
break;
}
}
if !complete || length > MAX_FRAME_LENGTH {
return Err(Box::new(PluginConnectionError::InvalidMessageLength));
}
let mut message = vec![0u8; length];
reader.read_exact(&mut message).await?;
Ok(message)
}
async fn await_response_to(
&self,
request_id: u32,
timeout: Duration,
) -> Result<proto::Response, Box<dyn Error + Send + Sync>> {
let (tx, rx) = oneshot::channel();
self.pending_requests.lock().await.insert(request_id, tx);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(e)) => Err(Box::new(e)),
Err(e) => Err(Box::new(e)),
}
}
}
#[async_trait]
impl PluginConnection for StdioPluginConnection {
async fn initialize_plugin(
&self,
config: String,
) -> Result<proto::PluginInitializeResponse, PluginConnectionError> {
let request_id = {
let mut r = self.next_request_id.lock().await;
let id = *r;
*r += 1;
id
};
let request = proto::Request {
request_id,
req: Some(proto::request::Req::InitializeReq(
proto::PluginInitializeRequest { config },
)),
}
.encode_length_delimited_to_vec();
if let Err(e) = self
.process
.lock()
.await
.stdin
.as_mut()
.unwrap()
.write_all(&request)
.await
{
return Err(PluginConnectionError::SendRequest(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::Initialize,
Box::new(e),
));
}
let response = match self
.await_response_to(request_id, Duration::from_secs(10))
.await
{
Ok(response) => response,
Err(e) => {
return Err(PluginConnectionError::ReadResponse(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::Initialize,
e,
));
}
};
match response.res {
Some(proto::response::Res::InitializeRes(resp)) => Ok(resp),
_ => Err(PluginConnectionError::DecodeResponse(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::Initialize,
None,
)),
}
}
async fn request_plugin_command_list(
&self,
) -> Result<proto::PluginCommandListResponse, PluginConnectionError> {
let request_id = {
let mut r = self.next_request_id.lock().await;
let id = *r;
*r += 1;
id
};
let request = proto::Request {
request_id,
req: Some(proto::request::Req::CommandListReq(
proto::PluginCommandListRequest {},
)),
}
.encode_length_delimited_to_vec();
if let Err(e) = self
.process
.lock()
.await
.stdin
.as_mut()
.unwrap()
.write_all(&request)
.await
{
return Err(PluginConnectionError::SendRequest(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::CommandList,
Box::new(e),
));
}
let response = match self
.await_response_to(request_id, Duration::from_secs(10))
.await
{
Ok(response) => response,
Err(e) => {
return Err(PluginConnectionError::ReadResponse(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::CommandList,
e,
));
}
};
match response.res {
Some(proto::response::Res::CommandListRes(resp)) => Ok(resp),
_ => Err(PluginConnectionError::DecodeResponse(
self.plugin.lock().await.plugin_id.clone(),
PluginRequestType::CommandList,
None,
)),
}
}
}