Asyncio:TaskManager
Python의 asyncio패키지를 사용하여 Task 관리자 구현.
task_manager.py Code
# -*- coding: utf-8 -*-
from abc import ABCMeta, abstractmethod
from asyncio import Task, create_task, shield, wait_for
from typing import Any, Dict, Optional
from overrides import overrides
async def _timeout_wrapper(coro, timeout: Optional[float] = None) -> Any:
if timeout is not None:
assert timeout > 0
return await wait_for(coro, timeout=timeout)
else:
return await coro
class TaskManagerInterface(metaclass=ABCMeta):
@abstractmethod
def on_done(self, task: Task) -> None:
raise NotImplementedError
class TaskManager(TaskManagerInterface, Dict[str, Task]):
async def close(self) -> None:
for task in self.values():
task.cancel()
try:
await task
except: # noqa
pass
self.clear()
async def join(self, key: str, timeout: Optional[float] = None) -> Any:
task = self.__getitem__(key)
if timeout is not None:
assert timeout > 0
return await wait_for(shield(task), timeout=timeout)
else:
return await task
def create_task(self, key: str, coro, timeout: Optional[float] = None) -> Task:
if self.__contains__(key):
raise KeyError(f"Already exists key: {key}")
task = create_task(_timeout_wrapper(coro, timeout), name=key)
task.add_done_callback(self.on_done)
self.__setitem__(key, task)
return task
@overrides
def on_done(self, task: Task) -> None:
assert isinstance(task, Task)
class AutoRemoveTaskManager(TaskManager):
@overrides
def on_done(self, task: Task) -> None:
assert isinstance(task, Task)
self.__delitem__(task.get_name())
TestCase
# -*- coding: utf-8 -*-
from asyncio import CancelledError, TimeoutError
from asyncio import sleep as asyncio_sleep
from unittest import IsolatedAsyncioTestCase, main
from recc.aio.task_manager import AutoRemoveTaskManager, TaskManager
class TaskManagerTestCase(IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.tm = TaskManager()
async def asyncTearDown(self):
await self.tm.close()
self.assertEqual(0, len(self.tm))
async def test_default(self):
async def _func():
return 999
key1 = "key1"
self.assertEqual(0, len(self.tm))
self.tm.create_task(key1, _func())
self.assertEqual(1, len(self.tm))
self.assertEqual(999, await self.tm.join(key1))
self.assertTrue(self.tm[key1].done())
self.assertFalse(self.tm[key1].cancelled())
self.assertIsNone(self.tm[key1].exception())
self.assertEqual(999, self.tm[key1].result())
self.assertEqual(1, len(self.tm))
async def test_timeout(self):
async def _timeout_1s_func():
await asyncio_sleep(1.0)
key1 = "key1"
self.assertEqual(0, len(self.tm))
self.tm.create_task(key1, _timeout_1s_func(), 0.1)
self.assertEqual(1, len(self.tm))
with self.assertRaises(TimeoutError):
await self.tm.join(key1)
self.assertTrue(self.tm[key1].done())
self.assertFalse(self.tm[key1].cancelled())
self.assertIsNotNone(self.tm[key1].exception())
self.assertIsInstance(self.tm[key1].exception(), TimeoutError)
async def test_join_timeout(self):
async def _timeout_10s_func():
await asyncio_sleep(10.0)
key1 = "key1"
self.assertEqual(0, len(self.tm))
self.tm.create_task(key1, _timeout_10s_func())
self.assertEqual(1, len(self.tm))
with self.assertRaises(TimeoutError):
await self.tm.join(key1, 0.001)
self.assertFalse(self.tm[key1].done())
self.assertFalse(self.tm[key1].cancelled())
self.tm[key1].cancel()
self.assertFalse(self.tm[key1].done())
self.assertFalse(self.tm[key1].cancelled()) # Not now
with self.assertRaises(CancelledError):
await self.tm.join(key1)
self.assertTrue(self.tm[key1].done())
self.assertTrue(self.tm[key1].cancelled())
with self.assertRaises(CancelledError):
self.tm[key1].exception()
class AutoRemoveTaskManagerTestCase(IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.tm = AutoRemoveTaskManager()
async def asyncTearDown(self):
await self.tm.close()
self.assertEqual(0, len(self.tm))
async def test_default(self):
async def _func():
return 777
key1 = "key1"
self.assertEqual(0, len(self.tm))
task1 = self.tm.create_task(key1, _func())
self.assertEqual(key1, task1.get_name())
self.assertEqual(1, len(self.tm))
self.assertEqual(777, await task1)
self.assertTrue(task1.done())
self.assertEqual(0, len(self.tm))
async def test_timeout(self):
async def _timeout_1s_func():
await asyncio_sleep(1.0)
key1 = "key1"
self.assertEqual(0, len(self.tm))
task1 = self.tm.create_task(key1, _timeout_1s_func(), 0.1)
self.assertEqual(key1, task1.get_name())
self.assertEqual(1, len(self.tm))
with self.assertRaises(TimeoutError):
await task1
self.assertTrue(task1.done())
self.assertFalse(task1.cancelled())
self.assertIsNotNone(task1.exception())
self.assertIsInstance(task1.exception(), TimeoutError)
self.assertEqual(0, len(self.tm))
if __name__ == "__main__":
main()