Skip to content

Asyncio:TaskManager

Python의 asyncio패키지를 사용하여 Task 관리자 구현.

task_manager.py Code

# -*- coding: utf-8 -*-

from abc import ABCMeta, abstractmethod
from asyncio import Task, create_task, shield, wait_for
from typing import Any, Dict, Optional

from overrides import overrides


async def _timeout_wrapper(coro, timeout: Optional[float] = None) -> Any:
    if timeout is not None:
        assert timeout > 0
        return await wait_for(coro, timeout=timeout)
    else:
        return await coro


class TaskManagerInterface(metaclass=ABCMeta):
    @abstractmethod
    def on_done(self, task: Task) -> None:
        raise NotImplementedError


class TaskManager(TaskManagerInterface, Dict[str, Task]):
    async def close(self) -> None:
        for task in self.values():
            task.cancel()
            try:
                await task
            except:  # noqa
                pass
        self.clear()

    async def join(self, key: str, timeout: Optional[float] = None) -> Any:
        task = self.__getitem__(key)
        if timeout is not None:
            assert timeout > 0
            return await wait_for(shield(task), timeout=timeout)
        else:
            return await task

    def create_task(self, key: str, coro, timeout: Optional[float] = None) -> Task:
        if self.__contains__(key):
            raise KeyError(f"Already exists key: {key}")
        task = create_task(_timeout_wrapper(coro, timeout), name=key)
        task.add_done_callback(self.on_done)
        self.__setitem__(key, task)
        return task

    @overrides
    def on_done(self, task: Task) -> None:
        assert isinstance(task, Task)


class AutoRemoveTaskManager(TaskManager):
    @overrides
    def on_done(self, task: Task) -> None:
        assert isinstance(task, Task)
        self.__delitem__(task.get_name())

TestCase

# -*- coding: utf-8 -*-

from asyncio import CancelledError, TimeoutError
from asyncio import sleep as asyncio_sleep
from unittest import IsolatedAsyncioTestCase, main

from recc.aio.task_manager import AutoRemoveTaskManager, TaskManager


class TaskManagerTestCase(IsolatedAsyncioTestCase):
    async def asyncSetUp(self):
        self.tm = TaskManager()

    async def asyncTearDown(self):
        await self.tm.close()
        self.assertEqual(0, len(self.tm))

    async def test_default(self):
        async def _func():
            return 999

        key1 = "key1"
        self.assertEqual(0, len(self.tm))
        self.tm.create_task(key1, _func())
        self.assertEqual(1, len(self.tm))

        self.assertEqual(999, await self.tm.join(key1))
        self.assertTrue(self.tm[key1].done())
        self.assertFalse(self.tm[key1].cancelled())
        self.assertIsNone(self.tm[key1].exception())
        self.assertEqual(999, self.tm[key1].result())
        self.assertEqual(1, len(self.tm))

    async def test_timeout(self):
        async def _timeout_1s_func():
            await asyncio_sleep(1.0)

        key1 = "key1"
        self.assertEqual(0, len(self.tm))
        self.tm.create_task(key1, _timeout_1s_func(), 0.1)
        self.assertEqual(1, len(self.tm))

        with self.assertRaises(TimeoutError):
            await self.tm.join(key1)

        self.assertTrue(self.tm[key1].done())
        self.assertFalse(self.tm[key1].cancelled())
        self.assertIsNotNone(self.tm[key1].exception())
        self.assertIsInstance(self.tm[key1].exception(), TimeoutError)

    async def test_join_timeout(self):
        async def _timeout_10s_func():
            await asyncio_sleep(10.0)

        key1 = "key1"
        self.assertEqual(0, len(self.tm))
        self.tm.create_task(key1, _timeout_10s_func())
        self.assertEqual(1, len(self.tm))

        with self.assertRaises(TimeoutError):
            await self.tm.join(key1, 0.001)

        self.assertFalse(self.tm[key1].done())
        self.assertFalse(self.tm[key1].cancelled())

        self.tm[key1].cancel()
        self.assertFalse(self.tm[key1].done())
        self.assertFalse(self.tm[key1].cancelled())  # Not now

        with self.assertRaises(CancelledError):
            await self.tm.join(key1)

        self.assertTrue(self.tm[key1].done())
        self.assertTrue(self.tm[key1].cancelled())

        with self.assertRaises(CancelledError):
            self.tm[key1].exception()


class AutoRemoveTaskManagerTestCase(IsolatedAsyncioTestCase):
    async def asyncSetUp(self):
        self.tm = AutoRemoveTaskManager()

    async def asyncTearDown(self):
        await self.tm.close()
        self.assertEqual(0, len(self.tm))

    async def test_default(self):
        async def _func():
            return 777

        key1 = "key1"
        self.assertEqual(0, len(self.tm))
        task1 = self.tm.create_task(key1, _func())
        self.assertEqual(key1, task1.get_name())
        self.assertEqual(1, len(self.tm))

        self.assertEqual(777, await task1)
        self.assertTrue(task1.done())
        self.assertEqual(0, len(self.tm))

    async def test_timeout(self):
        async def _timeout_1s_func():
            await asyncio_sleep(1.0)

        key1 = "key1"
        self.assertEqual(0, len(self.tm))
        task1 = self.tm.create_task(key1, _timeout_1s_func(), 0.1)
        self.assertEqual(key1, task1.get_name())
        self.assertEqual(1, len(self.tm))

        with self.assertRaises(TimeoutError):
            await task1

        self.assertTrue(task1.done())
        self.assertFalse(task1.cancelled())
        self.assertIsNotNone(task1.exception())
        self.assertIsInstance(task1.exception(), TimeoutError)
        self.assertEqual(0, len(self.tm))


if __name__ == "__main__":
    main()

See also