Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RetryHandlerSkeleton #152

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import List, Any
from deltacat.utils.ray_utils.retry_handler.batch_scaling_interface import BatchScalingInterface
class AIMDBasedBatchScalingStrategy(BatchScalingInterface):
"""
Default batch scaling parameters for if the client does not provide their own batch_scaling parameters
"""
def __init__(self, additive_increase: int, multiplicative_decrease: float):
self.task_infos = []
self.batch_index = 0
self.batch_size = None
self.max_batch_size = None
self.min_batch_size = None
self.additive_increase = additive_increase
self.multiplicative_decrease = multiplicative_decrease
def init_tasks(self, initial_batch_size: int, max_batch_size: int, min_batch_size: int, task_infos: List[TaskInfoObject])-> None:
"""
Setup AIMD scaling for the batches as the default
"""
self.task_infos = task_infos
self.batch_size = initial_batch_size
self.max_batch_size = max_batch_size
self.min_batch_size = min_batch_size


def has_next_batch(self) -> bool:
"""
Returns the list of tasks included in the next batch of whatever size based on AIMD
"""
return self.batch_index < len(self.task_infos)


def next_batch(self) -> List[TaskInfoObject]:
"""
If there are no more tasks to execute that can not create a batch, return False
"""
batch_end = min(self.batch_index + self.batch_size, len(self.task_infos))
batch = self.task_infos[self.batch_index:batch_end]
self.batch_index = batch_end
return batch

def mark_task_complete(self, task_info: TaskInfoObject):
task_info.completed = True

def increase_batch_size(self):
self.batch_size = min(self.batch_size + self.additive_increase, self.max_batch_size)


def decrease_batch_size(self):
self.batch_size = max(self.batch_size * self.multiplicative_decrease, self.min_batch_size)
27 changes: 27 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
This module represents a straggler detection and retry handler framework

Within retry_strategy_config.py, the client can provide 3 parameters to start_tasks_execution to perform retries and detect stragglers
Params:
1. ray_remote_task_info: A list of Ray task objects
2. scaling_strategy: Batch scaling parameters for how many tasks to execute per batch (Optional)
a. If not provided, a default AIMD (additive increase, multiplicative decrease) strategy will be assigned for retries
3. straggler_detection: Client-provided class that holds logic for how they want to detect straggler tasks (Optional)
a. Client algorithm must inherit the interface for detection which will be used in wait_and_get_results
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved

Use cases:
1. Notifying progress
This will be done through ProgressNotifierInterface. The client can implement has_progress and send_progress from the interface
to recieve updates on task level progress. This can be an SNSQueue or any type of indicator the client may choose.
2. Detecting stragglers
Given the straggler detection algorithm implemented by StragglerDetectionInterface, the method is_straggler will inform
the customer if the current node is a straggler according to their own logic. In order to make their decision, we will provide them
with TaskContext that contains fields and data that the client can use to decide if a task is a straggler or not.
3. Retrying retryable exceptions
Within the failure directory, there are common errors that are retryable and when detected as an instance
of the retryable class, will cause the task to be retried when the exception is caught. If the client would like
to create their own exceptions to be handles, they can create a class that is an extension of retryable_error or
non_retryable_error and the framework should handle it based on the configuration strategy.




42 changes: 42 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/batch_scaling_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import List, Any, Protocol
from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject
class BatchScalingInterface(Protocol):
"""
Interface for a generic batch scaling that the client can provide.
"""
def init_tasks(self, initial_batch_size: int, max_batch_size: int, min_batch_size: int, task_infos: List[TaskInfoObject]) -> None:
"""
Loads all tasks to be executed for retry batching
"""
pass
def has_next_batch(self) -> bool:
"""
Returns true if there are tasks remaining in the overall List of tasks to create a new batch
"""
pass
def next_batch(self, task_info: TaskInfoObject) -> List:
"""
Gets the next batch to execute on
"""
pass
def mark_task_complete(self, task_info: TaskInfoObject) -> None:
"""
If the task has been completed, mark some field of it as true
so we know what tasks are completed and what need to be executed
"""
pass

def increase_batch_size(self) -> None:
"""
Increase the batch size by some amount according to client specifications
:return:
"""
pass

def decrease_batch_size(self) -> None:
"""
Decrease the batch size by some amount according to client specifications
:return:
"""
pass

16 changes: 16 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/exception_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import List, Optional
from ray_manager.models.ray_remote_task_exception_retry_strategy_config import RayRemoteTaskExceptionRetryConfig
def get_retry_strategy_config_for_known_exception(exception: Exception,
exception_retry_strategy_configs: List[RayRemoteTaskExceptionRetryConfig]) -> Optional[RayRemoteTaskExceptionRetryConfig]:
"""
Checks whether the exception seen is recognized as a retryable error or not
"""
for exception_retry_strategy_config in exception_retry_strategy_configs:
if type(exception) == type(exception_retry_strategy_config.exception):
return exception_retry_strategy_config

for exception_retry_strategy_config in exception_retry_strategy_configs:
if isinstance(exception, type(exception_retry_strategy_config.exception)):
return exception_retry_strategy_config

return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from deltacat.utils.ray_utils.retry_handler.retryable_error.failures import RetryableError

class AWSSecurityTokenRateExceededException(RetryableError):

def __init__(self, *args: object) -> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from deltacat.utils.ray_utils.retry_handler.retryable_error.failures import RetryableError

class CairnsClientException(RetryableError):
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, *args: object) -> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from exceptions import Exception
class NonRetryableError(Exception):
"""
Class represents a non-retryable error
"""
def __init__(self, *args: object):
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from exceptions import Exception
class RetryableError(Exception):
"""
Class for errors that can be retried
"""
def __init__(self, *args: object) --> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import List, Protocol
from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject
class ProgressNotifierInterface(Protocol):
"""
Interface for client injected progress notification system.
"""
def has_heartbeat(self, task_info: TaskInfoObject) -> bool:
"""
Sends progress of current task to parent task
"""
pass
def send_heartbeat(self, parent_task_info: TaskInfoObject) -> bool:
"""
Tells parent task if the current task has a heartbeat or not
"""
pass

196 changes: 196 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/ray_task_submission_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
from __future__ import annotations
from typing import Any, Dict, List, cast, Optional
from deltacat.utils.ray_utils.retry_handler.ray_remote_tasks_batch_scaling_strategy import RayRemoteTasksBatchScalingStrategy
import ray
import time
import logging
from deltacat.logs import configure_logger
from deltacat.utils.ray_utils.retry_handler.task_execution_error import RayRemoteTaskExecutionError
from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject
from deltacat.utils.ray_utils.retry_handler.retry_strategy_config import get_retry_strategy_config_for_known_exception

logger = configure_logger(logging.getLogger(__name__))

@ray.remote
def submit_single_task(taskObj: TaskInfoObject, TaskContext: Optional[Interface] = None) -> Any:
"""
Submits a single task for execution, handles any exceptions that may occur during execution,
and applies appropriate retry strategies if they are defined.
"""
try:
taskObj.attempt_count += 1
curr_attempt = taskObj.attempt_count
logger.debug(f"Executing the submitted Ray remote task as part of attempt number: {current_attempt_number}")
return taskObj.task_callable(taskObj.task_input)
except (Exception) as exception:
exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(exception, taskObj.exception_retry_strategy_configs)
if exception_retry_strategy_config is not None:
return RayRemoteTaskExecutionError(exception_retry_strategy_config.exception, taskObj)

logger.error(f"The exception thrown by submitted Ray task during attempt number: {current_attempt_number} is non-retryable or unexpected, hence throwing Non retryable exception: {exception}")
raise UnexpectedRayTaskError(str(exception))

class RayTaskSubmissionHandler:
"""
Starts execution of all given a list of Ray tasks with optional arguments: scaling strategy and straggler detection
"""
def start_tasks_execution(self,
ray_remote_task_infos: List[TaskInfoObject],
scaling_strategy: Optional[BatchScalingStrategy] = None,
straggler_detection: Optional[StragglerDetectionInterface] = None,
retry_strategy: Optional[RetryTaskInterface],
task_context: Optional[TaskContext]) -> None:
"""
Prepares and initiates the execution of a batch of tasks and can optionally support
custom client batch scaling, straggler detection, and task context
"""
if scaling_strategy is None:
scaling_strategy = AIMDBasedBatchScalingStrategy(ray_remote_task_infos)
if retry_strategy is None:
retry_strategy = RetryTaskDefault(max_retries = 3)

active_tasks = []

while scaling_strategy.has_next_batch():
current_batch = scaling_strategy.next_batch()
for task in current_batch:
try:
self._submit_tasks(task)
active_tasks.append(task)
except Exception as e:
if retry_strategy.should_retry(task, e):
retry_strategy.retry(task, e)
continue
else:
raise #? not sure what to do if the error isnt retryable
completed_tasks = self._wait_and_get_all_task_results(active_tasks)

for task in completed_tasks:
scaling_strategy.mark_task_complete(task)
active_tasks.remove(task)

if all(task.completed for task in current_batch):
scaling_strategy.increase_batch_size()
else:
scaling_strategy.decrease_batch_size()

#handle strags
if straggler_detection is not None:
for task in active_tasks: #tasks that are still running
if straggler_detection.is_straggler(task, task_context):
ray.cancel(task)
active_tasks.remove(task)
#maybe we need to requeue the cancelled task? can add back to ray_remote_task_infos


#call wait_and_get_all ...
#when ray returns results mark as completed --> to mark as completed we want to give a bool field to the task info object and set to true, when gets marked to true
#if success, additive increase method to batchScaling
#if failure, MD on the batch size and continue until nothing remains
#check at least 1 is completed from current batch
#mark task as completed

#wait some time period here ? --> call to _wait_and_get_all_task_results so there is a period to collect completed tasks
#use result of wait and remove from active_tasks because it is completed
#use results of completed promises compared to total tasks in batch to determine batch scaling increase or decrease


def _wait_and_get_all_task_results(self, straggler_detection: Optional[StragglerDetectionInterface]) -> List[Any]:
return self._get_task_results(self.num_of_submitted_tasks, straggler_detection)

def _get_task_results(self, num_of_results: int, straggler_detection: Optional[StragglerDetectionInterface]) -> List[Any]:
"""
Gets results from a list of tasks to be executed, and catches exceptions to manage the retry strategy.
Optional: Given a StragglerDetectionInterface, can detect and handle straggler tasks according to the client logic
"""
if not self.unfinished_promises or num_of_results == 0:
return []
elif num_of_results > len(self.unfinished_promises):
num_of_results = len(self.unfinished_promises)

finished, self.unfinished_promises = ray.wait(self.unfinished_promises, num_of_results)
successful_results = []

for finished in finished:
finished_result = None
try:
finished_result = ray.get(finished)
except (Exception) as exception:
#if exception send to method handle_ray_exception to determine what to do and assign the corresp error
finished_result = self._handle_ray_exception(exception=exception, ray_remote_task_info=self.task_promise_obj_ref_to_task_info_map[str(finished_promise)] )#evaluate the exception and return the error

if finished_result and type(finished_result) == RayRemoteTaskExecutionError:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
finished_result = cast(RayRemoteTaskExecutionError, finished_result)

if straggler_detection and straggler_detection.isStraggler(finished_result):
ray.cancel(finished_result)
exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(finished_result.exception,
finished_result.ray_remote_task_info.exception_retry_strategy_configs)
if (exception_retry_strategy_config is None or finished_result.ray_remote_task_info.num_of_attempts > exception_retry_strategy_config.max_retry_attempts):
logger.error(f"The submitted task has exhausted all the maximum retries configured and finally throws exception - {finished_result.exception}")
raise finished_result.exception
self._update_ray_remote_task_options_on_exception(finished_result.exception, finished_result.ray_remote_task_info)
self.unfinished_promises.append(self._invoke_ray_remote_task(ray_remote_task_info=finished_result.ray_remote_task_info))
else:
successful_results.append(finished_result)
del self.task_promise_obj_ref_to_task_info_map[str(finished_promise)]

num_of_successful_results = len(successful_results)
self.num_of_submitted_tasks_completed += num_of_successful_results
self.current_batch_size -= num_of_successful_results

self._enqueue_new_tasks(num_of_successful_results)

if num_of_successful_results < num_of_results:
successful_results.extend(self._get_task_results(num_of_results - num_of_successful_results))
return successful_results
else:
return successful_results


def _enqueue_new_tasks(self, num_of_tasks: int) -> None:
"""
Helper method to submit a specified number of tasks
"""
new_tasks_submitted = self.remaining_ray_remote_task_infos[:num_of_tasks]
num_of_new_tasks_submitted = len(new_tasks_submitted)
self._submit_tasks(new_tasks_submitted)
self.remaining_ray_remote_task_infos = self.remaining_ray_remote_task_infos[num_of_tasks:]
self.current_batch_size += num_of_new_tasks_submitted
logger.info(f"Enqueued {num_of_new_tasks_submitted} new tasks. Current concurrency of tasks execution: {self.current_batch_size}, Current Task progress: {self.num_of_submitted_tasks_completed}/{self.num_of_submitted_tasks}")

def _submit_tasks(self, info_objs: List[TaskInfoObject]) -> None:
for info_obj in info_objs:
time.sleep(0.005)
self.unfinished_promises.append(self._invoke_ray_remote_task(info_obj))
#replace with ray.options
def _invoke_ray_remote_task(self, ray_remote_task_info: RayRemoteTaskInfo) -> Any:
#change to using ray.options
ray_remote_task_options_arguments = dict()

if ray_remote_task_info.ray_remote_task_options.memory:
ray_remote_task_options_arguments['memory'] = ray_remote_task_info.ray_remote_task_options.memory

if ray_remote_task_info.ray_remote_task_options.num_cpus:
ray_remote_task_options_arguments['num_cpus'] = ray_remote_task_info.ray_remote_task_options.num_cpus

if ray_remote_task_info.ray_remote_task_options.placement_group:
ray_remote_task_options_arguments['placement_group'] = ray_remote_task_info.ray_remote_task_options.placement_group

ray_remote_task_promise_obj_ref = submit_single_task.options(**ray_remote_task_options_arguments).remote(ray_remote_task_info=ray_remote_task_info)
self.task_promise_obj_ref_to_task_info_map[str(ray_remote_task_promise_obj_ref)] = ray_remote_task_info

return ray_remote_task_promise_obj_ref

#replace with ray.options
def _update_ray_remote_task_options_on_exception(self, exception: Exception, ray_remote_task_info: RayRemoteTaskInfo):
exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(exception, ray_remote_task_info.exception_retry_strategy_configs)
if exception_retry_strategy_config and ray_remote_task_info.ray_remote_task_options.memory:
logger.info(f"Updating the Ray remote task options after encountering exception: {exception}")
ray_remote_task_memory_multiply_factor = exception_retry_strategy_config.ray_remote_task_memory_multiply_factor
ray_remote_task_info.ray_remote_task_options.memory *= ray_remote_task_memory_multiply_factor
logger.info(f"Updated ray remote task options Memory: {ray_remote_task_info.ray_remote_task_options.memory}")
#replace with own exceptions
def _handle_ray_exception(self, exception: Exception, ray_remote_task_info: RayRemoteTaskInfo) -> RayRemoteTaskExecutionError:
logger.error(f"Ray remote task failed with {type(exception)} Ray exception: {exception}")
if type(exception).__name__ == "AWSSecurityTokenRateExceededException(RetryableError)"
Loading