From be63e18ebf2d6ddf7a02a8c39839d7a51b843fe4 Mon Sep 17 00:00:00 2001 From: adbenitez Date: Sat, 10 Dec 2022 19:22:57 -0500 Subject: [PATCH] improve hook filters --- deltachat-rpc-client/examples/echobot.py | 4 +- .../examples/echobot_advanced.py | 13 +++--- .../src/deltachat_rpc_client/client.py | 46 ++++++++++++++++--- .../src/deltachat_rpc_client/const.py | 2 + .../src/deltachat_rpc_client/events.py | 32 ++++++++++--- deltachat-rpc-client/tests/test_something.py | 15 ++++-- 6 files changed, 87 insertions(+), 25 deletions(-) diff --git a/deltachat-rpc-client/examples/echobot.py b/deltachat-rpc-client/examples/echobot.py index 792dbae86..c77b422d2 100755 --- a/deltachat-rpc-client/examples/echobot.py +++ b/deltachat-rpc-client/examples/echobot.py @@ -17,8 +17,8 @@ async def log_event(event): @hooks.on(events.NewMessage) -async def echo(msg): - await msg.chat.send_text(msg.text) +async def echo(event): + await event.chat.send_text(event.text) if __name__ == "__main__": diff --git a/deltachat-rpc-client/examples/echobot_advanced.py b/deltachat-rpc-client/examples/echobot_advanced.py index 88ddd1303..28b1750b6 100644 --- a/deltachat-rpc-client/examples/echobot_advanced.py +++ b/deltachat-rpc-client/examples/echobot_advanced.py @@ -25,14 +25,15 @@ async def log_error(event): logging.error(event.msg) -@hooks.on(events.NewMessage(r".+", func=lambda msg: not msg.text.startswith("/"))) -async def echo(msg): - await msg.chat.send_text(msg.text) +@hooks.on(events.NewMessage(func=lambda e: not e.command)) +async def echo(event): + if event.text or event.file: + await event.chat.send_message(text=event.text, file=event.file) -@hooks.on(events.NewMessage(r"/help")) -async def help_command(msg): - await msg.chat.send_text("Send me any text message and I will echo it back") +@hooks.on(events.NewMessage(command="/help")) +async def help_command(event): + await event.chat.send_text("Send me any message and I will echo it back") async def main(): diff --git a/deltachat-rpc-client/src/deltachat_rpc_client/client.py b/deltachat-rpc-client/src/deltachat_rpc_client/client.py index 2a749eefe..6abcc8a59 100644 --- a/deltachat-rpc-client/src/deltachat_rpc_client/client.py +++ b/deltachat-rpc-client/src/deltachat_rpc_client/client.py @@ -4,8 +4,8 @@ from typing import Callable, Dict, Iterable, Optional, Set, Tuple, Type, Union from deltachat_rpc_client.account import Account -from .const import EventType -from .events import EventFilter, NewInfoMessage, NewMessage, RawEvent +from .const import COMMAND_PREFIX, EventType +from .events import EventFilter, NewMessage, RawEvent from .utils import AttrDict @@ -79,16 +79,48 @@ class Client: self.logger.exception(ex) def _should_process_messages(self) -> bool: - return any(issubclass(filter_type, NewMessage) for filter_type in self._hooks) + return NewMessage in self._hooks + + async def _parse_command(self, snapshot: AttrDict) -> None: + cmds = [ + hook[1].command + for hook in self._hooks.get(NewMessage, []) + if hook[1].command + ] + parts = snapshot.text.split(maxsplit=1) + payload = parts[1] if len(parts) > 1 else "" + cmd = parts.pop(0) + + if "@" in cmd: + suffix = "@" + (await self.account.self_contact.get_snapshot()).address + if cmd.endswith(suffix): + cmd = cmd[: -len(suffix)] + else: + return + + parts = cmd.split("_") + _payload = payload + while parts: + _cmd = "_".join(parts) + if _cmd in cmds: + break + _payload = (parts.pop() + " " + _payload).rstrip() + + if parts: + cmd = _cmd + payload = _payload + + snapshot["command"] = cmd + snapshot["payload"] = payload async def _process_messages(self) -> None: if self._should_process_messages(): for message in await self.account.get_fresh_messages_in_arrival_order(): snapshot = await message.get_snapshot() - if snapshot.is_info: - await self._on_event(snapshot, NewInfoMessage) - else: - await self._on_event(snapshot, NewMessage) + snapshot["command"], snapshot["payload"] = "", "" + if not snapshot.is_info and snapshot.text.startswith(COMMAND_PREFIX): + await self._parse_command(snapshot) + await self._on_event(snapshot, NewMessage) await snapshot.message.mark_seen() diff --git a/deltachat-rpc-client/src/deltachat_rpc_client/const.py b/deltachat-rpc-client/src/deltachat_rpc_client/const.py index c8bb925fc..b2a4e7f12 100644 --- a/deltachat-rpc-client/src/deltachat_rpc_client/const.py +++ b/deltachat-rpc-client/src/deltachat_rpc_client/const.py @@ -1,5 +1,7 @@ from enum import Enum, IntEnum +COMMAND_PREFIX = "/" + class ContactFlag(IntEnum): VERIFIED_ONLY = 0x01 diff --git a/deltachat-rpc-client/src/deltachat_rpc_client/events.py b/deltachat-rpc-client/src/deltachat_rpc_client/events.py index 606fe9896..7654bdf8b 100644 --- a/deltachat-rpc-client/src/deltachat_rpc_client/events.py +++ b/deltachat-rpc-client/src/deltachat_rpc_client/events.py @@ -91,8 +91,15 @@ class RawEvent(EventFilter): class NewMessage(EventFilter): """Matches whenever a new message arrives. - Warning: registering a handler for this event or any subclass will cause the messages + Warning: registering a handler for this event will cause the messages to be marked as read. Its usage is mainly intended for bots. + + :param pattern: if set, this Pattern will be used to filter the message by its text + content. + :param command: If set, only match messages with the given command (ex. /help). + :param is_info: If set to True only match info/system messages, if set to False + only match messages that are not info/system messages. If omitted + info/system messages as well as normal messages will be matched. """ def __init__( @@ -103,9 +110,15 @@ class NewMessage(EventFilter): Callable[[str], bool], re.Pattern, ] = None, + command: Optional[str] = None, + is_info: Optional[bool] = None, func: Optional[Callable[[AttrDict], bool]] = None, ) -> None: super().__init__(func=func) + self.is_info = is_info + if command is not None and not isinstance(command, str): + raise TypeError("Invalid command") + self.command = command if isinstance(pattern, str): pattern = re.compile(pattern) if isinstance(pattern, re.Pattern): @@ -119,11 +132,20 @@ class NewMessage(EventFilter): return hash((self.pattern, self.func)) def __eq__(self, other) -> bool: - if type(other) is self.__class__: # noqa - return (self.pattern, self.func) == (other.pattern, other.func) + if isinstance(other, NewMessage): + return (self.pattern, self.command, self.is_info, self.func) == ( + other.pattern, + other.command, + other.is_info, + other.func, + ) return False async def filter(self, event: AttrDict) -> bool: + if self.is_info is not None and self.is_info != event.is_info: + return False + if self.command and self.command != event.command: + return False if self.pattern: match = self.pattern(event.text) if inspect.isawaitable(match): @@ -133,10 +155,6 @@ class NewMessage(EventFilter): return await super()._call_func(event) -class NewInfoMessage(NewMessage): - """Matches whenever a new info/system message arrives.""" - - class HookCollection: """ Helper class to collect event hooks that can later be added to a Delta Chat client. diff --git a/deltachat-rpc-client/tests/test_something.py b/deltachat-rpc-client/tests/test_something.py index 6d8136058..7c265f1ec 100644 --- a/deltachat-rpc-client/tests/test_something.py +++ b/deltachat-rpc-client/tests/test_something.py @@ -235,12 +235,21 @@ async def test_bot(acfactory) -> None: res = [] bot.add_hook(callback, events.NewMessage(r"hello")) - snapshot1 = AttrDict(text="hello") - snapshot2 = AttrDict(text="hello, world") - snapshot3 = AttrDict(text="hey!") + bot.add_hook(callback, events.NewMessage(command="/help")) + snapshot1 = AttrDict(text="hello", command=None) + snapshot2 = AttrDict(text="hello, world", command=None) + snapshot3 = AttrDict(text="hey!", command=None) for snapshot in [snapshot1, snapshot2, snapshot3]: await bot._on_event(snapshot, events.NewMessage) assert len(res) == 2 assert snapshot1 in res assert snapshot2 in res assert snapshot3 not in res + + res = [] + bot.remove_hook(callback, events.NewMessage(r"hello")) + snapshot4 = AttrDict(command="/help") + await bot._on_event(snapshot, events.NewMessage) + await bot._on_event(snapshot4, events.NewMessage) + assert len(res) == 1 + assert snapshot4 in res