Aiogram:Examples:InferenceServerProxy
추론 서버를 연결하는 프록시 Bot
Python Code
# -*- coding: utf-8 -*-
from argparse import Namespace
from asyncio import AbstractEventLoop, get_event_loop
from dataclasses import dataclass
from io import BytesIO, StringIO
from os import environ
from re import Match
from sys import version_info
from typing import Callable, Final, Iterable, List, Optional, Union
if version_info >= (3, 11):
from asyncio import Runner # type: ignore[attr-defined]
from aiogram import Bot, Dispatcher, executor
from aiogram.types import Message, PhotoSize
from aiohttp import ClientSession, ClientTimeout
from type_serialize import deserialize
from uvloop import install as uvloop_install
from uvloop import new_event_loop as uvloop_new_event_loop
from answersteelscrapbot.logging.logging import logger
ADMIN_ID_SEPARATOR: Final[str] = ":"
USER_AGENT: Final[str] = "YourBot"
SHOW_TOP2_THRESHOLD: Final[float] = 0.7
LOGGING_SUFFIX: Final[str] = "\n\nPowered by Company"
INFERENCE_JPEG_HEADERS = {
"Content-Type": "image/jpeg",
"Accept": "application/json",
"User-Agent": USER_AGENT,
}
# K = Command
# E = Event
K_HELP: Final[str] = "help"
E_PHOTO: Final[str] = "photo"
ADMIN_USAGE = f"""
You can control me by sending these commands:
/{K_HELP} - Show this help message.
"""
GUEST_USAGE = f"""
You can control me by sending these commands:
/{K_HELP} - Show this help message.
"""
def get_environment(key: str) -> str:
value = environ.get(key, None)
if value is None:
raise KeyError(f"The {key} environment variable is required")
return value
def message_params_to_text(
message: Message,
regexp_command: Optional[Match] = None,
arg_names: Optional[Iterable[str]] = None,
) -> str:
chat_id = message.chat.id
username = message.from_user.username
buf = StringIO()
buf.write(f"id={chat_id},from='{username}'")
if regexp_command and arg_names:
for i, arg_name in enumerate(arg_names):
arg_value = regexp_command.group(i + 1)
buf.write(f",{arg_name}={arg_value}")
return buf.getvalue()
def message_to_logging_text(
event_name: str,
message: Message,
regexp_command: Optional[Match] = None,
arg_names: Optional[Iterable[str]] = None,
) -> str:
params = message_params_to_text(message, regexp_command, arg_names)
return f"On '{event_name}' command ({params})"
def get_max_size_photo(photo: List[PhotoSize]) -> PhotoSize:
if not photo:
raise IndexError("Empty photo")
max_file_size = max(map(lambda x: x.file_size, photo))
return list(filter(lambda x: x.file_size == max_file_size, photo))[0]
@dataclass
class InferenceItem:
id: int
name: str
ratio: float
@property
def score(self) -> float:
return self.ratio
def __str__(self):
return f"'{self.name}' (ratio={self.ratio:.3f})"
class DefaultApp:
def __init__(
self,
telegram_api_token: str,
inference_server_address: str,
admins: List[int],
*,
timeout=8.0,
use_uvloop=False,
debug=False,
verbose=0,
):
self._inference_server_address = inference_server_address
self._admins = admins
self._timeout = timeout
self._use_uvloop = use_uvloop
self._debug = debug
self._verbose = verbose
self._bot = Bot(token=telegram_api_token)
self._dispatcher = Dispatcher(self._bot)
self._register()
logger.info(f"Inference server address: {self._inference_server_address}")
logger.info(f"Admins: {self._admins}")
logger.info(f"uvloop flag: {self._use_uvloop}")
logger.info(f"Debug flag: {self._debug}")
logger.info(f"Verbose level: {self._verbose}")
def _register(self):
assert self._dispatcher is not None
d = self._dispatcher
d.register_message_handler(self.on_help, commands=[K_HELP])
d.register_message_handler(self.on_photo, content_types=["photo"])
@property
def admins(self) -> List[int]:
return self._admins
@property
def admin_usage(self) -> str:
return ADMIN_USAGE
@property
def guest_usage(self) -> str:
return GUEST_USAGE
def run(self) -> int:
try:
if self._use_uvloop:
if version_info >= (3, 11):
with Runner(loop_factory=uvloop_new_event_loop) as runner:
self.start_polling(runner.get_loop())
else:
uvloop_install()
self.start_polling()
else:
self.start_polling()
except KeyboardInterrupt:
logger.warning("An interrupt signal was detected")
return 0
except Exception as e:
logger.exception(e)
return 1
else:
return 0
def start_polling(self, loop: Optional[AbstractEventLoop] = None) -> None:
executor.start_polling(
self._dispatcher,
loop=loop if loop else get_event_loop(),
skip_updates=True,
on_startup=self.on_open,
on_shutdown=self.on_close,
)
async def inference_with_jpeg(
self, jpeg: Union[bytes, bytearray, memoryview]
) -> List[InferenceItem]:
timeout = ClientTimeout(total=self._timeout)
async with ClientSession(headers=INFERENCE_JPEG_HEADERS, timeout=timeout) as s:
async with s.post(self._inference_server_address, data=jpeg) as r:
return deserialize(await r.json(), List[InferenceItem])
async def inference_text_with_jpeg(
self, jpeg: Union[bytes, bytearray, memoryview]
) -> str:
try:
inference_result = await self.inference_with_jpeg(jpeg)
except BaseException as e:
logger.exception(e)
return "Communication with the inference server failed"
else:
assert isinstance(inference_result, list)
items = sorted(inference_result, key=lambda x: x.score, reverse=True)
if not items:
return "Not found inferred item"
buffer = StringIO()
top1 = items[0]
buffer.write(f"Top 1 is {str(top1)}")
if top1.score < SHOW_TOP2_THRESHOLD and len(items) >= 2:
top2 = items[1]
buffer.write(f"\nTop 2 is {str(top2)}")
buffer.write(LOGGING_SUFFIX)
return buffer.getvalue()
# --------------
# Event handlers
# --------------
async def on_open(self, dispatcher: Dispatcher) -> None:
assert self
assert dispatcher
async def on_close(self, dispatcher: Dispatcher) -> None:
assert self
assert dispatcher
async def on_help(self, message: Message):
logger.info(message_to_logging_text(K_HELP, message))
if message.chat.id in self.admins:
await message.reply(self.admin_usage)
else:
await message.reply(self.guest_usage)
async def on_photo(self, message: Message):
logger.info(message_to_logging_text(E_PHOTO, message))
photo = get_max_size_photo(message.photo)
buffer = BytesIO()
await photo.download(destination_file=buffer)
result_text = await self.inference_text_with_jpeg(buffer.getbuffer())
await message.reply(result_text)
def default_main(args: Namespace, printer: Callable[..., None] = print) -> int:
assert args is not None
assert printer is not None
telegram_api_token = get_environment("TELEGRAM_API_TOKEN")
inference_server_address = get_environment("INFERENCE_SERVER_ADDRESS")
admin_ids = get_environment("ADMIN_IDS")
assert isinstance(args.use_uvloop, bool)
assert isinstance(args.debug, bool)
assert isinstance(args.verbose, int)
app = DefaultApp(
telegram_api_token=telegram_api_token,
inference_server_address=inference_server_address,
admins=[int(x) for x in admin_ids.split(ADMIN_ID_SEPARATOR) if x],
use_uvloop=args.use_uvloop,
debug=args.debug,
verbose=args.verbose,
)
return app.run()
.env file
TELEGRAM_API_TOKEN=[YOUR API TOKEN]
INFERENCE_SERVER_ADDRESS=http://192.168.0.66:10000/api/v1/jobs/ai/analysis-steel-image/
ADMIN_IDS=