Source code for astra.events.filters

# -----------------------------------------------------------
# Astra - WhatsApp Client Framework
# Licensed under the Apache License 2.0.
# -----------------------------------------------------------

"""
This module provides the criteria (filters) used to match events
in the Astra framework.
"""

import re
from abc import ABC, abstractmethod
from typing import Any, List, Union, Callable, Optional

[docs] class Criterion(ABC): """ Base class for all event matching criteria. Criteria can be combined using logical operators: - `&` (AND) - `|` (OR) - `~` (NOT) """
[docs] @abstractmethod async def matches(self, event: Any) -> bool: """Determines if the given event matches this criterion.""" pass
def __and__(self, other: 'Criterion') -> 'AndCriterion': return AndCriterion(self, other) def __or__(self, other: 'Criterion') -> 'OrCriterion': return OrCriterion(self, other) def __invert__(self) -> 'NotCriterion': return NotCriterion(self) # Alias for EventEmitter compatibility async def passes(self, event: Any) -> bool: return await self.matches(event)
# --- Logical Combinators ---
[docs] class AndCriterion(Criterion): def __init__(self, c1: Criterion, c2: Criterion): self.c1, self.c2 = c1, c2
[docs] async def matches(self, event: Any) -> bool: return await self.c1.matches(event) and await self.c2.matches(event)
[docs] class OrCriterion(Criterion): def __init__(self, c1: Criterion, c2: Criterion): self.c1, self.c2 = c1, c2
[docs] async def matches(self, event: Any) -> bool: return await self.c1.matches(event) or await self.c2.matches(event)
[docs] class NotCriterion(Criterion): def __init__(self, c: Criterion): self.c = c
[docs] async def matches(self, event: Any) -> bool: return not await self.c.matches(event)
[docs] class AllCriterion(Criterion):
[docs] async def matches(self, event: Any) -> bool: return True
# --- Predicate Implementations ---
[docs] class ChatTypeCriterion(Criterion): def __init__(self, is_group: bool): self.is_group = is_group
[docs] async def matches(self, event: Any) -> bool: # Check EventContext or raw objects val = getattr(event, 'is_group', None) return val is self.is_group
[docs] class TextMatchCriterion(Criterion): def __init__(self, pattern: str, exact: bool = False, ignore_case: bool = True): self.pattern = pattern.lower() if ignore_case else pattern self.exact = exact self.ignore_case = ignore_case
[docs] async def matches(self, event: Any) -> bool: text = str(getattr(event, 'text', "")).strip() if self.ignore_case: text = text.lower() if self.exact: return text == self.pattern return self.pattern in text
[docs] class RegexCriterion(Criterion): def __init__(self, pattern: str, flags: int = re.IGNORECASE): self.regex = re.compile(pattern, flags)
[docs] async def matches(self, event: Any) -> bool: text = str(getattr(event, 'text', "")) return bool(self.regex.search(text))
[docs] class IdentityCriterion(Criterion): def __init__(self, ids: Union[str, List[str]], field: str = 'chat_id'): self.ids = {ids} if isinstance(ids, str) else set(ids) self.field = field
[docs] async def matches(self, event: Any) -> bool: val = getattr(event, self.field, None) return str(val) in self.ids
[docs] class DirectionCriterion(Criterion): def __init__(self, outgoing: bool): self.outgoing = outgoing
[docs] async def matches(self, event: Any) -> bool: from_me = getattr(event, 'from_me', None) if from_me is None and hasattr(event, '_event'): from_me = getattr(event._event, 'from_me', False) return bool(from_me) is self.outgoing
[docs] class TypeCriterion(Criterion): def __init__(self, attr: str, value: Any = True): self.attr = attr self.value = value
[docs] async def matches(self, event: Any) -> bool: actual = getattr(event, self.attr, None) if hasattr(actual, 'value'): actual = actual.value return actual == self.value
[docs] class CommandCriterion(Criterion): def __init__(self, command: Union[str, List[str]], prefixes: str = "/."): self.commands = {command} if isinstance(command, str) else set(command) self.prefixes = prefixes
[docs] async def matches(self, event: Any) -> bool: from .context import EventContext if isinstance(event, EventContext): # 1. Check command name match if event.command not in self.commands: return False # 2. Check prefix match (if prefixes defined) # If self.prefixes is empty, we allow anything (prefix-less). # Otherwise, we strictly match provided prefixes. if self.prefixes and event.prefix not in self.prefixes: return False return True return False
# --- Human-Friendly Namespace ---
[docs] class Filters: """ Collection of event matching criteria. """ # --- Direction & Identity --- all = AllCriterion() incoming = DirectionCriterion(outgoing=False) outgoing = DirectionCriterion(outgoing=True) me = DirectionCriterion(outgoing=True) private = ChatTypeCriterion(is_group=False) group = ChatTypeCriterion(is_group=True) # --- Media Types --- media = TypeCriterion("is_media", True) photo = TypeCriterion("type", "image") video = TypeCriterion("type", "video") audio = TypeCriterion("type", "audio") voice = TypeCriterion("type", "ptt") document = TypeCriterion("type", "document") sticker = TypeCriterion("type", "sticker") location = TypeCriterion("type", "location") contact = TypeCriterion("type", "vcard") poll = TypeCriterion("type", "poll_creation") # --- Structural --- service = TypeCriterion("is_service", True) forwarded = TypeCriterion("is_forwarded", True) quoted = TypeCriterion("has_quoted_msg", True)
[docs] @staticmethod def command(name: Union[str, List[str]], prefixes: str = "/.") -> Criterion: """ Matches if the message is a specific command. Example: command("ping", ".") -> Matches ".ping" command(".ping") -> Matches ".ping" (infer prefix from name) """ # Surgical normalization: if name starts with a known prefix, # we treat that char as the only valid prefix. if isinstance(name, str): for p in prefixes: if name.startswith(p): return CommandCriterion(name[len(p):], p) return CommandCriterion(name, prefixes) # Handle list of names (aliases) cmds = [] for n in name: found_prefix = False for p in prefixes: if n.startswith(p): cmds.append((n[len(p):], p)) found_prefix = True break if not found_prefix: cmds.append((n, prefixes)) # If all have the same prefix, we can group them first_p = cmds[0][1] if all(c[1] == first_p for c in cmds): return CommandCriterion([c[0] for c in cmds], first_p) # Otherwise, we have to return an OR of CommandCriteria res = CommandCriterion(cmds[0][0], cmds[0][1]) for c in cmds[1:]: res = res | CommandCriterion(c[0], c[1]) return res
[docs] @staticmethod def text_contains(text: str, ignore_case: bool = True) -> Criterion: """Matches if the message text contains the given substring.""" return TextMatchCriterion(text, exact=False, ignore_case=ignore_case)
[docs] @staticmethod def text_is(text: str, ignore_case: bool = True) -> Criterion: """Matches if the message text exactly matches the given string.""" return TextMatchCriterion(text, exact=True, ignore_case=ignore_case)
[docs] @staticmethod def regex(pattern: str) -> Criterion: """Matches message text against a regular expression.""" return RegexCriterion(pattern)
[docs] @staticmethod def chat(chat_id: Union[str, List[str]]) -> Criterion: """Matches events from specific chat JIDs.""" return IdentityCriterion(chat_id, field='chat_id')
[docs] @staticmethod def sender(user_id: Union[str, List[str]]) -> Criterion: """Matches events from specific user JIDs.""" return IdentityCriterion(user_id, field='sender_id')
[docs] @staticmethod def has_attribute(attribute: str, value: Any = True) -> Criterion: """Custom attribute matcher for advanced use cases.""" return TypeCriterion(attribute, value)
[docs] @staticmethod def create(func: Callable[[Any], Union[bool, Any]]) -> Criterion: """ Creates a custom filter from a callable. """ class CustomCriterion(Criterion): async def matches(self, event: Any) -> bool: import asyncio return await func(event) if asyncio.iscoroutinefunction(func) else bool(func(event)) return CustomCriterion()