# -----------------------------------------------------------
# Astra - WhatsApp Client Framework
# Licensed under the Apache License 2.0.
# -----------------------------------------------------------
"""
The EventDispatcher is the central brain of Astra's event system.
It routes incoming data from the bridge to appropriate handlers and waiters.
"""
import logging
import asyncio
import uuid
import time
from typing import Any, Callable, Dict, List, Optional, Union, Awaitable
from dataclasses import dataclass, field
from .filters import Criterion
from .context import EventContext
from ..network.serializers import DataTransformer
logger = logging.getLogger("Events")
@dataclass
class EventSubscription:
"""Represents a registered event or command handler."""
id: str
event_name: str
callback: Callable[..., Awaitable[None]]
criteria: Optional[Criterion] = None
priority: int = 0
is_command: bool = False
command_name: Optional[str] = None
[docs]
class EventDispatcher:
"""
Routes incoming events across the application.
Handles priority-based execution, command parsing, and event interception
via middlewares.
"""
def __init__(self, client):
self._client = client
self._subscriptions: Dict[str, List[EventSubscription]] = {}
self._waiters: Dict[str, List[Dict[str, Any]]] = {}
self._middlewares: List[Callable[[str, Any], Awaitable[bool]]] = []
self._on_event_received: Optional[Callable[[], None]] = None
self._event_count: int = 0
[docs]
def add_middleware(self, middleware: Callable[[str, Any], Awaitable[bool]]):
"""Injects a middleware to pre-process or block events."""
self._middlewares.append(middleware)
[docs]
async def wait_for(self, event_name: str, criteria: Optional[Criterion] = None, timeout: Optional[float] = None) -> Any:
"""
Suspends execution until a specific event arrives.
"""
future = asyncio.get_event_loop().create_future()
waiter = {"future": future, "criteria": criteria}
self._waiters.setdefault(event_name, []).append(waiter)
try:
return await asyncio.wait_for(future, timeout=timeout)
finally:
# Ensure the waiter is removed regardless of success or timeout
if event_name in self._waiters:
self._waiters[event_name] = [w for w in self._waiters[event_name] if w["future"] is not future]
[docs]
def subscribe(
self,
event_name: str,
callback: Callable[..., Awaitable[None]],
criteria: Optional[Criterion] = None,
priority: int = 0,
is_command: bool = False,
command_name: Optional[str] = None
) -> str:
"""
Registers a new event handler.
"""
sub_id = str(uuid.uuid4())
sub = EventSubscription(
id=sub_id,
event_name=event_name,
callback=callback,
criteria=criteria,
priority=priority,
is_command=is_command,
command_name=command_name
)
self._subscriptions.setdefault(event_name, []).append(sub)
# Higher priority runs first
self._subscriptions[event_name].sort(key=lambda s: s.priority, reverse=True)
logger.debug(f"Subscribed: {event_name} (id={sub_id[:8]}, priority={priority})")
return sub_id
[docs]
def set_event_callback(self, callback: Callable[[], None]):
"""Registers a callback invoked on every event dispatch (used by SyncEngine)."""
self._on_event_received = callback
[docs]
async def dispatch(self, name: str, payload: Any):
"""
Dispatches an event through the engine.
"""
correlation_id = str(uuid.uuid4())[:8]
self._event_count += 1
logger.debug(f"Dispatcher received event '{name}' (ID: {correlation_id})")
# Notify SyncEngine that an event just arrived (resets stall timer)
if self._on_event_received:
try:
self._on_event_received()
except Exception:
pass
# 1. Transform raw payload to Astra models
is_msg = name in {"message", "msg"}
if is_msg: name = "message"
is_reaction = name == "reaction"
if (is_msg or is_reaction) and isinstance(payload, dict):
payload = DataTransformer.to_message(payload, self._client)
# 1b. Filter stale events (older than startup)
if hasattr(payload, 'timestamp') and payload.timestamp < self._client.startup_at:
logger.debug(f"[{correlation_id}] Ignoring stale event '{name}' (ts: {payload.timestamp}, startup: {round(self._client.startup_at)})")
return
# 1c. Ignore revoked (deleted) messages to prevent recursive loops
if is_msg and getattr(payload, 'type', None) == "revoked":
logger.debug(f"[{correlation_id}] Ignoring revoked message.")
return
# 2. Contextualize (for Message/Reaction events)
prefix, cmd, args = (None, None, [])
if is_msg:
text = (getattr(payload, 'text', None) or getattr(payload, 'body', None) or "").strip()
logger.debug(f"[{correlation_id}] Captured text: '{text}'")
if text:
# Resolve prefix and command
has_prefix = text[0] in "/.!"
if has_prefix:
prefix = text[0]
parts = text[1:].split()
if parts:
cmd = parts[0].lower()
args = parts[1:]
else:
# Prefix-less support
parts = text.split()
if parts:
cmd = parts[0].lower()
args = parts[1:]
# Sanity check: command names are rarely longer than 32 chars
if cmd and len(cmd) > 32:
logger.warning(f"[{correlation_id}] Ignoring suspicious long command name: {cmd[:50]}...")
cmd = None
args = []
logger.debug(f"[{correlation_id}] Parsed Command: {cmd}, Prefix: {prefix}")
# Upgrade payload to EventContext
payload = EventContext(self._client, payload, cmd, args, prefix)
elif is_reaction:
# Wrap reaction in context as well (for consistency)
payload = EventContext(self._client, payload)
# 3. Middleware Execution
for middleware in self._middlewares:
try:
if not await middleware(name, payload):
logger.debug(f"[{correlation_id}] Event '{name}' intercepted by middleware.")
return
except Exception as e:
logger.error(f"[{correlation_id}] Middleware error: {e}")
# 4. Notify Listeners (via Client EventEmitter)
if hasattr(self._client, "events"):
logger.debug(f"[{correlation_id}] Emitting '{name}' to event emitter")
self._client.events.emit(name, payload)
# 5. Handle Waiters
await self._process_waiters(name, payload, cmd, args, prefix, correlation_id)
# 6. Run Registered Subscriptions
if name in self._subscriptions:
for sub in self._subscriptions[name]:
# Command Check
if sub.is_command and (not is_msg or cmd != sub.command_name):
continue
# Criteria Check
if sub.criteria:
if not await sub.criteria.passes(payload):
continue
# Run the handler in a new task
asyncio.create_task(self._execute_safe(sub, payload, cmd, args, prefix, correlation_id))
# --- Private Execution Logic ---
async def _process_waiters(self, name, payload, cmd, args, prefix, correlation_id):
if name not in self._waiters: return
for waiter in self._waiters[name][:]:
if waiter["future"].done(): continue
if waiter["criteria"]:
ctx = EventContext(self._client, payload, cmd, args, prefix) if name == "message" else payload
if not await waiter["criteria"].passes(ctx): continue
waiter["future"].set_result(payload)
async def _execute_safe(self, sub, payload, cmd, args, prefix, cid):
start = time.time()
try:
if sub.is_command:
ctx = EventContext(self._client, payload, cmd, args, prefix)
await sub.callback(ctx)
else:
await sub.callback(payload)
ms = (time.time() - start) * 1000
logger.debug(f"[{cid}] '{sub.callback.__name__}' done in {ms:.2f}ms")
except Exception as e:
logger.error(f"[{cid}] Handler '{sub.callback.__name__}' failed: {e}", exc_info=True)