Source code for ai2_kit.core.executor

from typing import Optional, Dict, List, TypeVar, Callable, Mapping, Union
from abc import ABC, abstractmethod
from invoke import Result
import cloudpickle as cp
import tempfile
import tarfile
import os
import shlex
import base64
import bz2

from .queue_system import QueueSystemConfig, BaseQueueSystem, Slurm, Lsf, PBS
from .job import JobFuture
from .artifact import Artifact
from .connector import SshConfig, BaseConnector, SshConnector, LocalConnector
from .util import s_uuid
from .log import get_logger
from .pydantic import BaseModel


logger = get_logger(__name__)

[docs] class BaseExecutorConfig(BaseModel): ssh: Optional[SshConfig] = None queue_system: QueueSystemConfig work_dir: str python_cmd: str = 'python'
ExecutorMap = Mapping[str, BaseExecutorConfig] FnType = TypeVar('FnType', bound=Callable)
[docs] class Executor(ABC): name: str work_dir: str tmp_dir: str python_cmd: str
[docs] def init(self): ...
[docs] @abstractmethod def mkdir(self, path: str): ...
[docs] @abstractmethod def run_python_script(self, script: str, python_cmd=None): ...
[docs] @abstractmethod def run_python_fn(self, fn: FnType, python_cmd=None, cwd=None) -> FnType: ...
[docs] @abstractmethod def dump_text(self, text: str, path: str): ...
[docs] @abstractmethod def load_text(self, path: str) -> str: ...
[docs] @abstractmethod def glob(self, pattern: str) -> List[str]: ...
[docs] @abstractmethod def run(self, script: str, **kwargs) -> Result: ...
[docs] @abstractmethod def submit(self, script: str, **kwargs) -> JobFuture: ...
[docs] @abstractmethod def upload(self, from_path: str, to_dir: str) -> str: ...
[docs] @abstractmethod def download(self, from_path: str, to_dir: str) -> str: ...
[docs] @abstractmethod def resolve_artifact(self, artifact: Artifact) -> List[str]: ...
[docs] def setup_workspace(self, workspace_dir: str, dirs: List[str]): paths = [os.path.join(workspace_dir, dir) for dir in dirs] for path in paths : self.mkdir(path) logger.info('create path: %s', path) return paths
[docs] class HpcExecutor(Executor):
[docs] @classmethod def from_config(cls, config: Union[dict, BaseExecutorConfig], name: str = ''): if isinstance(config, dict): config = BaseExecutorConfig.parse_obj(config) if config.ssh: connector = SshConnector.from_config(config.ssh) else: connector = LocalConnector() queue_system = None if config.queue_system.slurm: queue_system = Slurm() queue_system.config = config.queue_system.slurm elif config.queue_system.lsf: queue_system = Lsf() queue_system.config = config.queue_system.lsf elif config.queue_system.pbs: queue_system = PBS() queue_system.config = config.queue_system.pbs if queue_system is None: raise ValueError('Queue system config is missing!') queue_system.connector = connector return cls(connector, queue_system, config.work_dir, config.python_cmd, name)
@property def is_local(self): return isinstance(self.connector, LocalConnector) def __init__(self, connector: BaseConnector, queue_system: BaseQueueSystem, work_dir: str, python_cmd: str, name: str): self.name = name self.connector = connector self.queue_system = queue_system self.work_dir = work_dir self.python_cmd = python_cmd self.tmp_dir = os.path.join(self.work_dir, '.tmp') self.python_pkgs_dir = os.path.join(self.tmp_dir, 'python_pkgs')
[docs] def init(self): # if work_dir is relative path, it will be relative to user home if not os.path.isabs(self.work_dir): user_home = self.run('echo $HOME', hide=True).stdout.strip() self.work_dir = os.path.normpath(os.path.join(user_home, self.work_dir)) self.mkdir(self.work_dir) self.mkdir(self.tmp_dir) self.mkdir(self.python_pkgs_dir) self.upload_python_pkg('ai2_kit')
[docs] def upload_python_pkg(self, pkg: str): """ upload python package to remote server """ pkg_path = os.path.dirname(__import__(pkg).__file__) with tempfile.NamedTemporaryFile(suffix='.tar.gz') as fp: with tarfile.open(fp.name, 'w:gz') as tar_fp: tar_fp.add(pkg_path, arcname=os.path.basename(pkg_path), filter=_filter_pyc_files) fp.flush() self.upload(fp.name, self.python_pkgs_dir) file_name = os.path.basename(fp.name) self.run(f'cd {shlex.quote(self.python_pkgs_dir)} && tar -xf {shlex.quote(file_name)}') logger.info('add python package: %s', pkg_path)
[docs] def mkdir(self, path: str): return self.connector.run('mkdir -p {}'.format(shlex.quote(path)))
[docs] def dump_text(self, text: str, path: str): return self.connector.dump_text(text, path)
# TODO: handle error properly
[docs] def load_text(self, path: str) -> str: return self.connector.run('cat {}'.format(shlex.quote(path)), hide=True).stdout
[docs] def glob(self, pattern: str): return self.connector.glob(pattern)
[docs] def run(self, script: str, **kwargs): return self.connector.run(script, **kwargs)
[docs] def run_python_script(self, script: str, python_cmd=None, cwd=None): if python_cmd is None: python_cmd = self.python_cmd if cwd is None: cwd = self.work_dir base_cmd = f'cd {shlex.quote(cwd)} && PYTHONPATH={shlex.quote(self.python_pkgs_dir)} ' script_len = len(script) logger.info('the size of generated python script is %s', script_len) if script_len < 100_000: # ssh connection will be closed of the size of command is too large return self.connector.run(f'{base_cmd} {python_cmd} -c {shlex.quote(script)}', hide=True) else: script_path = os.path.join(self.tmp_dir, f'run_python_script_{s_uuid()}.py') self.dump_text(script, script_path) ret = self.connector.run(f'{base_cmd} {python_cmd} {shlex.quote(script_path)}', hide=True) self.connector.run(f'rm {shlex.quote(script_path)}') return ret
[docs] def run_python_fn(self, fn: FnType, python_cmd=None, cwd=None) -> FnType: def remote_fn(*args, **kwargs): script = fn_to_script(fn, args, kwargs, delimiter='@') ret = self.run_python_script(script=script, python_cmd=python_cmd, cwd=cwd) _, r = ret.stdout.rsplit('@', 1) return cp.loads(bz2.decompress(base64.b64decode(r))) return remote_fn # type: ignore
[docs] def submit(self, script: str, cwd: str, **kwargs): return self.queue_system.submit(script, cwd=cwd, **kwargs)
[docs] def resolve_artifact(self, artifact: Artifact) -> List[str]: if artifact.includes is None: return [artifact.url] pattern = os.path.join(artifact.url, artifact.includes) return self.glob(pattern)
[docs] def upload(self, from_path: str, to_dir: str): return self.connector.upload(from_path, to_dir)
[docs] def download(self, from_path: str, to_dir: str): return self.connector.download(from_path, to_dir)
[docs] def create_executor(config: BaseExecutorConfig, name: str) -> Executor: if config.queue_system is not None: return HpcExecutor.from_config(config, name) raise RuntimeError('The executor configuration is not supported!')
[docs] class ExecutorManager: def __init__(self, executor_configs: Mapping[str, BaseExecutorConfig]): self._executor_configs = executor_configs self._executors: Dict[str, Executor] = dict()
[docs] def get_executor(self, name: str): config = self._executor_configs.get(name) if config is None: raise ValueError( 'Executor with name {} is not found!'.format(name)) if name not in self._executors: executor = create_executor(config, name) self._executors[name] = executor return self._executors[name]
[docs] def fn_to_script(fn: Callable, args, kwargs, delimiter='@'): script = [ f'''import base64,bz2,sys,cloudpickle as cp''', f'''fn,args,kwargs={pickle_converts((fn, args, kwargs))}''', 'r=fn(*args, **kwargs)', f'''sys.stdout.flush()''', # ensure all output is printed f'''print({repr(delimiter)}+base64.b64encode(bz2.compress(cp.dumps(r, protocol=cp.DEFAULT_PROTOCOL),5)).decode('ascii'))''', f'''sys.stdout.flush()''', # ensure all output is printed ] return ';'.join(script)
[docs] def pickle_converts(obj, pickle_module='cp', bz2_module='bz2', base64_module='base64'): """ convert an object to its pickle string form """ obj_pkl = cp.dumps(obj, protocol=cp.DEFAULT_PROTOCOL) compress_level = 5 if len(obj_pkl) > 4096 else 1 compressed = bz2.compress(obj_pkl, compress_level) obj_b64 = base64.b64encode(compressed).decode('ascii') return f'{pickle_module}.loads({bz2_module}.decompress({base64_module}.b64decode({repr(obj_b64)})))'
def _filter_pyc_files(tarinfo): if tarinfo.name.endswith('.pyc') or tarinfo.name.endswith('__pycache__'): return None return tarinfo