Source code for pyrobale.filters

import re
from typing import Callable, List, Optional, Union, TYPE_CHECKING
import asyncio

if TYPE_CHECKING:
    from ..objects.user import User
    from ..client import Client

[docs] class Filter: def __init__(self, check_func: Callable, inv: bool = False) -> None:
[docs] self.state = (check_func, inv)
[docs] self.lst = [self.state]
[docs] def __invert__(self): self.state = (self.state[0], not self.state[1]) self.lst = [self.state] return self
[docs] def __and__(self, other): if not isinstance(other, Filter): raise TypeError(f"Cannot combine Filter with {type(other)}") def combined_check(event, client=None, *args): if self.state[1]: res1 = not self.state[0](event, client, *args) else: res1 = self.state[0](event, client, *args) if other.state[1]: res2 = not other.state[0](event, client, *args) else: res2 = other.state[0](event, client, *args) return res1 and res2 new_filter = Filter(self.state[0], self.state[1]) new_filter.lst = self.lst + other.lst new_filter.state = (combined_check, False) return new_filter
[docs] def __call__(self, event, client=None, *args): if self.state[1]: return not self.state[0](event, client, *args) return self.state[0](event, client, *args)
[docs] def equals(expected_text: str): """ Check if the event text or caption or callbackQuery data is equal to the expected text. Args: expected_text (str): The expected text to compare with. Returns: Callable: A function that checks if the event text or caption or callbackQuery data is equal to the expected text. """ async def check(event, *args): try: return getattr(event, "text", None) == expected_text or getattr(event, "caption", None) == expected_text or getattr(event, "data", None) == expected_text except: return False return Filter(check)
[docs] def startswith(expected_text: str): """ Check if the event text or caption or callbackQuery data is started with to the expected text. Args: expected_text (str): The expected text to compare with. Returns: Callable: A function that checks if the event text or caption or callbackQuery data is started with to the expected text. """ async def check(event, *args): try: return getattr(event, "text", "").startswith(expected_text) or getattr(event, "caption", "").startswith(expected_text) or getattr(event, "data", "").startswith(expected_text) except: return False return Filter(check)
[docs] def regex(pattern: str): """ checks the event text or caption with given pattern using regex Args: pattern (str): The pattern to check with text Returns: Callable: A function that checks if the event text or caption is match with given pattern """ async def check(event, *args): try: return re.search(pattern, getattr(event, "text", "")) or re.search(pattern, getattr(event, "caption", "")) except: return False return Filter(check)
[docs] def from_users(allowed_users: Union[List[Union["User", int, str]], int, str]): """ Check if the event text or caption or callbackQuery sender is in allowed user. Args: allowed_users (List[Union["User", int]]): Allowed users to use this handler. Returns: Callable: A function that checks if the event text or caption or callbackQuery sender is in allowed user. """ if type(allowed_users) in [str, int]: try: allowed_users = [int(allowed_users)] except: raise ValueError("Chat IDs can only be digits") async def check(event, *args): try: event_user = getattr(event, "user", None) event_user_id = getattr(event_user, "id") if event_user_id in allowed_users: return True except: return False return Filter(check)
[docs] def is_joined(chat_ids: Union[List[Union["User", int, str]], int, str]): """ Checks if the event User is joined in specified chats. Args: allowed_users (List[Union["User", int]]): Allowed users to use this handler. Returns: Callable: A function that checks if the event User is joined in specified chats """ if type(chat_ids) in [str, int]: try: chat_ids = [int(chat_ids)] except: raise ValueError("Chat IDs can only be digits") async def check(event, client: 'Client', *args): try: event_user = getattr(event, "user", None) event_user_id = getattr(event_user, "id") for chat in chat_ids: data = await client.make_post( client.requests_base + "/getChatMember", data={"chat_id": chat, "user_id": event_user_id}, ) joined = data.get("result", {}).get("status") in ["member", "creator", "administrator"] if not joined: return False return True except: return False return Filter(check)
[docs] def at_state(state: Optional[str] = None): """ Checks if the event User is at specified state. Args: state (Optional[str]): state condition Returns: Callable: A function that checks if the event User is at specified state. """ async def check(event, client: 'Client', *args): try: event_user = getattr(event, "user", None) event_user_id = getattr(event_user, "id") return client.state_machine.get_state(event_user_id) == state except: return False return Filter(check)
def _private(): """ checks if the event is happening in a private chat """ async def check(event, *args): try: chat = getattr(event, "chat") return getattr(chat, "private") except: return False return Filter(check) def _group(): """ checks if the event is happening in a group chat """ async def check(event, *args): try: chat = getattr(event, "chat") return chat.type == chat.type.GROUP except: return False return Filter(check) def _reply(): """ Checks if the event is a reply to a message. """ async def check(event, *args): try: return getattr(event, "reply_to_message") is not None except: return False return Filter(check) def _forward(): """ Checks if the event is a forwarded message. """ async def check(event, *args): try: return getattr(event, "forward_from") is not None except: return False return Filter(check) def _gif(): """ Checks if the event has a gif media. """ async def check(event, *args): try: return getattr(event, "animation") is not None except: return False return Filter(check) def _digit(): """ Check if the event text or caption or callbackQuery data is digit. Returns: Callable: A function that checks if the event text or caption or callbackQuery data is digit. """ async def check(event, *args): try: return getattr(event, "text", "").isdigit() or getattr(event, "caption", "").isdigit() or getattr(event, "data", "").isdigit() except: return False return Filter(check) def _channel(): """ checks if the event is happening in a channel """ async def check(event, args): try: chat = getattr(event, "chat") return getattr(chat, "channel") except: return False return Filter(check)
[docs] def func(function: Callable): async def check(event, *args): try: return function(event) except: return False return Filter(check)
def _text(): async def check(event, *args): return hasattr(event, "text") and bool(getattr(event, "text")) return Filter(check) def _photo(): async def check(event, *args): return hasattr(event, "photo") and bool(getattr(event, "photo")) return Filter(check) def _video(): async def check(event, *args): return hasattr(event, "video") and bool(getattr(event, "video")) return Filter(check) def _audio(): async def check(event, *args): return hasattr(event, "audio") and bool(getattr(event, "audio")) return Filter(check) def _voice(): async def check(event, *args): return hasattr(event, "voice") and bool(getattr(event, "voice")) return Filter(check) def _contact(): async def check(event, *args): return hasattr(event, "contact") and bool(getattr(event, "contact")) return Filter(check) def _location(): async def check(event, *args): return hasattr(event, "location") and bool(getattr(event, "location")) return Filter(check) private = pv = _private()
[docs] channel = _channel()
[docs] group = _group()
[docs] digit = _digit()
[docs] reply = _reply()
[docs] forward = _forward()
[docs] gif = _gif()
[docs] text = _text()
[docs] photo = _photo()
[docs] video = _video()
[docs] audio = _audio()
[docs] voice = _voice()
[docs] contact = _contact()
[docs] location = _location()