Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RetryHandlerSkeleton #152

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
This module represents a straggler detection and retry handler framework

Within retry_strategy_config.py, the client can provide 3 parameters to start_tasks_execution to perform retries and detect stragglers
Params:
1. ray_remote_task_info: A list of Ray task objects
2. scaling_strategy: Batch scaling parameters for how many tasks to execute per batch (Optional)
a. If not provided, a default AIMD (additive increase, multiplicative decrease) strategy will be assigned for retries
3. straggler_detection: Client-provided class that holds logic for how they want to detect straggler tasks (Optional)
a. Client algorithm must inherit the interface for detection which will be used in wait_and_get_results
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved

Use cases:
1. Notifying progress
a. TaskContext (progressNotifier - (send_heartbeat, send_progress, get_progress), timeout_time) from StragglerDetectionInterface
2. Detecting stragglers
Given the straggler detection algorithm fed in by the client, the method get_timeout_val will be used to determine how
long the task will run before it is considered a straggler. The logic for this must be provided by the client internally.
3. Retrying retryable exceptions
a. Within the failure directory, there are common errors that are retryable and when detected as an instance
of the retryable class, will cause the task to be retried through submitting the task.

The client can provide these inputs to fulfil the following use cases:

Given a list of 1000 tasks, we will first scale each batch to a reasonable size and run the retry and detection on each batch
4 changes: 4 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/TaskContext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class TaskContext():
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, progress_notifier: progressNotifierInterface, timeoutTime: float):
self.progress_notifier = progress_notifier
self.timeoutTime =
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
17 changes: 17 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/exception_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import List, Optional

from ray_manager.models.ray_remote_task_exception_retry_strategy_config import RayRemoteTaskExceptionRetryConfig
"""
Checks whether the exception seen is recognized as a retryable error or not
"""
def get_retry_strategy_config_for_known_exception(exception: Exception,
exception_retry_strategy_configs: List[RayRemoteTaskExceptionRetryConfig]) -> Optional[RayRemoteTaskExceptionRetryConfig]:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
for exception_retry_strategy_config in exception_retry_strategy_configs:
if type(exception) == type(exception_retry_strategy_config.exception):
return exception_retry_strategy_config

for exception_retry_strategy_config in exception_retry_strategy_configs:
if isinstance(exception, type(exception_retry_strategy_config.exception)):
return exception_retry_strategy_config

return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from deltacat.utils.ray_utils.retry_handler.retryable_error.failures import RetryableError

class AWSSecurityTokenException(RetryableError):
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, *args: object) -> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from deltacat.utils.ray_utils.retry_handler.retryable_error.failures import RetryableError

class BrokenPipe(RetryableError):

def __init__(self, *args: object) -> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from deltacat.utils.ray_utils.retry_handler.retryable_error.failures import RetryableError

class CairnsClientException(RetryableError):
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, *args: object) -> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from deltacat.utils.ray_utils.retry_handler.retryable_error.failures import NonRetryableError

class CairnsResourceNotFound(NonRetryableError):
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, *args: object) -> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from deltacat.utils.ray_utils.retry_handler.retryable_error.failures import NonRetryableError

class ManifestGenerationException(NonRetryableError):

def __init__(self, *args: object) -> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class NonRetryableError(RuntimeError):
"""
Class represents a non-retryable error
"""

def __init__(self, *args:object) --> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from deltacat.utils.ray_utils.retry_handler.retryable_error.failures import RetryableError

class PortConflict(RetryableError):

def __init__(self, *args: object) -> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class RetryableError(RuntimeError):
"""
class for errors that can be retried
"""
def __init__(self, *args: object) --> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from deltacat.utils.ray_utils.retry_handler.retryable_error.failures import RetryableError

class UploadPartThrottle(RetryableError):

def __init__(self, *args: object) -> None:
super().__init__(*args)
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
22 changes: 22 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/interface_batch_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from abc import ABC, abstractmethod
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
from typing import List

class BatchScalingInterface(ABC):
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
"""
Interface for a generic batch scaling that the client can provide.
"""
"""
Loads all tasks to be executed for retry and straggler detection
"""
def init_tasks(self, task_infos):
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
pass
"""
Gets the next batch of x size to execute on
"""
def next_batch(self, task_info) -> List:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
pass
"""
Returns true if there are tasks remaining in the overall List of tasks
"""
def has_next_batch(self, running_tasks) -> bool:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from abc import ABC, abstractmethod

class ProgressNotifierInterface(ABC):
"""
Gets progress message regarding current task
"""
@abstractmethod
def get_progress(self, task):
pass

"""
Tells parent task if the current task has a heartbeat or not
"""
@abstractmethod
def has_heartbeat(self, task) -> bool:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
pass

"""
Sends progress of current task to parent task
"""
@abstractmethod
def send_progress(self, task):
pass
34 changes: 34 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/interface_retry_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from abc import ABC, abstractmethod
class RetryTaskInterface(ABC):
@abstractmethod
def init_tasks(self, task_infos):
"""
Loads all tasks to check for retries if exception
:param task_infos:
:return: List of tasks
"""
pass
@abstractmethod
def should_retry(self, task) -> bool:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
"""
Given a task, determine whether it can be retried or not
:param task:
:return: True or False
"""
pass
@abstractmethod
def get_wait_time(self, task):
"""
Wait time between retries
:param task:
:return:
"""
pass
@abstractmethod
def retry(self, task):
"""
Executes retry behavior for the exception
:param task:
:return:
"""
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
from typing import Any

class StragglerDetectionInterface(ABC):

@abstractmethod
def is_straggler(self, task, task_context) -> bool:
"""
Given all the info, returns whether this specific task is a straggler or not
"""
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from deltacat.utils.ray_utils.retry_handler.task_constants import DEFAULT_RAY_REMOTE_TASK_BATCH_NEGATIVE_FEEDBACK_BATCH_SIZE_MULTIPLICATIVE_DECREASE_FACTOR, DEFAULT_RAY_REMOTE_TASK_BATCH_NEGATIVE_FEEDBACK_BACK_OFF_IN_MS, DEFAULT_RAY_REMOTE_TASK_BATCH_POSITIVE_FEEDBACK_BATCH_SIZE_ADDITIVE_INCREASE
from dataclasses import dataclass

class RayRemoteTasksBatchScalingParams(BatchScalingStrategy):
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
"""
Represents the batch scaling params of the Ray remote tasks
need to add constants that this file refers to
"""
def __init__(self,
straggler_detection: StragglerDetectionInterface):
self.straggler_detection = straggler_detection

def init_tasks(self, task_infos):
pass

def next_batch(self, task_info) -> List:
pass

def has_next_batch(self, running_tasks) -> bool:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
pass
149 changes: 149 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/ray_task_submission_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import annotations
from typing import Any, Dict, List, cast, Optional
from deltacat.utils.ray_utils.retry_handler.ray_remote_tasks_batch_scaling_params import RayRemoteTasksBatchScalingParams
import ray
import time
import logging
from deltacat.logs import configure_logger
from deltacat.utils.ray_utils.retry_handler.task_execution_error import RayRemoteTaskExecutionError
from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject
from deltacat.utils.ray_utils.retry_handler.retry_strategy_config import get_retry_strategy_config_for_known_exception

logger = configure_logger(logging.getLogger(__name__))

@ray.remote
def submit_single_task(taskObj: TaskInfoObject, TaskContext: Optional[Interface] = None) -> Any:
try:
taskObj.attempt_count += 1
curr_attempt = taskObj.attempt_count
if TaskContext is not None:
# custom logic for checking if taskContext has progress and then use to detect stragglers
#track time/progress in here
logger.debug(f"Executing the submitted Ray remote task as part of attempt number: {current_attempt_number}")
return taskObj.task_callable(taskObj.task_input)
except (Exception) as exception:
exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(exception, taskObj.exception_retry_strategy_configs)
if exception_retry_strategy_config is not None:
return RayRemoteTaskExecutionError(exception_retry_strategy_config.exception, taskObj)


class RayTaskSubmissionHandler:
"""
Starts execution of all given a list of Ray tasks with optional arguments: scaling strategy and straggler detection
"""
def start_tasks_execution(self,
ray_remote_task_infos: List[TaskInfoObject],
scaling_strategy: Optional[BatchScalingStrategy] = None,
straggler_detection: Optional[StragglerDetectionInterface] = None,
task_context: Optional[TaskContext]) -> None:
if scaling_strategy is None:
scaling_strategy = RayRemoteTasksBatchScalingParams(len(ray_remote_task_infos))
while scaling_strategy.hasNextBatch:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
current_batch = scaling_strategy.next_batch()
for tasks in current_batch:
#execute and retry and detect straggler if avail
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved



#use interface methods and data to detect stragglers in ray
self.num_of_submitted_tasks = len(ray_remote_task_infos)
self.current_batch_size = min(scaling_strategy.get_batch_size, self.num_of_submitted_tasks)
self.num_of_submitted_tasks_completed = 0
self.remaining_ray_remote_task_infos = ray_remote_task_infos
self.batch_scaling_params = batch_scaling_params
self.task_promise_obj_ref_to_task_info_map: Dict[Any, RayRemoteTaskInfo] = {}

self.unfinished_promises: List[Any] = []
logger.info(f"Starting the execution of {len(ray_remote_task_infos)} Ray remote tasks. Concurrency of tasks execution: {self.current_batch_size}")
if straggler_detection is not None:
#feed to non-detection only retry handler
self.__wait_and_get_all_task_results(straggler_detection)
else:
self.__submit_tasks(self.remaining_ray_remote_task_infos[:self.current_batch_size])
self.remaining_ray_remote_task_infos = self.remaining_ray_remote_task_infos[self.current_batch_size:]
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved


def __wait_and_get_all_task_results(self, straggler_detection: Optional[StragglerDetectionInterface]) -> List[Any]:
return self.__get_task_results(self.num_of_submitted_tasks, straggler_detection)

#Straggler detection will go in here
def __get_task_results(self, num_of_results: int, straggler_detection: Optional[StragglerDetectionInterface]) -> List[Any]:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
if straggler_detection is not None:
finished, unfinished = ray.wait(unfinished, num_of_results, straggler_detection.calc_timeout_val)
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
else:
finished, unfinished = ray.wait(unfinished, num_of_results)
for finished in finished:
finished_result = None
try:
finished_result = ray.get(finished)
except (Exception) as exception:
#if exception send to method handle_ray_exception to determine what to do and assign the corresp error
finished_result = self.handle_ray_exception(exception=exception, ray_remote_task_info=self.task_promise_obj_ref_to_task_info_map[str(finished_promise)] )#evaluate the exception and return the error

if finished_result and type(finished_result) == RayRemoteTaskExecutionError:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
finished_result = cast(RayRemoteTaskExecutionError, finished_result)
exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(finished_result.exception,
finished_result.ray_remote_task_info.exception_retry_strategy_configs)
if (exception_retry_strategy_config is None or finished_result.ray_remote_task_info.num_of_attempts > exception_retry_strategy_config.max_retry_attempts):
logger.error(f"The submitted task has exhausted all the maximum retries configured and finally throws exception - {finished_result.exception}")
raise finished_result.exception
self.__update_ray_remote_task_options_on_exception(finished_result.exception, finished_result.ray_remote_task_info)
self.unfinished_promises.append(self.__invoke_ray_remote_task(ray_remote_task_info=finished_result.ray_remote_task_info))
else:
successful_results.append(finished_result)
del self.task_promise_obj_ref_to_task_info_map[str(finished_promise)]

num_of_successful_results = len(successful_results)
self.num_of_submitted_tasks_completed += num_of_successful_results
self.current_batch_size -= num_of_successful_results

self.__enqueue_new_tasks(num_of_successful_results)

if num_of_successful_results < num_of_results:
successful_results.extend(self.wait_and_get_task_results(num_of_results - num_of_successful_results))
return successful_results
else:
return successful_results


def __enqueue_new_tasks(self, num_of_tasks: int) -> None:
new_tasks_submitted = self.remaining_ray_remote_task_infos[:num_of_tasks]
num_of_new_tasks_submitted = len(new_tasks_submitted)
self.__submit_tasks(new_tasks_submitted)
self.remaining_ray_remote_task_infos = self.remaining_ray_remote_task_infos[num_of_tasks:]
self.current_batch_size += num_of_new_tasks_submitted
logger.info(f"Enqueued {num_of_new_tasks_submitted} new tasks. Current concurrency of tasks execution: {self.current_batch_size}, Current Task progress: {self.num_of_submitted_tasks_completed}/{self.num_of_submitted_tasks}")

def __submit_tasks(self, ray_remote_task_infos: List[RayRemoteTaskInfo]) -> None:
for ray_remote_task_info in ray_remote_task_infos:
time.sleep(0.005)
self.unfinished_promises.append(self.__invoke_ray_remote_task(ray_remote_task_info))

def __invoke_ray_remote_task(self, ray_remote_task_info: RayRemoteTaskInfo) -> Any:
ray_remote_task_options_arguments = dict()

if ray_remote_task_info.ray_remote_task_options.memory:
ray_remote_task_options_arguments['memory'] = ray_remote_task_info.ray_remote_task_options.memory

if ray_remote_task_info.ray_remote_task_options.num_cpus:
ray_remote_task_options_arguments['num_cpus'] = ray_remote_task_info.ray_remote_task_options.num_cpus

if ray_remote_task_info.ray_remote_task_options.placement_group:
ray_remote_task_options_arguments['placement_group'] = ray_remote_task_info.ray_remote_task_options.placement_group

ray_remote_task_promise_obj_ref = submit_single_task.options(**ray_remote_task_options_arguments).remote(ray_remote_task_info=ray_remote_task_info)
self.task_promise_obj_ref_to_task_info_map[str(ray_remote_task_promise_obj_ref)] = ray_remote_task_info

return ray_remote_task_promise_obj_ref

def __update_ray_remote_task_options_on_exception(self, exception: Exception, ray_remote_task_info: RayRemoteTaskInfo):
exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(exception, ray_remote_task_info.exception_retry_strategy_configs)
if exception_retry_strategy_config and ray_remote_task_info.ray_remote_task_options.memory:
logger.info(f"Updating the Ray remote task options after encountering exception: {exception}")
ray_remote_task_memory_multiply_factor = exception_retry_strategy_config.ray_remote_task_memory_multiply_factor
ray_remote_task_info.ray_remote_task_options.memory *= ray_remote_task_memory_multiply_factor
logger.info(f"Updated ray remote task options Memory: {ray_remote_task_info.ray_remote_task_options.memory}")

def __handle_ray_exception(self, exception: Exception, ray_remote_task_info: RayRemoteTaskInfo) -> RayRemoteTaskExecutionError:
logger.error(f"Ray remote task failed with {type(exception)} Ray exception: {exception}")
if type(exception).__name__ ==
Empty file.
Loading