improve hook filters

This commit is contained in:
adbenitez
2022-12-10 19:22:57 -05:00
parent 1f7ad78f40
commit be63e18ebf
6 changed files with 87 additions and 25 deletions

View File

@@ -17,8 +17,8 @@ async def log_event(event):
@hooks.on(events.NewMessage) @hooks.on(events.NewMessage)
async def echo(msg): async def echo(event):
await msg.chat.send_text(msg.text) await event.chat.send_text(event.text)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -25,14 +25,15 @@ async def log_error(event):
logging.error(event.msg) logging.error(event.msg)
@hooks.on(events.NewMessage(r".+", func=lambda msg: not msg.text.startswith("/"))) @hooks.on(events.NewMessage(func=lambda e: not e.command))
async def echo(msg): async def echo(event):
await msg.chat.send_text(msg.text) if event.text or event.file:
await event.chat.send_message(text=event.text, file=event.file)
@hooks.on(events.NewMessage(r"/help")) @hooks.on(events.NewMessage(command="/help"))
async def help_command(msg): async def help_command(event):
await msg.chat.send_text("Send me any text message and I will echo it back") await event.chat.send_text("Send me any message and I will echo it back")
async def main(): async def main():

View File

@@ -4,8 +4,8 @@ from typing import Callable, Dict, Iterable, Optional, Set, Tuple, Type, Union
from deltachat_rpc_client.account import Account from deltachat_rpc_client.account import Account
from .const import EventType from .const import COMMAND_PREFIX, EventType
from .events import EventFilter, NewInfoMessage, NewMessage, RawEvent from .events import EventFilter, NewMessage, RawEvent
from .utils import AttrDict from .utils import AttrDict
@@ -79,16 +79,48 @@ class Client:
self.logger.exception(ex) self.logger.exception(ex)
def _should_process_messages(self) -> bool: def _should_process_messages(self) -> bool:
return any(issubclass(filter_type, NewMessage) for filter_type in self._hooks) return NewMessage in self._hooks
async def _parse_command(self, snapshot: AttrDict) -> None:
cmds = [
hook[1].command
for hook in self._hooks.get(NewMessage, [])
if hook[1].command
]
parts = snapshot.text.split(maxsplit=1)
payload = parts[1] if len(parts) > 1 else ""
cmd = parts.pop(0)
if "@" in cmd:
suffix = "@" + (await self.account.self_contact.get_snapshot()).address
if cmd.endswith(suffix):
cmd = cmd[: -len(suffix)]
else:
return
parts = cmd.split("_")
_payload = payload
while parts:
_cmd = "_".join(parts)
if _cmd in cmds:
break
_payload = (parts.pop() + " " + _payload).rstrip()
if parts:
cmd = _cmd
payload = _payload
snapshot["command"] = cmd
snapshot["payload"] = payload
async def _process_messages(self) -> None: async def _process_messages(self) -> None:
if self._should_process_messages(): if self._should_process_messages():
for message in await self.account.get_fresh_messages_in_arrival_order(): for message in await self.account.get_fresh_messages_in_arrival_order():
snapshot = await message.get_snapshot() snapshot = await message.get_snapshot()
if snapshot.is_info: snapshot["command"], snapshot["payload"] = "", ""
await self._on_event(snapshot, NewInfoMessage) if not snapshot.is_info and snapshot.text.startswith(COMMAND_PREFIX):
else: await self._parse_command(snapshot)
await self._on_event(snapshot, NewMessage) await self._on_event(snapshot, NewMessage)
await snapshot.message.mark_seen() await snapshot.message.mark_seen()

View File

@@ -1,5 +1,7 @@
from enum import Enum, IntEnum from enum import Enum, IntEnum
COMMAND_PREFIX = "/"
class ContactFlag(IntEnum): class ContactFlag(IntEnum):
VERIFIED_ONLY = 0x01 VERIFIED_ONLY = 0x01

View File

@@ -91,8 +91,15 @@ class RawEvent(EventFilter):
class NewMessage(EventFilter): class NewMessage(EventFilter):
"""Matches whenever a new message arrives. """Matches whenever a new message arrives.
Warning: registering a handler for this event or any subclass will cause the messages Warning: registering a handler for this event will cause the messages
to be marked as read. Its usage is mainly intended for bots. to be marked as read. Its usage is mainly intended for bots.
:param pattern: if set, this Pattern will be used to filter the message by its text
content.
:param command: If set, only match messages with the given command (ex. /help).
:param is_info: If set to True only match info/system messages, if set to False
only match messages that are not info/system messages. If omitted
info/system messages as well as normal messages will be matched.
""" """
def __init__( def __init__(
@@ -103,9 +110,15 @@ class NewMessage(EventFilter):
Callable[[str], bool], Callable[[str], bool],
re.Pattern, re.Pattern,
] = None, ] = None,
command: Optional[str] = None,
is_info: Optional[bool] = None,
func: Optional[Callable[[AttrDict], bool]] = None, func: Optional[Callable[[AttrDict], bool]] = None,
) -> None: ) -> None:
super().__init__(func=func) super().__init__(func=func)
self.is_info = is_info
if command is not None and not isinstance(command, str):
raise TypeError("Invalid command")
self.command = command
if isinstance(pattern, str): if isinstance(pattern, str):
pattern = re.compile(pattern) pattern = re.compile(pattern)
if isinstance(pattern, re.Pattern): if isinstance(pattern, re.Pattern):
@@ -119,11 +132,20 @@ class NewMessage(EventFilter):
return hash((self.pattern, self.func)) return hash((self.pattern, self.func))
def __eq__(self, other) -> bool: def __eq__(self, other) -> bool:
if type(other) is self.__class__: # noqa if isinstance(other, NewMessage):
return (self.pattern, self.func) == (other.pattern, other.func) return (self.pattern, self.command, self.is_info, self.func) == (
other.pattern,
other.command,
other.is_info,
other.func,
)
return False return False
async def filter(self, event: AttrDict) -> bool: async def filter(self, event: AttrDict) -> bool:
if self.is_info is not None and self.is_info != event.is_info:
return False
if self.command and self.command != event.command:
return False
if self.pattern: if self.pattern:
match = self.pattern(event.text) match = self.pattern(event.text)
if inspect.isawaitable(match): if inspect.isawaitable(match):
@@ -133,10 +155,6 @@ class NewMessage(EventFilter):
return await super()._call_func(event) return await super()._call_func(event)
class NewInfoMessage(NewMessage):
"""Matches whenever a new info/system message arrives."""
class HookCollection: class HookCollection:
""" """
Helper class to collect event hooks that can later be added to a Delta Chat client. Helper class to collect event hooks that can later be added to a Delta Chat client.

View File

@@ -235,12 +235,21 @@ async def test_bot(acfactory) -> None:
res = [] res = []
bot.add_hook(callback, events.NewMessage(r"hello")) bot.add_hook(callback, events.NewMessage(r"hello"))
snapshot1 = AttrDict(text="hello") bot.add_hook(callback, events.NewMessage(command="/help"))
snapshot2 = AttrDict(text="hello, world") snapshot1 = AttrDict(text="hello", command=None)
snapshot3 = AttrDict(text="hey!") snapshot2 = AttrDict(text="hello, world", command=None)
snapshot3 = AttrDict(text="hey!", command=None)
for snapshot in [snapshot1, snapshot2, snapshot3]: for snapshot in [snapshot1, snapshot2, snapshot3]:
await bot._on_event(snapshot, events.NewMessage) await bot._on_event(snapshot, events.NewMessage)
assert len(res) == 2 assert len(res) == 2
assert snapshot1 in res assert snapshot1 in res
assert snapshot2 in res assert snapshot2 in res
assert snapshot3 not in res assert snapshot3 not in res
res = []
bot.remove_hook(callback, events.NewMessage(r"hello"))
snapshot4 = AttrDict(command="/help")
await bot._on_event(snapshot, events.NewMessage)
await bot._on_event(snapshot4, events.NewMessage)
assert len(res) == 1
assert snapshot4 in res