Skip to content

Aiogram:Examples:InferenceServerProxy

추론 서버를 연결하는 프록시 Bot

Python Code

# -*- coding: utf-8 -*-

from argparse import Namespace
from asyncio import AbstractEventLoop, get_event_loop
from dataclasses import dataclass
from io import BytesIO, StringIO
from os import environ
from re import Match
from sys import version_info
from typing import Callable, Final, Iterable, List, Optional, Union

if version_info >= (3, 11):
    from asyncio import Runner  # type: ignore[attr-defined]

from aiogram import Bot, Dispatcher, executor
from aiogram.types import Message, PhotoSize
from aiohttp import ClientSession, ClientTimeout
from type_serialize import deserialize
from uvloop import install as uvloop_install
from uvloop import new_event_loop as uvloop_new_event_loop

from answersteelscrapbot.logging.logging import logger

ADMIN_ID_SEPARATOR: Final[str] = ":"
USER_AGENT: Final[str] = "YourBot"
SHOW_TOP2_THRESHOLD: Final[float] = 0.7

LOGGING_SUFFIX: Final[str] = "\n\nPowered by Company"

INFERENCE_JPEG_HEADERS = {
    "Content-Type": "image/jpeg",
    "Accept": "application/json",
    "User-Agent": USER_AGENT,
}

# K = Command
# E = Event
K_HELP: Final[str] = "help"
E_PHOTO: Final[str] = "photo"

ADMIN_USAGE = f"""
You can control me by sending these commands:

/{K_HELP} - Show this help message.
"""

GUEST_USAGE = f"""
You can control me by sending these commands:

/{K_HELP} - Show this help message.
"""


def get_environment(key: str) -> str:
    value = environ.get(key, None)
    if value is None:
        raise KeyError(f"The {key} environment variable is required")
    return value


def message_params_to_text(
    message: Message,
    regexp_command: Optional[Match] = None,
    arg_names: Optional[Iterable[str]] = None,
) -> str:
    chat_id = message.chat.id
    username = message.from_user.username
    buf = StringIO()
    buf.write(f"id={chat_id},from='{username}'")
    if regexp_command and arg_names:
        for i, arg_name in enumerate(arg_names):
            arg_value = regexp_command.group(i + 1)
            buf.write(f",{arg_name}={arg_value}")
    return buf.getvalue()


def message_to_logging_text(
    event_name: str,
    message: Message,
    regexp_command: Optional[Match] = None,
    arg_names: Optional[Iterable[str]] = None,
) -> str:
    params = message_params_to_text(message, regexp_command, arg_names)
    return f"On '{event_name}' command ({params})"


def get_max_size_photo(photo: List[PhotoSize]) -> PhotoSize:
    if not photo:
        raise IndexError("Empty photo")
    max_file_size = max(map(lambda x: x.file_size, photo))
    return list(filter(lambda x: x.file_size == max_file_size, photo))[0]


@dataclass
class InferenceItem:
    id: int
    name: str
    ratio: float

    @property
    def score(self) -> float:
        return self.ratio

    def __str__(self):
        return f"'{self.name}' (ratio={self.ratio:.3f})"


class DefaultApp:
    def __init__(
        self,
        telegram_api_token: str,
        inference_server_address: str,
        admins: List[int],
        *,
        timeout=8.0,
        use_uvloop=False,
        debug=False,
        verbose=0,
    ):
        self._inference_server_address = inference_server_address
        self._admins = admins

        self._timeout = timeout
        self._use_uvloop = use_uvloop
        self._debug = debug
        self._verbose = verbose

        self._bot = Bot(token=telegram_api_token)
        self._dispatcher = Dispatcher(self._bot)
        self._register()

        logger.info(f"Inference server address: {self._inference_server_address}")
        logger.info(f"Admins: {self._admins}")
        logger.info(f"uvloop flag: {self._use_uvloop}")
        logger.info(f"Debug flag: {self._debug}")
        logger.info(f"Verbose level: {self._verbose}")

    def _register(self):
        assert self._dispatcher is not None
        d = self._dispatcher
        d.register_message_handler(self.on_help, commands=[K_HELP])
        d.register_message_handler(self.on_photo, content_types=["photo"])

    @property
    def admins(self) -> List[int]:
        return self._admins

    @property
    def admin_usage(self) -> str:
        return ADMIN_USAGE

    @property
    def guest_usage(self) -> str:
        return GUEST_USAGE

    def run(self) -> int:
        try:
            if self._use_uvloop:
                if version_info >= (3, 11):
                    with Runner(loop_factory=uvloop_new_event_loop) as runner:
                        self.start_polling(runner.get_loop())
                else:
                    uvloop_install()
                    self.start_polling()
            else:
                self.start_polling()
        except KeyboardInterrupt:
            logger.warning("An interrupt signal was detected")
            return 0
        except Exception as e:
            logger.exception(e)
            return 1
        else:
            return 0

    def start_polling(self, loop: Optional[AbstractEventLoop] = None) -> None:
        executor.start_polling(
            self._dispatcher,
            loop=loop if loop else get_event_loop(),
            skip_updates=True,
            on_startup=self.on_open,
            on_shutdown=self.on_close,
        )

    async def inference_with_jpeg(
        self, jpeg: Union[bytes, bytearray, memoryview]
    ) -> List[InferenceItem]:
        timeout = ClientTimeout(total=self._timeout)
        async with ClientSession(headers=INFERENCE_JPEG_HEADERS, timeout=timeout) as s:
            async with s.post(self._inference_server_address, data=jpeg) as r:
                return deserialize(await r.json(), List[InferenceItem])

    async def inference_text_with_jpeg(
        self, jpeg: Union[bytes, bytearray, memoryview]
    ) -> str:
        try:
            inference_result = await self.inference_with_jpeg(jpeg)
        except BaseException as e:
            logger.exception(e)
            return "Communication with the inference server failed"
        else:
            assert isinstance(inference_result, list)
            items = sorted(inference_result, key=lambda x: x.score, reverse=True)
            if not items:
                return "Not found inferred item"

            buffer = StringIO()
            top1 = items[0]
            buffer.write(f"Top 1 is {str(top1)}")
            if top1.score < SHOW_TOP2_THRESHOLD and len(items) >= 2:
                top2 = items[1]
                buffer.write(f"\nTop 2 is {str(top2)}")
            buffer.write(LOGGING_SUFFIX)
            return buffer.getvalue()

    # --------------
    # Event handlers
    # --------------

    async def on_open(self, dispatcher: Dispatcher) -> None:
        assert self
        assert dispatcher

    async def on_close(self, dispatcher: Dispatcher) -> None:
        assert self
        assert dispatcher

    async def on_help(self, message: Message):
        logger.info(message_to_logging_text(K_HELP, message))

        if message.chat.id in self.admins:
            await message.reply(self.admin_usage)
        else:
            await message.reply(self.guest_usage)

    async def on_photo(self, message: Message):
        logger.info(message_to_logging_text(E_PHOTO, message))

        photo = get_max_size_photo(message.photo)
        buffer = BytesIO()
        await photo.download(destination_file=buffer)
        result_text = await self.inference_text_with_jpeg(buffer.getbuffer())
        await message.reply(result_text)


def default_main(args: Namespace, printer: Callable[..., None] = print) -> int:
    assert args is not None
    assert printer is not None

    telegram_api_token = get_environment("TELEGRAM_API_TOKEN")
    inference_server_address = get_environment("INFERENCE_SERVER_ADDRESS")
    admin_ids = get_environment("ADMIN_IDS")
    assert isinstance(args.use_uvloop, bool)
    assert isinstance(args.debug, bool)
    assert isinstance(args.verbose, int)

    app = DefaultApp(
        telegram_api_token=telegram_api_token,
        inference_server_address=inference_server_address,
        admins=[int(x) for x in admin_ids.split(ADMIN_ID_SEPARATOR) if x],
        use_uvloop=args.use_uvloop,
        debug=args.debug,
        verbose=args.verbose,
    )
    return app.run()

.env file

TELEGRAM_API_TOKEN=[YOUR API TOKEN]
INFERENCE_SERVER_ADDRESS=http://192.168.0.66:10000/api/v1/jobs/ai/analysis-steel-image/
ADMIN_IDS=

See also