Skip to content

Commit

Permalink
fix: Improve AutoscaledPool state management (#241)
Browse files Browse the repository at this point in the history
- closes #236
  • Loading branch information
janbuchar authored Jun 27, 2024
1 parent 671f54b commit fdea3d1
Showing 1 changed file with 48 additions and 37 deletions.
85 changes: 48 additions & 37 deletions src/crawlee/autoscaling/autoscaled_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
if TYPE_CHECKING:
from crawlee.autoscaling import SystemStatus

__all__ = ['ConcurrencySettings', 'AutoscaledPool']

logger = getLogger(__name__)


Expand Down Expand Up @@ -63,6 +65,16 @@ def __init__(
self.max_tasks_per_minute = max_tasks_per_minute


class _AutoscaledPoolRun:
def __init__(self) -> None:
self.worker_tasks = list[asyncio.Task]()
"""A list of worker tasks currently in progress"""

self.worker_tasks_updated = asyncio.Event()
self.cleanup_done = asyncio.Event()
self.result: asyncio.Future = asyncio.Future()


class AutoscaledPool:
"""Manages a pool of asynchronous resource-intensive tasks that are executed in parallel.
Expand Down Expand Up @@ -131,13 +143,6 @@ def __init__(

self._autoscale_task = RecurringTask(self._autoscale, autoscale_interval)

self._worker_tasks = list[asyncio.Task]()
"""A list of worker tasks currently in progress"""

self._worker_tasks_updated = asyncio.Event()
self._cleanup_done = asyncio.Event()
self._run_result: asyncio.Future = asyncio.Future()

if desired_concurrency_ratio < 0 or desired_concurrency_ratio > 1:
raise ValueError('desired_concurrency_ratio must be between 0 and 1 (non-inclusive)')

Expand All @@ -154,32 +159,33 @@ def __init__(

self._max_tasks_per_minute = concurrency_settings.max_tasks_per_minute
self._is_paused = False
self._is_running = False
self._current_run: _AutoscaledPoolRun | None = None

async def run(self) -> None:
"""Start the autoscaled pool and return when all tasks are completed and `is_finished_function` returns True.
If there is an exception in one of the tasks, it will be re-raised.
"""
if self._is_running:
if self._current_run is not None:
raise RuntimeError('The pool is already running')

self._is_running = True
self._cleanup_done.clear()
run = _AutoscaledPoolRun()
self._current_run = run

logger.debug('Starting the pool')

self._autoscale_task.start()
self._log_system_status_task.start()

orchestrator = asyncio.create_task(
self._worker_task_orchestrator(), name='autoscaled pool worker task orchestrator'
self._worker_task_orchestrator(run), name='autoscaled pool worker task orchestrator'
)

try:
await self._run_result
await run.result
except AbortError:
orchestrator.cancel()
for task in self._worker_tasks:
for task in run.worker_tasks:
if not task.done():
task.cancel()
finally:
Expand All @@ -195,21 +201,23 @@ async def run(self) -> None:

logger.info('Waiting for remaining tasks to finish')

for task in self._worker_tasks:
for task in run.worker_tasks:
if not task.done():
with suppress(BaseException):
await task

self._run_result = asyncio.Future()
self._cleanup_done.set()
self._is_running = False
run.cleanup_done.set()
self._current_run = None

logger.debug('Pool cleanup finished')

async def abort(self) -> None:
"""Interrupt the autoscaled pool and all the tasks in progress."""
self._run_result.set_exception(AbortError())
await self._cleanup_done.wait()
if not self._current_run:
raise RuntimeError('The pool is not running')

self._current_run.result.set_exception(AbortError())
await self._current_run.cleanup_done.wait()

def pause(self) -> None:
"""Pause the autoscaled pool so that it does not start new tasks."""
Expand All @@ -227,7 +235,10 @@ def desired_concurrency(self) -> int:
@property
def current_concurrency(self) -> int:
"""The number of concurrent tasks in progress."""
return len(self._worker_tasks)
if self._current_run is None:
return 0

return len(self._current_run.worker_tasks)

def _autoscale(self) -> None:
"""Inspect system load status and adjust desired concurrency if necessary. Do not call directly."""
Expand Down Expand Up @@ -258,16 +269,16 @@ def _log_system_status(self) -> None:
f'{system_status!s}'
)

async def _worker_task_orchestrator(self) -> None:
async def _worker_task_orchestrator(self, run: _AutoscaledPoolRun) -> None:
"""Launches worker tasks whenever there is free capacity and a task is ready.
Exits when `is_finished_function` returns True.
"""
finished = False

try:
while not (finished := await self._is_finished_function()) and not self._run_result.done():
self._worker_tasks_updated.clear()
while not (finished := await self._is_finished_function()) and not run.result.done():
run.worker_tasks_updated.clear()

current_status = self._system_status.get_current_system_info()
if not current_status.is_system_idle:
Expand All @@ -281,44 +292,44 @@ async def _worker_task_orchestrator(self) -> None:
else:
logger.debug('Scheduling a new task')
worker_task = asyncio.create_task(self._worker_task(), name='autoscaled pool worker task')
worker_task.add_done_callback(lambda task: self._reap_worker_task(task))
self._worker_tasks.append(worker_task)
worker_task.add_done_callback(lambda task: self._reap_worker_task(task, run))
run.worker_tasks.append(worker_task)

if math.isfinite(self._max_tasks_per_minute):
await asyncio.sleep(60 / self._max_tasks_per_minute)

continue

with suppress(asyncio.TimeoutError):
await asyncio.wait_for(self._worker_tasks_updated.wait(), timeout=0.5)
await asyncio.wait_for(run.worker_tasks_updated.wait(), timeout=0.5)
finally:
if finished:
logger.debug('`is_finished_function` reports that we are finished')
elif self._run_result.done() and self._run_result.exception() is not None:
elif run.result.done() and run.result.exception() is not None:
logger.debug('Unhandled exception in `run_task_function`')

if self._worker_tasks:
if run.worker_tasks:
logger.debug('Terminating - waiting for tasks to complete')
await asyncio.wait(self._worker_tasks, return_when=asyncio.ALL_COMPLETED)
await asyncio.wait(run.worker_tasks, return_when=asyncio.ALL_COMPLETED)
logger.debug('Worker tasks finished')
else:
logger.debug('Terminating - no running tasks to wait for')

if not self._run_result.done():
self._run_result.set_result(object())
if not run.result.done():
run.result.set_result(object())

def _reap_worker_task(self, task: asyncio.Task) -> None:
def _reap_worker_task(self, task: asyncio.Task, run: _AutoscaledPoolRun) -> None:
"""A callback for finished worker tasks.
- It interrupts the run in case of an exception,
- keeps track of tasks in progress,
- notifies the orchestrator
"""
self._worker_tasks_updated.set()
self._worker_tasks.remove(task)
run.worker_tasks_updated.set()
run.worker_tasks.remove(task)

if not task.cancelled() and (exception := task.exception()) and not self._run_result.done():
self._run_result.set_exception(exception)
if not task.cancelled() and (exception := task.exception()) and not run.result.done():
run.result.set_exception(exception)

async def _worker_task(self) -> None:
try:
Expand Down

0 comments on commit fdea3d1

Please sign in to comment.