from typing import Optional, Dict
from abc import ABC, abstractmethod
from collections import defaultdict
import invoke
import shlex
import os
import re
import time
import asyncio
import json
from .connector import BaseConnector
from .log import get_logger
from .job import JobFuture, JobState
from .util import short_hash
from .pydantic import BaseModel
logger = get_logger(__name__)
[docs]
class QueueSystemConfig(BaseModel):
[docs]
class Slurm(BaseModel):
sbatch_bin: str = 'sbatch'
squeue_bin: str = 'squeue'
scancel_bin: str = 'scancel'
polling_interval: int = 10
[docs]
class LSF(BaseModel):
bsub_bin: str = 'bsub'
bjobs_bin: str = 'bjobs'
polling_interval: int = 10
[docs]
class PBS(BaseModel):
qsub_bin: str = 'qsub'
qstat_bin: str = 'qstat'
qdel_bin: str = 'qdel'
slurm: Optional[Slurm] = None
lsf: Optional[LSF] = None
pbs: Optional[PBS] = None
[docs]
class BaseQueueSystem(ABC):
connector: BaseConnector
[docs]
def get_polling_interval(self) -> int:
return 10
[docs]
def get_setup_script(self) -> str:
return ''
[docs]
@abstractmethod
def get_script_suffix(self) -> str:
...
[docs]
@abstractmethod
def get_submit_cmd(self) -> str:
...
[docs]
@abstractmethod
def get_job_id_pattern(self) -> str:
...
[docs]
@abstractmethod
def get_job_id_envvar(self) -> str:
...
[docs]
@abstractmethod
def get_job_state(self, job_id: str, success_indicator_path: str) -> JobState:
...
[docs]
@abstractmethod
def cancel(self, job_id: str):
...
def _post_submit(self, job: 'QueueJobFuture'):
...
[docs]
def submit(self,
script: str,
cwd: str,
name: Optional[str] = None,
success_indicator: Optional[str] = None,
):
# use hash instead of uuid to ensure idempotence
if name is None:
name = 'job-' + short_hash(script) + self.get_script_suffix()
quoted_cwd = shlex.quote(cwd)
# a placeholder file that will be created when the script end without error
if success_indicator is None:
success_indicator = name + '.success'
running_indicator = name + '.running'
inject_cmds = '\n'.join([
self.get_setup_script(),
'',
])
script = inject_cmd_to_script(script, inject_cmds)
# create script and add a command to write job id to success indicator
script = '\n'.join([
script,
'',
f'echo ${self.get_job_id_envvar()} > {shlex.quote(success_indicator)}',
'',
])
script_path = os.path.join(cwd, name)
self.connector.run(f'mkdir -p {quoted_cwd}')
self.connector.dump_text(script, script_path)
# submit script
cmd = f"cd {quoted_cwd} && {self.get_submit_cmd()} {shlex.quote(name)}"
# apply checkpoint
submit_cmd_fn = self._submit_cmd
# recover running job id
# TODO: refactor the following code as function
job_id, job_state = None, JobState.UNKNOWN
recover_cmd = f"cd {quoted_cwd} && cat {shlex.quote(running_indicator)}"
try:
job_id = self.connector.run(recover_cmd, hide=True).stdout.strip()
if job_id:
success_indicator_path = os.path.join(cwd, success_indicator)
job_state = self.get_job_state(job_id, success_indicator_path=success_indicator_path)
except:
pass
if job_id and job_state in (JobState.PENDING, JobState.RUNNING, JobState.COMPLETED):
logger.info(f"{script_path} has been submmited ({job_id}) and in {str(job_state)} state, continue!")
else:
logger.info(f'Submit batch script: {script_path}')
job_id = submit_cmd_fn(cmd)
# create running indicator
self.connector.dump_text(str(job_id), os.path.join(cwd, running_indicator))
job = QueueJobFuture(self,
job_id=job_id,
name=name,
script=script,
cwd=cwd,
success_indicator=success_indicator,
polling_interval=self.get_polling_interval() // 2,
)
self._post_submit(job)
return job
def _submit_cmd(self, cmd: str):
result = self.connector.run(cmd)
m = re.search(self.get_job_id_pattern(), result.stdout)
if m is None:
raise RuntimeError("Unable to parse job id")
return m.group(1)
[docs]
class Slurm(BaseQueueSystem):
config: QueueSystemConfig.Slurm
_last_states = defaultdict(lambda: JobState.UNKNOWN)
_last_update_at: float = 0
translate_table = {
'PD': JobState.PENDING,
'R': JobState.RUNNING,
'CA': JobState.CANCELLED,
'CF': JobState.PENDING, # (configuring),
'CG': JobState.RUNNING, # (completing),
'CD': JobState.COMPLETED,
'F': JobState.FAILED, # (failed),
'TO': JobState.TIMEOUT, # (timeout),
'NF': JobState.FAILED, # (node failure),
'RV': JobState.FAILED, # (revoked) and
'SE': JobState.FAILED # (special exit state)
}
[docs]
def get_polling_interval(self):
return self.config.polling_interval
[docs]
def get_script_suffix(self):
return '.sbatch'
[docs]
def get_submit_cmd(self):
return self.config.sbatch_bin
[docs]
def get_job_id_pattern(self):
# example: Submitted batch job 123
return r"Submitted batch job\s+(\d+)"
[docs]
def get_job_id_envvar(self) -> str:
return 'SLURM_JOB_ID'
[docs]
def get_job_state(self, job_id: str, success_indicator_path: str) -> JobState:
state = self._get_all_states().get(job_id)
if state is None:
cmd = 'test -f {}'.format(shlex.quote(success_indicator_path))
ret = self.connector.run(cmd, warn=True)
if ret.return_code:
return JobState.FAILED
else:
return JobState.COMPLETED
else:
return state
[docs]
def cancel(self, job_id: str):
cmd = f'{self.config.scancel_bin} {job_id}'
self.connector.run(cmd)
def _post_submit(self, job: 'QueueJobFuture'):
self._last_update_at = 0
def _translate_state(self, slurm_state: str) -> JobState:
return self.translate_table.get(slurm_state, JobState.UNKNOWN)
def _get_all_states(self) -> Dict[str, JobState]:
current_ts = time.time()
if (current_ts - self._last_update_at) < self.get_polling_interval():
return self._last_states
# call squeue to get all states
cmd = f"{self.config.squeue_bin} --noheader --format='%i %t' -u $USER"
try:
r = self.connector.run(cmd, hide=True)
except invoke.exceptions.UnexpectedExit as e:
logger.warning(f'Error when calling squeue: {e}')
return self._last_states
states: Dict[str, JobState] = dict()
for line in r.stdout.splitlines():
if not line: # skip empty line
continue
job_id, slurm_state = line.split()
state = self._translate_state(slurm_state)
states[job_id] = state
# update cache
self._last_update_at = current_ts
self._last_states = states
return states
[docs]
class Lsf(BaseQueueSystem):
config: QueueSystemConfig.LSF
[docs]
def get_polling_interval(self):
return self.config.polling_interval
[docs]
def get_script_suffix(self):
return '.lsf'
[docs]
def get_submit_cmd(self):
return self.config.bsub_bin + ' <'
[docs]
def get_job_id_pattern(self):
# example: Job <123> is submitted to queue <small>.
return r"Job <(\d+)> is submitted to queue"
[docs]
def get_job_id_envvar(self) -> str:
return 'LSB_JOBID'
# TODO
[docs]
def get_job_state(self, job_id: str, success_indicator_path: str) -> JobState:
return JobState.UNKNOWN
# TODO
[docs]
def cancel(self, job_id: str):
...
def _get_all_states(self) -> Dict[str, JobState]:
...
[docs]
class PBS(BaseQueueSystem):
config: QueueSystemConfig.PBS
translate_table = {
'B': JobState.RUNNING, # This state is returned for running array jobs
'R': JobState.RUNNING,
'C': JobState.COMPLETED, # Completed after having run
'E': JobState.COMPLETED, # Exiting after having run
'H': JobState.HELD, # Held
'Q': JobState.PENDING, # Queued, and eligible to run
'W': JobState.PENDING, # Job is waiting for it's execution time (-a option) to be reached
'S': JobState.HELD # Suspended
}
_last_states = defaultdict(lambda: JobState.UNKNOWN)
_last_update_at: float = 0
[docs]
def get_setup_script(self) -> str:
return 'cd $PBS_O_WORKDIR'
[docs]
def get_script_suffix(self) -> str:
return '.pbs'
[docs]
def get_submit_cmd(self) -> str:
return self.config.qsub_bin
[docs]
def get_job_id_pattern(self) -> str:
return r"(.+)"
[docs]
def get_job_id_envvar(self) -> str:
return 'PBS_JOBID'
[docs]
def cancel(self, job_id: str):
cmd = f'{self.config.qdel_bin} {job_id}'
self.connector.run(cmd)
def _post_submit(self, job: 'QueueJobFuture'):
self._last_update_at = 0 # force update stats
[docs]
def get_job_state(self, job_id: str, success_indicator_path: str) -> JobState:
state = self._get_all_states().get(job_id)
if state is None:
cmd = 'test -f {}'.format(shlex.quote(success_indicator_path))
ret = self.connector.run(cmd, warn=True)
if ret.return_code:
return JobState.FAILED
else:
return JobState.COMPLETED
else:
return state
def _get_all_states(self) -> Dict[str, JobState]:
current_ts = time.time()
if (current_ts - self._last_update_at) < self.get_polling_interval():
return self._last_states
cmd = f"{self.config.qstat_bin} -f -F json"
try:
r = self.connector.run(cmd, hide=True)
except invoke.exceptions.UnexpectedExit as e:
logger.warning(f'Error when calling qstat: {e}')
return self._last_states
states: Dict[str, JobState] = dict()
qstat_json = json.loads(r.stdout)
for job_id, job in qstat_json.get('Jobs', dict()).items():
states[job_id] = self._translate_state(job['job_state'])
self._last_states = states
self._last_update_at = current_ts
return states
def _translate_state(self, slurm_state: str) -> JobState:
return self.translate_table.get(slurm_state, JobState.UNKNOWN)
[docs]
class QueueJobFuture(JobFuture):
def __init__(self,
queue_system: BaseQueueSystem,
job_id: str,
script: str,
cwd: str,
name: str,
success_indicator: str,
polling_interval=10,
):
self._queue_system = queue_system
self._name = name
self._script = script
self._cwd = cwd
self._job_id = job_id
self._success_indicator = success_indicator
self._polling_interval = polling_interval
self._final_state = None
@property
def success_indicator_path(self):
return os.path.join(self._cwd, self._success_indicator)
[docs]
def get_job_state(self):
if self._final_state is not None:
return self._final_state
state = self._queue_system.get_job_state(
self._job_id, self.success_indicator_path)
if state.terminal:
self._final_state = state
return state
[docs]
def resubmit(self):
if not self.done():
raise RuntimeError('Cannot resubmit an unfinished job!')
logger.info(f'Resubmit job: {self._job_id}')
return self._queue_system.submit(
script=self._script,
cwd=self._cwd,
name=self._name,
success_indicator=self._success_indicator,
)
[docs]
def is_success(self):
return self.get_job_state() is JobState.COMPLETED
[docs]
def cancel(self):
self._queue_system.cancel(self._job_id)
[docs]
def done(self):
return self.get_job_state().terminal
[docs]
def result(self, timeout: float = float('inf')) -> JobState:
return asyncio.run(self.result_async(timeout))
[docs]
async def result_async(self, timeout: float = float('inf')) -> JobState:
'''
Though this is not fully async, as the job submission and state polling are still blocking,
but it is already good enough to handle thousands of jobs (I guess).
'''
timeout_ts = time.time() + timeout
while time.time() < timeout_ts:
if self.done():
return self.get_job_state()
else:
await asyncio.sleep(self._polling_interval)
else:
raise RuntimeError(f'Timeout of polling job: {self._job_id}')
def __repr__(self):
return repr(dict(
name=self._name,
cwd=self._cwd,
job_id=self._job_id,
success_indicator=self._success_indicator,
polling_interval=self._polling_interval,
state=self.get_job_state(),
))
[docs]
def inject_cmd_to_script(script: str, cmd: str):
"""
Find the position of first none comment or empty lines,
and inject command before it
"""
lines = script.splitlines()
i = 0
for i, line in enumerate(lines):
line = line.strip()
if line and not line.startswith('#'):
break
lines.insert(i, cmd)
return '\n'.join(lines)