mirror of
https://github.com/chatmail/core.git
synced 2026-05-16 13:26:38 +03:00
improve hook filters
This commit is contained in:
@@ -17,8 +17,8 @@ async def log_event(event):
|
|||||||
|
|
||||||
|
|
||||||
@hooks.on(events.NewMessage)
|
@hooks.on(events.NewMessage)
|
||||||
async def echo(msg):
|
async def echo(event):
|
||||||
await msg.chat.send_text(msg.text)
|
await event.chat.send_text(event.text)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -25,14 +25,15 @@ async def log_error(event):
|
|||||||
logging.error(event.msg)
|
logging.error(event.msg)
|
||||||
|
|
||||||
|
|
||||||
@hooks.on(events.NewMessage(r".+", func=lambda msg: not msg.text.startswith("/")))
|
@hooks.on(events.NewMessage(func=lambda e: not e.command))
|
||||||
async def echo(msg):
|
async def echo(event):
|
||||||
await msg.chat.send_text(msg.text)
|
if event.text or event.file:
|
||||||
|
await event.chat.send_message(text=event.text, file=event.file)
|
||||||
|
|
||||||
|
|
||||||
@hooks.on(events.NewMessage(r"/help"))
|
@hooks.on(events.NewMessage(command="/help"))
|
||||||
async def help_command(msg):
|
async def help_command(event):
|
||||||
await msg.chat.send_text("Send me any text message and I will echo it back")
|
await event.chat.send_text("Send me any message and I will echo it back")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ from typing import Callable, Dict, Iterable, Optional, Set, Tuple, Type, Union
|
|||||||
|
|
||||||
from deltachat_rpc_client.account import Account
|
from deltachat_rpc_client.account import Account
|
||||||
|
|
||||||
from .const import EventType
|
from .const import COMMAND_PREFIX, EventType
|
||||||
from .events import EventFilter, NewInfoMessage, NewMessage, RawEvent
|
from .events import EventFilter, NewMessage, RawEvent
|
||||||
from .utils import AttrDict
|
from .utils import AttrDict
|
||||||
|
|
||||||
|
|
||||||
@@ -79,16 +79,48 @@ class Client:
|
|||||||
self.logger.exception(ex)
|
self.logger.exception(ex)
|
||||||
|
|
||||||
def _should_process_messages(self) -> bool:
|
def _should_process_messages(self) -> bool:
|
||||||
return any(issubclass(filter_type, NewMessage) for filter_type in self._hooks)
|
return NewMessage in self._hooks
|
||||||
|
|
||||||
|
async def _parse_command(self, snapshot: AttrDict) -> None:
|
||||||
|
cmds = [
|
||||||
|
hook[1].command
|
||||||
|
for hook in self._hooks.get(NewMessage, [])
|
||||||
|
if hook[1].command
|
||||||
|
]
|
||||||
|
parts = snapshot.text.split(maxsplit=1)
|
||||||
|
payload = parts[1] if len(parts) > 1 else ""
|
||||||
|
cmd = parts.pop(0)
|
||||||
|
|
||||||
|
if "@" in cmd:
|
||||||
|
suffix = "@" + (await self.account.self_contact.get_snapshot()).address
|
||||||
|
if cmd.endswith(suffix):
|
||||||
|
cmd = cmd[: -len(suffix)]
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
parts = cmd.split("_")
|
||||||
|
_payload = payload
|
||||||
|
while parts:
|
||||||
|
_cmd = "_".join(parts)
|
||||||
|
if _cmd in cmds:
|
||||||
|
break
|
||||||
|
_payload = (parts.pop() + " " + _payload).rstrip()
|
||||||
|
|
||||||
|
if parts:
|
||||||
|
cmd = _cmd
|
||||||
|
payload = _payload
|
||||||
|
|
||||||
|
snapshot["command"] = cmd
|
||||||
|
snapshot["payload"] = payload
|
||||||
|
|
||||||
async def _process_messages(self) -> None:
|
async def _process_messages(self) -> None:
|
||||||
if self._should_process_messages():
|
if self._should_process_messages():
|
||||||
for message in await self.account.get_fresh_messages_in_arrival_order():
|
for message in await self.account.get_fresh_messages_in_arrival_order():
|
||||||
snapshot = await message.get_snapshot()
|
snapshot = await message.get_snapshot()
|
||||||
if snapshot.is_info:
|
snapshot["command"], snapshot["payload"] = "", ""
|
||||||
await self._on_event(snapshot, NewInfoMessage)
|
if not snapshot.is_info and snapshot.text.startswith(COMMAND_PREFIX):
|
||||||
else:
|
await self._parse_command(snapshot)
|
||||||
await self._on_event(snapshot, NewMessage)
|
await self._on_event(snapshot, NewMessage)
|
||||||
await snapshot.message.mark_seen()
|
await snapshot.message.mark_seen()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from enum import Enum, IntEnum
|
from enum import Enum, IntEnum
|
||||||
|
|
||||||
|
COMMAND_PREFIX = "/"
|
||||||
|
|
||||||
|
|
||||||
class ContactFlag(IntEnum):
|
class ContactFlag(IntEnum):
|
||||||
VERIFIED_ONLY = 0x01
|
VERIFIED_ONLY = 0x01
|
||||||
|
|||||||
@@ -91,8 +91,15 @@ class RawEvent(EventFilter):
|
|||||||
class NewMessage(EventFilter):
|
class NewMessage(EventFilter):
|
||||||
"""Matches whenever a new message arrives.
|
"""Matches whenever a new message arrives.
|
||||||
|
|
||||||
Warning: registering a handler for this event or any subclass will cause the messages
|
Warning: registering a handler for this event will cause the messages
|
||||||
to be marked as read. Its usage is mainly intended for bots.
|
to be marked as read. Its usage is mainly intended for bots.
|
||||||
|
|
||||||
|
:param pattern: if set, this Pattern will be used to filter the message by its text
|
||||||
|
content.
|
||||||
|
:param command: If set, only match messages with the given command (ex. /help).
|
||||||
|
:param is_info: If set to True only match info/system messages, if set to False
|
||||||
|
only match messages that are not info/system messages. If omitted
|
||||||
|
info/system messages as well as normal messages will be matched.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -103,9 +110,15 @@ class NewMessage(EventFilter):
|
|||||||
Callable[[str], bool],
|
Callable[[str], bool],
|
||||||
re.Pattern,
|
re.Pattern,
|
||||||
] = None,
|
] = None,
|
||||||
|
command: Optional[str] = None,
|
||||||
|
is_info: Optional[bool] = None,
|
||||||
func: Optional[Callable[[AttrDict], bool]] = None,
|
func: Optional[Callable[[AttrDict], bool]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(func=func)
|
super().__init__(func=func)
|
||||||
|
self.is_info = is_info
|
||||||
|
if command is not None and not isinstance(command, str):
|
||||||
|
raise TypeError("Invalid command")
|
||||||
|
self.command = command
|
||||||
if isinstance(pattern, str):
|
if isinstance(pattern, str):
|
||||||
pattern = re.compile(pattern)
|
pattern = re.compile(pattern)
|
||||||
if isinstance(pattern, re.Pattern):
|
if isinstance(pattern, re.Pattern):
|
||||||
@@ -119,11 +132,20 @@ class NewMessage(EventFilter):
|
|||||||
return hash((self.pattern, self.func))
|
return hash((self.pattern, self.func))
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other) -> bool:
|
||||||
if type(other) is self.__class__: # noqa
|
if isinstance(other, NewMessage):
|
||||||
return (self.pattern, self.func) == (other.pattern, other.func)
|
return (self.pattern, self.command, self.is_info, self.func) == (
|
||||||
|
other.pattern,
|
||||||
|
other.command,
|
||||||
|
other.is_info,
|
||||||
|
other.func,
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def filter(self, event: AttrDict) -> bool:
|
async def filter(self, event: AttrDict) -> bool:
|
||||||
|
if self.is_info is not None and self.is_info != event.is_info:
|
||||||
|
return False
|
||||||
|
if self.command and self.command != event.command:
|
||||||
|
return False
|
||||||
if self.pattern:
|
if self.pattern:
|
||||||
match = self.pattern(event.text)
|
match = self.pattern(event.text)
|
||||||
if inspect.isawaitable(match):
|
if inspect.isawaitable(match):
|
||||||
@@ -133,10 +155,6 @@ class NewMessage(EventFilter):
|
|||||||
return await super()._call_func(event)
|
return await super()._call_func(event)
|
||||||
|
|
||||||
|
|
||||||
class NewInfoMessage(NewMessage):
|
|
||||||
"""Matches whenever a new info/system message arrives."""
|
|
||||||
|
|
||||||
|
|
||||||
class HookCollection:
|
class HookCollection:
|
||||||
"""
|
"""
|
||||||
Helper class to collect event hooks that can later be added to a Delta Chat client.
|
Helper class to collect event hooks that can later be added to a Delta Chat client.
|
||||||
|
|||||||
@@ -235,12 +235,21 @@ async def test_bot(acfactory) -> None:
|
|||||||
|
|
||||||
res = []
|
res = []
|
||||||
bot.add_hook(callback, events.NewMessage(r"hello"))
|
bot.add_hook(callback, events.NewMessage(r"hello"))
|
||||||
snapshot1 = AttrDict(text="hello")
|
bot.add_hook(callback, events.NewMessage(command="/help"))
|
||||||
snapshot2 = AttrDict(text="hello, world")
|
snapshot1 = AttrDict(text="hello", command=None)
|
||||||
snapshot3 = AttrDict(text="hey!")
|
snapshot2 = AttrDict(text="hello, world", command=None)
|
||||||
|
snapshot3 = AttrDict(text="hey!", command=None)
|
||||||
for snapshot in [snapshot1, snapshot2, snapshot3]:
|
for snapshot in [snapshot1, snapshot2, snapshot3]:
|
||||||
await bot._on_event(snapshot, events.NewMessage)
|
await bot._on_event(snapshot, events.NewMessage)
|
||||||
assert len(res) == 2
|
assert len(res) == 2
|
||||||
assert snapshot1 in res
|
assert snapshot1 in res
|
||||||
assert snapshot2 in res
|
assert snapshot2 in res
|
||||||
assert snapshot3 not in res
|
assert snapshot3 not in res
|
||||||
|
|
||||||
|
res = []
|
||||||
|
bot.remove_hook(callback, events.NewMessage(r"hello"))
|
||||||
|
snapshot4 = AttrDict(command="/help")
|
||||||
|
await bot._on_event(snapshot, events.NewMessage)
|
||||||
|
await bot._on_event(snapshot4, events.NewMessage)
|
||||||
|
assert len(res) == 1
|
||||||
|
assert snapshot4 in res
|
||||||
|
|||||||
Reference in New Issue
Block a user