diff --git a/PILOTVERSION b/PILOTVERSION index e3b1f212..2fb403f4 100644 --- a/PILOTVERSION +++ b/PILOTVERSION @@ -1 +1 @@ -3.9.1.13 \ No newline at end of file +3.9.2.23b \ No newline at end of file diff --git a/pilot.py b/pilot.py index 5405cdb9..67ffd4db 100755 --- a/pilot.py +++ b/pilot.py @@ -147,7 +147,7 @@ def main() -> int: return error.get_error_code() # update the OIDC token if necessary (after queuedata has been downloaded, since PQ.catchall can contain instruction to prevent token renewal) - if 'no_token_renewal' in infosys.queuedata.catchall: + if 'no_token_renewal' in infosys.queuedata.catchall or args.token_renewal is False: logger.info("OIDC token will not be renewed by the pilot") else: update_local_oidc_token_info(args.url, args.port) @@ -182,8 +182,6 @@ def main() -> int: f"pilot.workflow.{args.workflow}", globals(), locals(), [args.workflow], 0 ) - # check if real-time logging is requested for this queue - #rtloggingtype # update the pilot heartbeat file update_pilot_heartbeat(time.time()) @@ -451,6 +449,16 @@ def get_args() -> Any: help="Maximum number of getjob request failures in Harvester mode", ) + # no_token_renewal + arg_parser.add_argument( + "-y", + "--notokenrenewal", + dest="token_renewal", + action="store_false", + default=True, + help="Disable token renewal", + ) + arg_parser.add_argument( "--subscribe-to-msgsvc", dest="subscribe_to_msgsvc", diff --git a/pilot/api/data.py b/pilot/api/data.py index 93cc76f1..dc0e8764 100644 --- a/pilot/api/data.py +++ b/pilot/api/data.py @@ -20,7 +20,7 @@ # - Mario Lassnig, mario.lassnig@cern.ch, 2017 # - Paul Nilsson, paul.nilsson@cern.ch, 2017-2024 # - Tobias Wegner, tobias.wegner@cern.ch, 2017-2018 -# - Alexey Anisenkov, anisyonk@cern.ch, 2018-2019 +# - Alexey Anisenkov, anisyonk@cern.ch, 2018-2024 """API for data transfers.""" @@ -1072,7 +1072,7 @@ class StageOutClient(StagingClient): mode = "stage-out" - def prepare_destinations(self, files: list, activities: list or str) -> list: + def prepare_destinations(self, files: list, activities: list or str, alt_exclude: list = []) -> list: """ Resolve destination RSE (filespec.ddmendpoint) for each entry from `files` according to requested `activities`. @@ -1080,6 +1080,7 @@ def prepare_destinations(self, files: list, activities: list or str) -> list: :param files: list of FileSpec objects to be processed (list) :param activities: ordered list of activities to be used to resolve astorages (list or str) + :param alt_exclude: global list of destinations that should be excluded / not used for alternative stage-out :return: updated fspec entries (list). """ if not self.infosys.queuedata: # infosys is not initialized: not able to fix destination if need, nothing to do @@ -1108,11 +1109,26 @@ def prepare_destinations(self, files: list, activities: list or str) -> list: raise PilotException(f"Failed to resolve destination: no associated storages defined for activity={activity} ({act})", code=ErrorCodes.NOSTORAGE, state='NO_ASTORAGES_DEFINED') - # take the fist choice for now, extend the logic later if need - ddm = storages[0] - ddm_alt = storages[1] if len(storages) > 1 else None + def resolve_alt_destination(primary, exclude=None): + """ resolve alt destination as the next to primary entry not equal to `primary` and `exclude` """ + + cur = storages.index(primary) if primary in storages else 0 + inext = (cur + 1) % len(storages) # cycle storages, take the first elem when reach end + exclude = set([primary] + list(exclude if exclude is not None else [])) + alt = None + for attempt in range(len(exclude) or 1): # apply several tries to jump exclude entries (in case of dublicated data will stack) + inext = (cur + 1) % len(storages) # cycle storages, start from the beginning when reach end + if storages[inext] not in exclude: + alt = storages[inext] + break + cur += 1 + return alt + + # default destination + ddm = storages[0] # take the fist choice for now, extend the logic later if need + ddm_alt = resolve_alt_destination(ddm, exclude=alt_exclude) - self.logger.info(f"[prepare_destinations][{activity}]: allowed (local) destinations: {storages}") + self.logger.info(f"[prepare_destinations][{activity}]: allowed (local) destinations: {storages}, alt_exclude={alt_exclude}") self.logger.info(f"[prepare_destinations][{activity}]: resolved default destination: ddm={ddm}, ddm_alt={ddm_alt}") for e in files: @@ -1121,17 +1137,18 @@ def prepare_destinations(self, files: list, activities: list or str) -> list: " .. will use default ddm=%s as (local) destination; ddm_alt=%s", activity, e.lfn, ddm, ddm_alt) e.ddmendpoint = ddm e.ddmendpoint_alt = ddm_alt - elif e.ddmendpoint not in storages: # fspec.ddmendpoint is not in associated storages => assume it as final (non local) alternative destination + #elif e.ddmendpoint not in storages and is_unified: ## customize nucleus logic if need + # pass + elif e.ddmendpoint not in storages: # fspec.ddmendpoint is not in associated storages => use it as (non local) alternative destination self.logger.info("[prepare_destinations][%s]: Requested fspec.ddmendpoint=%s is not in the list of allowed (local) destinations" - " .. will consider default ddm=%s for transfer and tag %s as alt. location", activity, e.ddmendpoint, ddm, e.ddmendpoint) - e.ddmendpoint_alt = e.ddmendpoint # verify me - e.ddmendpoint = ddm - else: # set corresponding ddmendpoint_alt if exist (next entry in available storages list) - cur = storages.index(e.ddmendpoint) - ddm_next = storages[cur + 1] if (cur + 1) < len(storages) else storages[0] # cycle storages, take the first elem when reach end - e.ddmendpoint_alt = ddm_next if e.ddmendpoint != ddm_next else None - self.logger.info("[prepare_destinations][%s]: set ddmendpoint_alt=%s for fspec.ddmendpoint=%s", - activity, e.ddmendpoint_alt, e.ddmendpoint) + " .. will consider default ddm=%s as primary and set %s as alt. location", activity, e.ddmendpoint, ddm, e.ddmendpoint) + e.ddmendpoint_alt = e.ddmendpoint if e.ddmendpoint not in alt_exclude else None + e.ddmendpoint = ddm # use default destination, check/verify nucleus case + else: # set corresponding ddmendpoint_alt if exist (next entry in cycled storages list) + e.ddmendpoint_alt = resolve_alt_destination(e.ddmendpoint, exclude=alt_exclude) + + self.logger.info("[prepare_destinations][%s]: use ddmendpoint_alt=%s for fspec.ddmendpoint=%s", + activity, e.ddmendpoint_alt, e.ddmendpoint) return files diff --git a/pilot/common/errorcodes.py b/pilot/common/errorcodes.py index 123a50ad..76effc16 100644 --- a/pilot/common/errorcodes.py +++ b/pilot/common/errorcodes.py @@ -182,6 +182,7 @@ class ErrorCodes: PREEMPTION = 1379 ARCPROXYFAILURE = 1380 ARCPROXYLIBFAILURE = 1381 + PROXYTOOSHORT = 1382 # used at the beginning of the pilot to indicate that the proxy is too short _error_messages = { GENERALERROR: "General pilot error, consult batch log", @@ -326,6 +327,7 @@ class ErrorCodes: PREEMPTION: "Job was preempted", ARCPROXYFAILURE: "General arcproxy failure", ARCPROXYLIBFAILURE: "Arcproxy failure while loading shared libraries", + PROXYTOOSHORT: "Proxy is too short", } put_error_codes = [1135, 1136, 1137, 1141, 1152, 1181] diff --git a/pilot/control/data.py b/pilot/control/data.py index c3d1af73..a71e9e53 100644 --- a/pilot/control/data.py +++ b/pilot/control/data.py @@ -917,19 +917,22 @@ def _do_stageout(job: JobData, args: object, xdata: list, activity: list, title: kwargs = {'workdir': job.workdir, 'cwd': job.workdir, 'usecontainer': False, 'job': job, 'output_dir': args.output_dir, 'catchall': job.infosys.queuedata.catchall, 'rucio_host': args.rucio_host} #, mode='stage-out') - is_unified = job.infosys.queuedata.type == 'unified' + #is_unified = job.infosys.queuedata.type == 'unified' # prod analy unification: use destination preferences from PanDA server for unified queues - if not is_unified: - client.prepare_destinations(xdata, activity) ## FIX ME LATER: split activities: for astorages and for copytools (to unify with ES workflow) + #if not is_unified: + # client.prepare_destinations(xdata, activity) ## FIX ME LATER: split activities: for astorages and for copytools (to unify with ES workflow) - altstageout = not is_unified and job.allow_altstageout() # do not use alt stage-out for unified queues + ## FIX ME LATER: split activities: for `astorages` and `copytools` (to unify with ES workflow) + client.prepare_destinations(xdata, activity, alt_exclude=list(filter(None, [job.nucleus]))) + + altstageout = job.allow_altstageout() client.transfer(xdata, activity, raise_exception=not altstageout, **kwargs) remain_files = [entry for entry in xdata if entry.require_transfer()] # check if alt stageout can be applied (all remain files must have alt storage declared ddmendpoint_alt) has_altstorage = all(entry.ddmendpoint_alt and entry.ddmendpoint != entry.ddmendpoint_alt for entry in remain_files) - logger.info('alt stage-out settings: %s, is_unified=%s, altstageout=%s, remain_files=%s, has_altstorage=%s', - activity, is_unified, altstageout, len(remain_files), has_altstorage) + logger.info('alt stage-out settings: %s, allow_altstageout=%s, remain_files=%s, has_altstorage=%s', + activity, altstageout, len(remain_files), has_altstorage) if altstageout and remain_files and has_altstorage: # apply alternative stageout for failed transfers for entry in remain_files: @@ -992,8 +995,12 @@ def _stage_out_new(job: JobData, args: object) -> bool: logger.info('this job does not have any output files, only stage-out log file') job.stageout = 'log' + is_unified = job.infosys.queuedata.type == 'unified' + is_analysis = job.is_analysis() + activities = ['write_lan_analysis', 'write_lan', 'w'] if is_unified and is_analysis else ['write_lan', 'w'] + if job.stageout != 'log': ## do stage-out output files - if not _do_stageout(job, args, job.outdata, ['pw', 'w'], title='output', + if not _do_stageout(job, args, job.outdata, activities, title='output', ipv=args.internet_protocol_version): is_success = False logger.warning('transfer of output file(s) failed') @@ -1037,7 +1044,7 @@ def _stage_out_new(job: JobData, args: object) -> bool: # write time stamps to pilot timing file add_to_pilot_timing(job.jobid, PILOT_POST_LOG_TAR, time.time(), args) - if not _do_stageout(job, args, [logfile], ['pl', 'pw', 'w'], title='log', + if not _do_stageout(job, args, [logfile], ['pl'] + activities, title='log', ipv=args.internet_protocol_version): is_success = False logger.warning('log transfer failed') diff --git a/pilot/control/job.py b/pilot/control/job.py index 506b0ffc..6c8c4c28 100644 --- a/pilot/control/job.py +++ b/pilot/control/job.py @@ -1550,13 +1550,13 @@ def proceed_with_getjob(timefloor: int, starttime: int, jobnumber: int, getjob_r if verify_proxy: userproxy = __import__(f'pilot.user.{pilot_user}.proxy', globals(), locals(), [pilot_user], 0) - # is the proxy still valid? - exit_code, diagnostics = userproxy.verify_proxy(test=False) + # is the proxy still valid? at pilot startup, the proxy lifetime must be at least 72h + exit_code, diagnostics = userproxy.verify_proxy(test=False, pilotstartup=True) if traces.pilot['error_code'] == 0: # careful so we don't overwrite another error code traces.pilot['error_code'] = exit_code if exit_code == errors.ARCPROXYLIBFAILURE: logger.warning("currently ignoring arcproxy library failure") - if exit_code in {errors.NOPROXY, errors.NOVOMSPROXY, errors.CERTIFICATEHASEXPIRED}: + if exit_code in {errors.NOPROXY, errors.NOVOMSPROXY, errors.CERTIFICATEHASEXPIRED, errors.PROXYTOOSHORT}: logger.warning(diagnostics) return False diff --git a/pilot/control/monitor.py b/pilot/control/monitor.py index 9c92c9ec..89ff1f5f 100644 --- a/pilot/control/monitor.py +++ b/pilot/control/monitor.py @@ -109,7 +109,7 @@ def control(queues: namedtuple, traces: Any, args: object): # noqa: C901 if tokendownloadchecktime and queuedata: if int(time.time() - last_token_check) > tokendownloadchecktime: last_token_check = time.time() - if 'no_token_renewal' in queuedata.catchall: + if 'no_token_renewal' in queuedata.catchall or args.token_renewal is False: logger.info("OIDC token will not be renewed by the pilot") else: update_local_oidc_token_info(args.url, args.port) diff --git a/pilot/control/payloads/generic.py b/pilot/control/payloads/generic.py index 2317ff7b..056d2c32 100644 --- a/pilot/control/payloads/generic.py +++ b/pilot/control/payloads/generic.py @@ -842,7 +842,7 @@ def run(self) -> tuple[int, str]: # noqa: C901 if stdout and stderr else "General payload setup verification error (check setup logs)" ) - # check for special errors in thw output + # check for special errors in the output exit_code = errors.resolve_transform_error(exit_code, diagnostics) diagnostics = errors.format_diagnostics(exit_code, diagnostics) return exit_code, diagnostics diff --git a/pilot/info/jobdata.py b/pilot/info/jobdata.py index 2cc149be..b769f831 100644 --- a/pilot/info/jobdata.py +++ b/pilot/info/jobdata.py @@ -16,7 +16,7 @@ # under the License. # # Authors: -# - Alexey Anisenkov, anisyonk@cern.ch, 2018-19 +# - Alexey Anisenkov, anisyonk@cern.ch, 2018-24 # - Paul Nilsson, paul.nilsson@cern.ch, 2018-24 # - Wen Guan, wen.guan@cern.ch, 2018 @@ -177,6 +177,7 @@ class JobData(BaseData): noexecstrcnv = None # server instruction to the pilot if it should take payload setup from job parameters swrelease = "" # software release string writetofile = "" # + nucleus = "" # cmtconfig encoded info alrbuserplatform = "" # ALRB_USER_PLATFORM encoded in platform/cmtconfig value @@ -195,7 +196,7 @@ class JobData(BaseData): 'swrelease', 'zipmap', 'imagename', 'imagename_jobdef', 'accessmode', 'transfertype', 'datasetin', ## TO BE DEPRECATED: moved to FileSpec (job.indata) 'infilesguids', 'memorymonitor', 'allownooutput', 'pandasecrets', 'prodproxy', 'alrbuserplatform', - 'debug_command', 'dask_scheduler_ip', 'jupyter_session_ip', 'altstageout'], + 'debug_command', 'dask_scheduler_ip', 'jupyter_session_ip', 'altstageout', 'nucleus'], list: ['piloterrorcodes', 'piloterrordiags', 'workdirsizes', 'zombies', 'corecounts', 'subprocesses', 'logdata', 'outdata', 'indata'], dict: ['status', 'fileinfo', 'metadata', 'utilities', 'overwrite_queuedata', 'sizes', 'preprocess', diff --git a/pilot/user/atlas/proxy.py b/pilot/user/atlas/proxy.py index b6defea9..7b0295e5 100644 --- a/pilot/user/atlas/proxy.py +++ b/pilot/user/atlas/proxy.py @@ -90,7 +90,7 @@ def get_and_verify_proxy(x509: str, voms_role: str = '', proxy_type: str = '', w return exit_code, diagnostics, x509 -def verify_proxy(limit: int = None, x509: bool = None, proxy_id: str = "pilot", test: bool = False) -> tuple[int, str]: +def verify_proxy(limit: int = None, x509: bool = None, proxy_id: str = "pilot", test: bool = False, pilotstartup: bool = False) -> tuple[int, str]: """ Check for a valid voms/grid proxy longer than N hours. @@ -100,8 +100,11 @@ def verify_proxy(limit: int = None, x509: bool = None, proxy_id: str = "pilot", :param x509: points to the proxy file. If not set (=None) - get proxy file from X509_USER_PROXY environment (bool) :param proxy_id: proxy id (str) :param test: free Boolean test parameter (bool) + :param pilotstartup: free Boolean pilotstartup parameter (bool) :return: exit code (NOPROXY or NOVOMSPROXY) (int), diagnostics (error diagnostics string) (str) (tuple). """ + if pilotstartup: + limit = 72 # 3 days if limit is None: limit = 1 @@ -161,9 +164,6 @@ def verify_arcproxy(envsetup: str, limit: int, proxy_id: str = "pilot", test: bo # validityLeft - duration of proxy validity left in seconds. # vomsACvalidityEnd - timestamp when VOMS attribute validity ends. # vomsACvalidityLeft - duration of VOMS attribute validity left in seconds. - cmd = f"{envsetup}arcproxy -i subject" - _exit_code, _, _ = execute(cmd, shell=True) # , usecontainer=True, copytool=True) - cmd = f"{envsetup}arcproxy -i validityEnd -i validityLeft -i vomsACvalidityEnd -i vomsACvalidityLeft" _exit_code, stdout, stderr = execute_nothreads(cmd, shell=True) # , usecontainer=True, copytool=True) if stdout is not None: @@ -173,6 +173,7 @@ def verify_arcproxy(envsetup: str, limit: int, proxy_id: str = "pilot", test: bo exit_code = -1 else: exit_code, diagnostics, validity_end_cert, validity_end = interpret_proxy_info(_exit_code, stdout, stderr, limit) + # validity_end = int(time()) + 71 * 3600 # 71 hours test if proxy_id and validity_end: # setup cache if requested if exit_code == 0: @@ -222,7 +223,12 @@ def check_time_left(proxyname: str, validity: int, limit: int) -> tuple[int, str logger.info(f"cache: check {proxyname} validity: wanted={limit}h ({limit * 3600 - 20 * 60}s with grace) " f"left={float(seconds_left) / 3600:.2f}h (now={tnow} validity={validity} left={seconds_left}s)") - if seconds_left < limit * 3600 - 20 * 60: + # special case for limit=72h (3 days) for pilot startup + if limit == 72 and seconds_left < limit * 3600 - 20 * 60: + diagnostics = f'proxy is too short for pilot startup: {float(seconds_left) / 3600:.2f}h' + logger.warning(diagnostics) + exit_code = errors.PROXYTOOSHORT + elif seconds_left < limit * 3600 - 20 * 60: diagnostics = f'cert/proxy is about to expire: {float(seconds_left) / 3600:.2f}h' logger.warning(diagnostics) exit_code = errors.CERTIFICATEHASEXPIRED if proxyname == 'cert' else errors.VOMSPROXYABOUTTOEXPIRE diff --git a/pilot/user/generic/proxy.py b/pilot/user/generic/proxy.py index 579f92e0..bc91ae0f 100644 --- a/pilot/user/generic/proxy.py +++ b/pilot/user/generic/proxy.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) -def verify_proxy(limit: int = None, x509: str = None, proxy_id: str = "pilot", test: bool = False) -> (int, str): +def verify_proxy(limit: int = None, x509: str = None, proxy_id: str = "pilot", test: bool = False, pilotstartup: bool = False) -> (int, str): """ Check for a valid voms/grid proxy longer than N hours. @@ -37,6 +37,7 @@ def verify_proxy(limit: int = None, x509: str = None, proxy_id: str = "pilot", t :param x509: points to the proxy file. If not set (=None) - get proxy file from X509_USER_PROXY environment (str) :param proxy_id: proxy id (str) :param test: free Boolean test parameter (bool) + :param pilotstartup: free Boolean pilotstartup parameter (bool) :return: exit code (NOPROXY or NOVOMSPROXY), diagnostics (error diagnostics string) (int, str). """ if limit or x509 or proxy_id or test: # to bypass pylint score 0 diff --git a/pilot/user/rubin/copytool_definitions.py b/pilot/user/rubin/copytool_definitions.py index 7a9ef40d..163f2a33 100644 --- a/pilot/user/rubin/copytool_definitions.py +++ b/pilot/user/rubin/copytool_definitions.py @@ -17,34 +17,33 @@ # under the License. # # Authors: -# - Paul Nilsson, paul.nilsson@cern.ch, 2022-23 +# - Paul Nilsson, paul.nilsson@cern.ch, 2022-24 from hashlib import md5 -def mv_to_final_destination(): +def mv_to_final_destination() -> bool: """ - Is mv allowed to move files to/from final destination? + Decide if mv is allowed to move files to/from final destination. - :return: Boolean. + :return: True if mv is allowed to move files to final destination, False otherwise (bool). """ - return False -def get_path(scope, lfn): +def get_path(scope: str, lfn: str) -> str: """ - Construct a partial Rucio PFN using the scope and the LFN + Construct a partial Rucio PFN using the scope and the LFN. + /md5(:)[0:2]/md5()[2:4]/ E.g. scope = 'user.jwebb2', lfn = 'user.jwebb2.66999._000001.top1outDS.tar' -> 'user/jwebb2/01/9f/user.jwebb2.66999._000001.top1outDS.tar' - :param scope: scope (string). - :param lfn: LFN (string). - :return: partial rucio path (string). + :param scope: scope (str) + :param lfn: LFN (str) + :return: partial rucio path (str). """ - s = f'{scope}:{lfn}' hash_hex = md5(s.encode('utf-8')).hexdigest() paths = scope.split('.') + [hash_hex[0:2], hash_hex[2:4], lfn] diff --git a/pilot/user/rubin/diagnose.py b/pilot/user/rubin/diagnose.py index 34844259..0b44d51f 100644 --- a/pilot/user/rubin/diagnose.py +++ b/pilot/user/rubin/diagnose.py @@ -17,26 +17,26 @@ # under the License. # # Authors: -# - Paul Nilsson, paul.nilsson@cern.ch, 2021-23 +# - Paul Nilsson, paul.nilsson@cern.ch, 2021-24 # - Tadashi Maeno, tadashi.maeno@cern.ch, 2020 +import logging import os +from pilot.info.jobdata import JobData from pilot.util.config import config from pilot.util.filehandling import read_file, tail -import logging logger = logging.getLogger(__name__) -def interpret(job): +def interpret(job: JobData) -> int: """ Interpret the payload, look for specific errors in the stdout. - :param job: job object + :param job: job object (JobData) :return: exit code (payload) (int). """ - stdout = os.path.join(job.workdir, config.Payload.payloadstdout) if os.path.exists(stdout): message = 'payload stdout dump\n' @@ -55,16 +55,16 @@ def interpret(job): return 0 -def get_log_extracts(job, state): +def get_log_extracts(job: JobData, state: str) -> str: """ Extract special warnings and other info from special logs. + This function also discovers if the payload had any outbound connections. - :param job: job object. - :param state: job state (string). - :return: log extracts (string). + :param job: job object (JobData) + :param state: job state (str) + :return: log extracts (str). """ - logger.info("building log extracts (sent to the server as \'pilotLog\')") # for failed/holding jobs, add extracts from the pilot log file, but always add it to the pilot log itself @@ -72,20 +72,19 @@ def get_log_extracts(job, state): _extracts = get_pilot_log_extracts(job) if _extracts != "": logger.warning(f'detected the following tail of warning/fatal messages in the pilot log:\n{_extracts}') - if state == 'failed' or state == 'holding': + if state in {'failed', 'holding'}: extracts += _extracts return extracts -def get_pilot_log_extracts(job): +def get_pilot_log_extracts(job: JobData) -> str: """ Get the extracts from the pilot log (warning/fatal messages, as well as tail of the log itself). - :param job: job object. - :return: tail of pilot log (string). + :param job: job object (JobData) + :return: tail of pilot log (str). """ - extracts = "" path = os.path.join(job.workdir, config.Pilot.pilotlog) diff --git a/pilot/user/rubin/esprocessfinegrainedproc.py b/pilot/user/rubin/esprocessfinegrainedproc.py index 9a90ed3d..838593fd 100644 --- a/pilot/user/rubin/esprocessfinegrainedproc.py +++ b/pilot/user/rubin/esprocessfinegrainedproc.py @@ -60,17 +60,17 @@ def __init__(self, max_workers=None, thread_name_prefix='', initializer=None, in self.outputs = {} self._lock = threading.RLock() self.max_workers = max_workers - super(ESRunnerThreadPool, self).__init__(max_workers=max_workers, - thread_name_prefix=thread_name_prefix, - initializer=initializer, - initargs=initargs) + super().__init__(max_workers=max_workers, + thread_name_prefix=thread_name_prefix, + initializer=initializer, + initargs=initargs) def submit(self, fn, *args, **kwargs): - future = super(ESRunnerThreadPool, self).submit(fn, *args, **kwargs) + future = super().submit(fn, *args, **kwargs) return future def run_event(self, fn, event): - future = super(ESRunnerThreadPool, self).submit(fn, event) + future = super().submit(fn, event) with self._lock: self.futures[event['eventRangeID']] = {'event': event, 'future': future} diff --git a/pilot/user/rubin/loopingjob_definitions.py b/pilot/user/rubin/loopingjob_definitions.py index 664b1744..2bada33a 100644 --- a/pilot/user/rubin/loopingjob_definitions.py +++ b/pilot/user/rubin/loopingjob_definitions.py @@ -20,29 +20,30 @@ # - Paul Nilsson, paul.nilsson@cern.ch, 2018-24 -def allow_loopingjob_detection(): +def allow_loopingjob_detection() -> bool: """ - Should the looping job detection algorithm be allowed? + Decide if the looping job detection algorithm should be allowed. + The looping job detection algorithm finds recently touched files within the job's workdir. If a found file has not been touched during the allowed time limit (see looping job section in util/default.cfg), the algorithm will kill the job/payload process. - :return: boolean. + :return: True if yes (bool). """ - return True -def remove_unwanted_files(workdir, files): +def remove_unwanted_files(workdir: str, files: list) -> list: """ Remove files from the list that are to be ignored by the looping job algorithm. - :param workdir: working directory (string). Needed in case the find command includes the workdir in the list of - recently touched files. - :param files: list of recently touched files (file names). - :return: filtered files list. - """ + The workdir is needed in case the find command includes the workdir + in the list of recently touched files. + :param workdir: working directory (str) + :param files: list of recently touched files (list) + :return: filtered files (list). + """ _files = [] for _file in files: if not (workdir == _file or diff --git a/pilot/user/rubin/proxy.py b/pilot/user/rubin/proxy.py index 13662df0..52426148 100644 --- a/pilot/user/rubin/proxy.py +++ b/pilot/user/rubin/proxy.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) -def verify_proxy(limit: int = None, x509: str = None, proxy_id: str = "pilot", test: bool = False) -> (int, str): +def verify_proxy(limit: int = None, x509: str = None, proxy_id: str = "pilot", test: bool = False, pilotstartup: bool = False) -> (int, str): """ Check for a valid voms/grid proxy longer than N hours. @@ -37,6 +37,7 @@ def verify_proxy(limit: int = None, x509: str = None, proxy_id: str = "pilot", t :param x509: points to the proxy file. If not set (=None) - get proxy file from X509_USER_PROXY environment (str) :param proxy_id: proxy id (str) :param test: free Boolean test parameter (bool) + :param pilotstartup: free Boolean pilotstartup parameter (bool) :return: exit code (NOPROXY or NOVOMSPROXY), diagnostics (error diagnostics string) (int, str). """ if limit or x509 or proxy_id or test: # to bypass pylint score 0 diff --git a/pilot/user/sphenix/cpu.py b/pilot/user/sphenix/cpu.py index 2126ab0b..5da6c85e 100644 --- a/pilot/user/sphenix/cpu.py +++ b/pilot/user/sphenix/cpu.py @@ -17,22 +17,26 @@ # under the License. # # Authors: -# - Paul Nilsson, paul.nilsson@cern.ch, 2020-2024 +# - Paul Nilsson, paul.nilsson@cern.ch, 2020-24 -from typing import Any +import logging + +from pilot.info.jobdata import JobData from pilot.util.container import execute -import logging logger = logging.getLogger(__name__) -def get_core_count(job: Any) -> int: +def get_core_count(job: JobData) -> int: """ Return the core count. - :param job: job object (Any) + :param job: job object (JobData) :return: core count (int). """ + if not job: # to bypass pylint warning + pass + return 0 @@ -46,6 +50,7 @@ def add_core_count(corecount: int, core_counts: list = None) -> list: """ if core_counts is None: core_counts = [] + return core_counts.append(corecount) @@ -58,11 +63,11 @@ def set_core_counts(**kwargs): job = kwargs.get('job', None) if job and job.pgrp: cmd = f"ps axo pgid,psr | sort | grep {job.pgrp} | uniq | awk '{{print $1}}' | grep -x {job.pgrp} | wc -l" - exit_code, stdout, stderr = execute(cmd, mute=True) + _, stdout, _ = execute(cmd, mute=True) logger.debug(f'{cmd}: {stdout}') try: job.actualcorecount = int(stdout) - except Exception as e: + except (ValueError, TypeError) as e: logger.warning(f'failed to convert number of actual cores to int: {e}') else: logger.debug(f'set number of actual cores to: {job.actualcorecount}') diff --git a/pilot/user/sphenix/monitoring.py b/pilot/user/sphenix/monitoring.py index b42e5917..8ef2750d 100644 --- a/pilot/user/sphenix/monitoring.py +++ b/pilot/user/sphenix/monitoring.py @@ -17,18 +17,20 @@ # under the License. # # Authors: -# - Paul Nilsson, paul.nilsson@cern.ch, 2021-23 +# - Paul Nilsson, paul.nilsson@cern.ch, 2021-24 -from typing import Any +from pilot.info.jobdata import JobData -def fast_monitor_tasks(job: Any) -> int: +def fast_monitor_tasks(job: JobData) -> int: """ Perform fast monitoring tasks. - :param job: job object (Any) + :param job: job object (JobData) :return: exit code (int). """ + if not job: # to bypass pylint warning + pass exit_code = 0 return exit_code diff --git a/pilot/user/sphenix/proxy.py b/pilot/user/sphenix/proxy.py index 187f4d80..83fcea49 100644 --- a/pilot/user/sphenix/proxy.py +++ b/pilot/user/sphenix/proxy.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) -def verify_proxy(limit: int = None, x509: bool = None, proxy_id: str = "pilot", test: bool = False) -> (int, str): +def verify_proxy(limit: int = None, x509: bool = None, proxy_id: str = "pilot", test: bool = False, pilotstartup: bool = False) -> (int, str): """ Check for a valid voms/grid proxy longer than N hours. Use `limit` to set required time limit. @@ -37,6 +37,7 @@ def verify_proxy(limit: int = None, x509: bool = None, proxy_id: str = "pilot", :param x509: points to the proxy file. If not set (=None) - get proxy file from X509_USER_PROXY environment (bool) :param proxy_id: proxy id (str) :param test: free Boolean test parameter (bool) + :param pilotstartup: free Boolean pilotstartup parameter (bool) :return: exit code (NOPROXY or NOVOMSPROXY) (int), diagnostics (error diagnostics string) (str). """ if limit or x509 or proxy_id or test: # to bypass pylint score 0 diff --git a/pilot/util/auxiliary.py b/pilot/util/auxiliary.py index 7c5e3146..ddcd8547 100644 --- a/pilot/util/auxiliary.py +++ b/pilot/util/auxiliary.py @@ -179,6 +179,7 @@ def get_error_code_translation_dictionary() -> dict: """ error_code_translation_dictionary = { -1: [64, "Site offline"], + errors.CVMFSISNOTALIVE: [64, "CVMFS is not responding"], # same exit code as site offline errors.GENERALERROR: [65, "General pilot error, consult batch log"], # added to traces object errors.MKDIR: [66, "Could not create directory"], # added to traces object errors.NOSUCHFILE: [67, "No such file or directory"], # added to traces object @@ -196,7 +197,7 @@ def get_error_code_translation_dictionary() -> dict: errors.MISSINGINPUTFILE: [77, "Missing input file in SE"], # should pilot report this type of error to wrapper? errors.PANDAQUEUENOTACTIVE: [78, "PanDA queue is not active"], errors.COMMUNICATIONFAILURE: [79, "PanDA server communication failure"], - errors.CVMFSISNOTALIVE: [64, "CVMFS is not responding"], # same exit code as site offline + errors.PROXYTOOSHORT: [80, "Proxy too short"], # added to traces object errors.KILLSIGNAL: [137, "General kill signal"], # Job terminated by unknown kill signal errors.SIGTERM: [143, "Job killed by signal: SIGTERM"], # 128+15 errors.SIGQUIT: [131, "Job killed by signal: SIGQUIT"], # 128+3 diff --git a/pilot/util/constants.py b/pilot/util/constants.py index e90371e1..1334b93b 100644 --- a/pilot/util/constants.py +++ b/pilot/util/constants.py @@ -27,8 +27,8 @@ # Pilot version RELEASE = '3' # released number should be fixed at 3 for Pilot 3 VERSION = '9' # version number is '1' for first release, '0' until then, increased for bigger updates -REVISION = '1' # revision number should be reset to '0' for every new version release, increased for small updates -BUILD = '13' # build number should be reset to '1' for every new development cycle +REVISION = '2' # revision number should be reset to '0' for every new version release, increased for small updates +BUILD = '23' # build number should be reset to '1' for every new development cycle SUCCESS = 0 FAILURE = 1 diff --git a/pilot/util/container.py b/pilot/util/container.py index e5837d14..a3661230 100644 --- a/pilot/util/container.py +++ b/pilot/util/container.py @@ -100,7 +100,8 @@ def execute(executable: Any, **kwargs: dict) -> Any: # noqa: C901 stdout=kwargs.get('stdout', subprocess.PIPE), stderr=kwargs.get('stderr', subprocess.PIPE), cwd=kwargs.get('cwd', getcwd()), - preexec_fn=os.setsid, # setpgrp + start_new_session=True, # alternative to use os.setsid + # preexec_fn=os.setsid, # setpgrp encoding='utf-8', errors='replace') if kwargs.get('returnproc', False): @@ -131,8 +132,10 @@ def read_output(stream, queue): stdout_thread = threading.Thread(target=read_output, args=(process.stdout, stdout_queue)) stderr_thread = threading.Thread(target=read_output, args=(process.stderr, stderr_queue)) - stdout_thread.start() - stderr_thread.start() + # start the threads and use thread synchronization + with threading.Lock(): + stdout_thread.start() + stderr_thread.start() try: logger.debug(f'subprocess.communicate() will use timeout {timeout} s') @@ -183,95 +186,6 @@ def read_output(stream, queue): return exit_code, stdout, stderr -def execute_old2(executable: Any, **kwargs: dict) -> Any: # noqa: C901 - usecontainer = kwargs.get('usecontainer', False) - job = kwargs.get('job') - obscure = kwargs.get('obscure', '') - - if isinstance(executable, list): - executable = ' '.join(executable) - - if job and job.imagename != "" and "runcontainer" in executable: - usecontainer = False - job.usecontainer = usecontainer - - if usecontainer: - executable, diagnostics = containerise_executable(executable, **kwargs) - if not executable: - return None if kwargs.get('returnproc', False) else -1, "", diagnostics - - if not kwargs.get('mute', False): - print_executable(executable, obscure=obscure) - - timeout = get_timeout(kwargs.get('timeout', None)) - exe = ['/usr/bin/python'] + executable.split() if kwargs.get('mode', 'bash') == 'python' else ['/bin/bash', '-c', executable] - - exit_code = 0 - stdout = '' - stderr = '' - - def read_output(pipe, output_list): - for line in iter(pipe.readline, ''): - output_list.append(line) - pipe.close() - - process = None - with execute_lock: - process = subprocess.Popen(exe, - bufsize=-1, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - cwd=kwargs.get('cwd', getcwd()), - preexec_fn=os.setsid, - encoding='utf-8', - errors='replace') - if kwargs.get('returnproc', False): - return process - - stdout_lines = [] - stderr_lines = [] - - stdout_thread = threading.Thread(target=read_output, args=(process.stdout, stdout_lines)) - stderr_thread = threading.Thread(target=read_output, args=(process.stderr, stderr_lines)) - - stdout_thread.start() - stderr_thread.start() - - try: - logger.debug(f'subprocess.communicate() will use timeout {timeout} s') - process.wait(timeout=timeout) - except subprocess.TimeoutExpired as exc: - stderr += f'subprocess communicate sent TimeoutExpired: {exc}' - logger.warning(stderr) - exit_code = errors.COMMANDTIMEDOUT - stderr = kill_all(process, stderr) - except Exception as exc: - logger.warning(f'exception caused when executing command: {executable}: {exc}') - exit_code = errors.UNKNOWNEXCEPTION - stderr = kill_all(process, str(exc)) - else: - exit_code = process.poll() - - stdout_thread.join() - stderr_thread.join() - - stdout = ''.join(stdout_lines) - stderr = ''.join(stderr_lines) - - try: - if process: - process.wait(timeout=60) - except subprocess.TimeoutExpired: - if process: - logger.warning("process did not complete within the timeout of 60s - terminating") - process.terminate() - - if stdout and stdout.endswith('\n'): - stdout = stdout[:-1] - - return exit_code, stdout, stderr - - def execute_nothreads(executable: Any, **kwargs: dict) -> Any: """ Execute the command with its options in the provided executable list using subprocess time-out handler. diff --git a/pilot/util/https.py b/pilot/util/https.py index 70bcacc8..9b483e9a 100644 --- a/pilot/util/https.py +++ b/pilot/util/https.py @@ -791,6 +791,14 @@ def get_auth_token_content(auth_token: str, key: bool = False) -> str: return auth_token_content +class IPv4HTTPHandler(urllib.request.HTTPHandler): + def http_open(self, req): + return self.do_open(self._create_connection, req) + + def _create_connection(self, host, port=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None): + return socket.create_connection((host, port), timeout, source_address, family=socket.AF_INET) + + def request2(url: str = "", data: dict = None, secure: bool = True, @@ -799,6 +807,121 @@ def request2(url: str = "", """ Send a request using HTTPS (using urllib module). + :param url: the URL of the resource (str) + :param data: data to send (dict) + :param secure: use secure connection (bool) + :param compressed: compress data (bool) + :param panda: True for panda server interactions (bool) + :return: server response (str or dict). + """ + if data is None: + data = {} + + ipv = os.environ.get("PILOT_IP_VERSION") + + # https might not have been set up if running in a [middleware] container + if not _ctx.cacert: + https_setup(None, get_pilot_version()) + + # should tokens be used? + auth_token, auth_origin = get_local_oidc_token_info() + use_oidc_token = auth_token and auth_origin and panda + auth_token_content = get_auth_token_content(auth_token) if use_oidc_token else "" + if not auth_token_content and use_oidc_token: + logger.warning('OIDC_AUTH_TOKEN/PANDA_AUTH_TOKEN content could not be read') + return "" + + # get the relevant headers + headers = get_headers(use_oidc_token, auth_token_content, auth_origin) + logger.info(f'headers = {hide_token(headers.copy())}') + logger.info(f'data = {data}') + + # encode data as compressed JSON + if compressed: + rdata_out = BytesIO() + with GzipFile(fileobj=rdata_out, mode="w") as f_gzip: + f_gzip.write(json.dumps(data).encode()) + data_json = rdata_out.getvalue() + else: + data_json = json.dumps(data).encode('utf-8') + + # set up the request + req = urllib.request.Request(url, data_json, headers=headers) + + # create a context with certificate verification + ssl_context = get_ssl_context() + #ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.load_cert_chain(certfile=_ctx.cacert, keyfile=_ctx.cacert) + + if not secure: + ssl_context.verify_mode = False + ssl_context.check_hostname = False + + if ipv == 'IPv4': + logger.info("will use IPv4 in server communication") + install_ipv4_opener() + else: + logger.info("will use IPv6 in server communication") + + # ssl_context = ssl.create_default_context(capath=_ctx.capath, cafile=_ctx.cacert) + # Send the request securely + try: + logger.debug('sending data to server') + with urllib.request.urlopen(req, context=ssl_context, timeout=config.Pilot.http_maxtime) as response: + # Handle the response here + logger.info(f"response.status={response.status}, response.reason={response.reason}") + ret = response.read().decode('utf-8') + if 'getProxy' not in url: + logger.info(f"response={ret}") + logger.debug('sent request to server') + except (urllib.error.URLError, urllib.error.HTTPError, TimeoutError) as exc: + logger.warning(f'failed to send request: {exc}') + ret = "" + else: + if secure and isinstance(ret, str): + if ret == 'Succeeded': # this happens for sending modeOn (debug mode) + ret = {'StatusCode': '0'} + elif ret.startswith('{') and ret.endswith('}'): + try: + ret = json.loads(ret) + except json.JSONDecodeError as e: + logger.warning(f'failed to parse response: {e}') + else: # response="StatusCode=_some number_" + # Parse the query string into a dictionary + query_dict = parse_qs(ret) + + # Convert lists to single values + ret = {k: v[0] if len(v) == 1 else v for k, v in query_dict.items()} + + return ret + + +def install_ipv4_opener(): + """Install the IPv4 opener.""" + http_proxy = os.environ.get("http_proxy") + all_proxy = os.environ.get("all_proxy") + if http_proxy and all_proxy: + logger.info(f"using http_proxy={http_proxy}, all_proxy={all_proxy}") + proxy_handler = urllib.request.ProxyHandler({ + 'http': http_proxy, + 'https': http_proxy, + 'all': all_proxy + }) + opener = urllib.request.build_opener(proxy_handler, IPv4HTTPHandler()) + else: + logger.info("no http_proxy found, will use IPv4 without proxy") + opener = urllib.request.build_opener(IPv4HTTPHandler()) + urllib.request.install_opener(opener) + + +def request2_old(url: str = "", + data: dict = None, + secure: bool = True, + compressed: bool = True, + panda: bool = False) -> str or dict: + """ + Send a request using HTTPS (using urllib module). + :param url: the URL of the resource (str) :param data: data to send (dict) :param secure: use secure connection (bool) diff --git a/pilot/util/monitoring.py b/pilot/util/monitoring.py index 9b7e1d14..bbd64286 100644 --- a/pilot/util/monitoring.py +++ b/pilot/util/monitoring.py @@ -62,7 +62,8 @@ from pilot.util.psutils import ( is_process_running, get_pid, - get_subprocesses + get_subprocesses, + find_actual_payload_pid ) from pilot.util.timing import get_time_since from pilot.util.workernode import ( @@ -136,8 +137,8 @@ def job_monitor_tasks(job: JobData, mt: MonitoringTime, args: object) -> tuple[i if exit_code != 0: return exit_code, diagnostics - # display OOM process info (once) - display_oom_info(job.pid) + # update the OOM process info to prevent killing processes in the wrong order in case the job is killed (once) + update_oom_info(job.pid, job.transformation) # should the pilot abort the payload? exit_code, diagnostics = should_abort_payload(current_time, mt) @@ -199,28 +200,57 @@ def still_running(pid): return running -def display_oom_info(payload_pid): +def update_oom_info(bash_pid, payload_cmd): """ - Display OOM process info. + Update OOM process info. - :param payload_pid: payload pid (int). + In case the job is killed, the OOM process info should be updated to prevent killing processes in the wrong order. + It will otherwise lead to lingering processes. + + :param bash_pid: bash chain pid (int) + :param payload_cmd: payload command (string). """ - #fname = f"/proc/{payload_pid}/oom_score_adj" + # use the pid of the bash chain to get the actual payload pid which should be a child process + payload_pid = find_actual_payload_pid(bash_pid, payload_cmd) + if not payload_pid: + return + + fname = f"/proc/{payload_pid}/oom_score" + fname_adj = fname + "_adj" payload_score = get_score(payload_pid) if payload_pid else 'UNKNOWN' pilot_score = get_score(os.getpid()) + + cmd = "whoami" + _, stdout, _ = execute(cmd) + logger.debug(f"stdout = {stdout}") + cmd = f"ls -l {fname_adj}" + _, stdout, _ = execute(cmd) + logger.debug(f"stdout = {stdout}") + if isinstance(pilot_score, str) and pilot_score == 'UNKNOWN': logger.warning(f'could not get oom_score for pilot process: {pilot_score}') else: - #relative_payload_score = "1" + relative_payload_score = "1" + write_to_oom_score_adj(payload_pid, relative_payload_score) + logger.info(f'oom_score(pilot) = {pilot_score}, oom_score(payload) = {payload_score} (attempted writing relative score 1 to {fname})') + - # write the payload oom_score to the oom_score_adj file - #try: - # write_file(path=fname, contents=relative_payload_score) - #except Exception as e: # FileHandlingFailure - # logger.warning(f'could not write oom_score to file: {e}') +def write_to_oom_score_adj(pid, value): + """Writes the specified value to the oom_score_adj file for the given PID. - #logger.info(f'oom_score(pilot) = {pilot_score}, oom_score(payload) = {payload_score} (attempted writing relative score 1 to {fname})') - logger.info(f'oom_score(pilot) = {pilot_score}, oom_score(payload) = {payload_score}') + Args: + pid: The PID of the process. + value: The value to write to the oom_score_adj file. + """ + command = f"echo {value} > /proc/{pid}/oom_score_adj" + try: + subprocess.check_call(command, shell=True) + logger.info(f"successfully wrote {value} to /proc/{pid}/oom_score_adj") + except subprocess.CalledProcessError as e: + logger.warning(f"error writing to /proc/{pid}/oom_score_adj: {e}") + ec, stdout, stderr = execute(command) + logger.debug(f"ec = {ec} stdout = {stdout}\nstderr = {stderr}") + _, stdout, _ = execute(f"cat /proc/{pid}/oom_score_adj") def get_score(pid) -> str: diff --git a/pilot/util/parameters.py b/pilot/util/parameters.py index 46f47931..614992be 100644 --- a/pilot/util/parameters.py +++ b/pilot/util/parameters.py @@ -17,15 +17,18 @@ # under the License. # # Authors: -# - Paul Nilsson, paul.nilsson@cern.ch, 2017-23 +# - Paul Nilsson, paul.nilsson@cern.ch, 2017-24 # This module contains functions that are used with the get_parameters() function defined in the information module. # WARNING: IN GENERAL, NEEDS TO USE PLUG-IN MANAGER +import logging +from typing import Any + from pilot.info import infosys +from pilot.util.config import config -import logging logger = logging.getLogger(__name__) @@ -40,7 +43,6 @@ def get_maximum_input_sizes(): try: _maxinputsizes = infosys.queuedata.maxwdir # normally 14336+2000 MB except TypeError as exc: - from pilot.util.config import config _maxinputsizes = config.Pilot.maximum_input_file_sizes # MB logger.warning(f'could not convert schedconfig value for maxwdir: {exc} (will use default value instead - {_maxinputsizes})') @@ -49,25 +51,24 @@ def get_maximum_input_sizes(): try: _maxinputsizes = int(_maxinputsizes) - except Exception as exc: + except (ValueError, TypeError) as exc: _maxinputsizes = 14336 + 2000 logger.warning(f'failed to convert maxinputsizes to int: {exc} (using value: {_maxinputsizes} MB)') return _maxinputsizes -def convert_to_int(parameter, default=None): +def convert_to_int(parameter: Any, default: Any = None) -> Any: """ Try to convert a given parameter to an integer value. + The default parameter can be used to force the function to always return a given value in case the integer conversion, int(parameter), fails. - :param parameter: parameter (any type). - :param default: None by default (if set, always return an integer; the given value will be returned if - conversion to integer fails). - :return: converted integer. + :param parameter: parameter (Any) + :param default: None by default (Any) + :return: converted integer (Any). """ - try: value = int(parameter) except (ValueError, TypeError): diff --git a/pilot/util/psutils.py b/pilot/util/psutils.py index eb70f263..ba18e1b1 100644 --- a/pilot/util/psutils.py +++ b/pilot/util/psutils.py @@ -291,3 +291,33 @@ def find_process_by_jobid(jobid: int) -> int or None: return proc.pid return None + + +def find_actual_payload_pid(bash_pid: int, payload_cmd: str) -> int or None: + """ + Find the actual payload PID. + + Identify all subprocesses of the given bash PID and search for the payload command. Return its PID. + + :param bash_pid: bash PID (int) + :param payload_cmd: payload command (partial) (str) + :return: payload PID (int or None). + """ + if not _is_psutil_available: + logger.warning('find_actual_payload_pid(): psutil not available - aborting') + return None + + children = get_subprocesses(bash_pid) + if not children: + logger.warning(f'no children found for bash PID {bash_pid}') + return bash_pid + + for pid in children: + cmd = get_command_by_pid(pid) + logger.debug(f'pid={pid} cmd={cmd}') + if payload_cmd in cmd: + logger.info(f'found payload PID={pid} for bash PID={bash_pid}') + return pid + + logger.warning(f'could not find payload PID for bash PID {bash_pid}') + return None diff --git a/pilot/util/timer.py b/pilot/util/timer.py index 22e28ee9..841d953b 100644 --- a/pilot/util/timer.py +++ b/pilot/util/timer.py @@ -18,7 +18,7 @@ # # Authors: # - Alexey Anisenkov, anisyonk@cern.ch, 2018 -# - Paul Nilsson, paul.nilsson@cern.ch, 2019-23 +# - Paul Nilsson, paul.nilsson@cern.ch, 2019-24 """ Standalone implementation of time-out check on function call. @@ -43,32 +43,32 @@ from pilot.util.auxiliary import TimeoutException -class TimedThread(object): +class TimedThread: """ Thread-based Timer implementation (`threading` module) (shared memory space, GIL limitations, no way to kill thread, Windows compatible) """ - def __init__(self, timeout): + def __init__(self, _timeout): """ :param timeout: timeout value for operation in seconds. """ - self.timeout = timeout + self.timeout = _timeout self.is_timeout = False def execute(self, func, args, kwargs): try: ret = (True, func(*args, **kwargs)) - except Exception: + except (TypeError, ValueError, AttributeError, KeyError): ret = (False, sys.exc_info()) self.result = ret return ret - def run(self, func, args, kwargs, timeout=None): + def run(self, func, args, kwargs, _timeout=None): """ :raise: TimeoutException if timeout value is reached before function finished """ @@ -78,30 +78,30 @@ def run(self, func, args, kwargs, timeout=None): thread.start() - timeout = timeout if timeout is not None else self.timeout + _timeout = _timeout if _timeout is not None else self.timeout try: - thread.join(timeout) - except Exception as exc: + thread.join(_timeout) + except (RuntimeError, KeyboardInterrupt) as exc: print(f'exception caught while joining timer thread: {exc}') if thread.is_alive(): self.is_timeout = True - raise TimeoutException("Timeout reached", timeout=timeout) + raise TimeoutException("Timeout reached", timeout=_timeout) ret = self.result if ret[0]: return ret[1] - else: - try: - _r = ret[1][0](ret[1][1]).with_traceback(ret[1][2]) - except AttributeError: - exec("raise ret[1][0], ret[1][1], ret[1][2]") - raise _r + + try: + _r = ret[1][0](ret[1][1]).with_traceback(ret[1][2]) + except AttributeError: + exec("raise ret[1][0], ret[1][1], ret[1][2]") + raise _r -class TimedProcess(object): +class TimedProcess: """ Process-based Timer implementation (`multiprocessing` module). Uses shared Queue to keep result. (completely isolated memory space) @@ -110,22 +110,22 @@ class TimedProcess(object): Traceback data is printed to stderr """ - def __init__(self, timeout): + def __init__(self, _timeout): """ - :param timeout: timeout value for operation in seconds. + :param _timeout: timeout value for operation in seconds. """ - self.timeout = timeout + self.timeout = _timeout self.is_timeout = False - def run(self, func, args, kwargs, timeout=None): + def run(self, func, args, kwargs, _timeout=None): def _execute(func, args, kwargs, queue): try: ret = func(*args, **kwargs) queue.put((True, ret)) - except Exception as e: - print('Exception occurred while executing %s' % func, file=sys.stderr) + except (TypeError, ValueError, AttributeError, KeyError) as e: + print(f'exception occurred while executing {func}', file=sys.stderr) traceback.print_exc(file=sys.stderr) queue.put((False, e)) @@ -137,14 +137,14 @@ def _execute(func, args, kwargs, queue): process.daemon = True process.start() - timeout = timeout if timeout is not None else self.timeout + _timeout = _timeout if _timeout is not None else self.timeout try: - ret = queue.get(block=True, timeout=timeout) - except Empty: + ret = queue.get(block=True, timeout=_timeout) + except Empty as exc: self.is_timeout = True process.terminate() - raise TimeoutException("Timeout reached", timeout=timeout) + raise TimeoutException("Timeout reached", timeout=_timeout) from exc finally: while process.is_alive(): process.join(1) @@ -158,20 +158,20 @@ def _execute(func, args, kwargs, queue): if ret[0]: return ret[1] - else: - raise ret[1] + raise ret[1] Timer = TimedProcess -def timeout(seconds, timer=None): +def timeout(seconds: int, timer: Timer = None): """ Decorator for a function which causes it to timeout (stop execution) once passed given number of seconds. - :param timer: timer class (by default is Timer) - :raise: TimeoutException in case of timeout interrupt - """ + :param seconds: timeout value in seconds (int) + :param timer: timer class (None or Timer) + :raise: TimeoutException in case of timeout interrupt. + """ timer = timer or Timer def decorate(function):