Source code for ai2_kit.tool.hpc

from ai2_kit.core.util import expand_globs
from ai2_kit.core.log import get_logger
from ai2_kit.core.job import JobState
from ai2_kit.core.cmd import CmdGroup

import datetime
import shlex
import time
import os
import re


logger = get_logger(__name__)


[docs] def append_if_not_exist(fp, line: str, feat_str = None): if feat_str is None: feat_str = line for line in fp: if feat_str in line: return fp.seek(0, os.SEEK_END) fp.write(f'\n{line}\n')
[docs] class Slurm: _state_table = { 'PD': JobState.PENDING, 'R': JobState.RUNNING, 'CA': JobState.CANCELLED, 'CF': JobState.PENDING, # (configuring), 'CG': JobState.RUNNING, # (completing), 'CD': JobState.COMPLETED, 'F': JobState.FAILED, # (failed), 'TO': JobState.TIMEOUT, # (timeout), 'NF': JobState.FAILED, # (node failure), 'RV': JobState.FAILED, # (revoked) and 'SE': JobState.FAILED # (special exit state) } def __init__(self, sbatch_bin = 'sbatch', squeue_bin = 'squeue', scancel_bin='scancel') -> None: self._sbatch_bin = sbatch_bin self._squeue_bin = squeue_bin self._scancel_bin = scancel_bin self._job_states = {}
[docs] def submit(self, *path_or_glob: str): """ Submit multiple Slurm script at once. :param path_or_glob: path or glob of Slurm script """ files = expand_globs(path_or_glob, raise_invalid=True) if not files: raise ValueError('No files found') for file in files: with open(file, '+a') as fp: append_if_not_exist(fp, 'touch slurm_$SLURM_JOB_ID.done # AUTO GENERATED') try: for file in files: job_id = self._submit(file) logger.info(f'Submitted job {job_id} for {file}') self._job_states[job_id] = JobState.UNKNOWN except: self._cancel() raise return self
[docs] def wait(self, timeout=3600 * 24 * 7, ignore_error = False, fast_fail = False, interval = 10): """ Wait until all jobs are finished. :param timeout: timeout in seconds :param ignore_error: ignore error :param fast_fail: exit if any job failed :param interval: interval in seconds """ fail_cnt = 0 start_at = datetime.datetime.now() while (datetime.datetime.now() - start_at).total_seconds() < timeout: try: self._update_job_states() logger.info('Job states: %s', self._job_states) fail_cnt = 0 except Exception: fail_cnt += 1 logger.exception('Failed to update job states') if fail_cnt > 5: # stop if keep failing raise if all(state.terminal for state in self._job_states.values()): break if fast_fail and self._is_any_failed(): self._cancel() raise RuntimeError('Fast fail!') time.sleep(interval) else: logger.error('Timeout') if not ignore_error: raise RuntimeError('Timeout') if not ignore_error and self._is_any_failed(): raise RuntimeError('Some jobs failed')
def _is_any_failed(self): return any(state == JobState.FAILED for state in self._job_states.values()) def _update_job_states(self): query_cmd = f"{self._squeue_bin} --noheader --format='%i %t' -u $USER" fp = os.popen(query_cmd) out = fp.read() exit_code = fp.close() if exit_code is not None: raise RuntimeError(f'Failed to query job states: {exit_code}') state = {} for line in out.splitlines(): if line: job_id, slurm_state = line.split() state[job_id] = slurm_state for job_id in self._job_states: if job_id in state: self._job_states[job_id] = self._state_table[state[job_id]] else: if os.path.exists(f'slurm_{job_id}.done'): self._job_states[job_id] = JobState.COMPLETED else: self._job_states[job_id] = JobState.FAILED def _cancel(self): for job_id in self._job_states: os.system(f'{self._scancel_bin} {job_id}') def _submit(self, file: str): submit_cmd = f'{self._sbatch_bin} {shlex.quote(file)}' with os.popen(submit_cmd) as fp: stdout = fp.read() m = re.search(r'Submitted batch job (\d+)', stdout) if m is None: raise ValueError(f'Failed to submit job: {stdout}') job_id = m.group(1) return job_id
cmd_entry = CmdGroup(items={ 'slurm': Slurm, }, doc='Tools to facilitate HPC related tasks.')