Source code for ai2_kit.core.job
from typing import List, Callable, Optional, Awaitable
from enum import Enum
from abc import abstractmethod
import time
import asyncio
from .future import IFuture
# Copy from parsl
[docs]
class JobState(bytes, Enum):
"""Defines a set of states that a job can be in"""
def __new__(cls, value: int, terminal: bool, status_name: str) -> "JobState":
obj = bytes.__new__(cls, [value])
obj._value_ = value
obj.terminal = terminal
obj.status_name = status_name
return obj
value: int
terminal: bool
status_name: str
UNKNOWN = (0, False, "UNKNOWN")
PENDING = (1, False, "PENDING")
RUNNING = (2, False, "RUNNING")
CANCELLED = (3, True, "CANCELLED")
COMPLETED = (4, True, "COMPLETED")
FAILED = (5, True, "FAILED")
TIMEOUT = (6, True, "TIMEOUT")
HELD = (7, False, "HELD")
[docs]
class TimeoutError(RuntimeError):
...
[docs]
class JobFuture(IFuture[JobState]):
[docs]
@abstractmethod
def get_job_state(self) -> JobState:
...
[docs]
@abstractmethod
def cancel(self):
...
[docs]
@abstractmethod
def is_success(self) -> bool:
...
[docs]
@abstractmethod
def resubmit(self) -> 'JobFuture':
...
[docs]
async def gather_jobs(jobs: List[JobFuture], timeout = float('inf'), max_tries: int = 1, raise_error=True) -> List[JobState]:
async def wait_job(job: JobFuture) -> JobState:
state = JobState.UNKNOWN
tries = 0
while True:
try:
state = await job.result_async(timeout)
if state is JobState.COMPLETED:
return state
except TimeoutError:
state = JobState.TIMEOUT
tries += 1
if tries >= max_tries:
break
job = job.resubmit()
if raise_error:
raise RuntimeError(f'Job {job} failed with state {state}')
else:
return state
return await asyncio.gather(*[wait_job(job) for job in jobs])