Source code for ai2_kit.tool.batch
import os
import sys
import shutil
import shlex
from typing import Optional
from ai2_kit.core.util import ensure_dir, expand_globs, list_split
from ai2_kit.core.log import get_logger
logger = get_logger(__name__)
[docs]
class BatchTool:
"""
A toolkit to help generate batch scripts.
"""
[docs]
def run_cmd(self, *work_dirs: str, cmd: str):
"""
Run command in each work directory.
:param work_dirs: path or glob of work directories
:param cmd: command to run, use {work_dir} to represent the work directory
"""
paths = expand_globs(work_dirs)
for path in paths:
assert os.path.isdir(path), f'{path} is not a directory'
_cmd = f"cd {shlex.quote(path)} && {cmd.format(work_dir=path)}"
os.system(_cmd)
logger.info(f'Run command: {_cmd}')
return self
[docs]
def map_path(self, *sources: str, target: str, copy = False):
"""
Map source files or directory to target path, use link by default.
:param sources: path or glob of source files or directories
:param target: target path, use {i} to represent the index of the source path,
or {basename} to represent the basename of the source path
:param copy: use copy instead of link
"""
paths = expand_globs(sources)
for i, path in enumerate(paths):
target_path = target.format(i=i, basename=os.path.basename(path))
ensure_dir(target_path)
if copy:
if os.path.isdir(path):
shutil.copytree(path, target_path)
logger.info(f'Copy directory {path} to {target_path}')
else:
shutil.copy(path, target_path)
logger.info(f'Copy file {path} to {target_path}')
else:
path = os.path.abspath(path)
os.symlink(path, target_path)
logger.info(f'Link {path} to {target_path}')
return self
[docs]
def gen_batches(self, *work_dirs: str,
out: str,
cmd: Optional[str]=None,
concurrency: int = 1,
header_file: Optional[str]=None,
suppress_error: bool = False,
checkpoint: bool = True,
checkpoint_file: str = 'done.ckpt',
rel_path: bool = False,
):
"""
Generate batch scripts for each work directory.
This command will apply `cmd` to each work directory and generate batch scripts according to `concurrency`.
:param work_dirs: path or glob of work directories
:param out: path to write batch scripts, use {i} to represent the index of concurrent job
:param cmd: command to run, if None, will read from stdin, use {word_dir} to represent the word directory,
use {i} to represent the index of concurrent job
:param concurrency: number of concurrent jobs, decide the number of batch scripts to generate, if 0, will generate one batch script for each work directory
:param header_file: path to header file, will be added to the beginning of each batch script
:param suppress_error: if True, will add `set -e` to the beginning of each batch script
:param checkpoint: if True, will add checkpoint to each batch script, and skip the work directory if checkpoint exists
:param checkpoint_file: checkpoint file name
:param rel_path: if True, will use relative path in batch script
"""
_work_dirs = expand_globs(work_dirs)
# read cmd
if cmd is None:
cmd = sys.stdin.read()
# read template
header = '#!/bin/bash'
if header_file is not None:
with open(header_file, encoding='utf-8') as f:
header = f.read()
# generate batch scripts
if concurrency <= 0:
concurrency = len(_work_dirs)
for i, job_group in enumerate(list_split(_work_dirs, concurrency)):
batch = [ header ]
if not suppress_error:
batch.append('set -e')
batch.append("for work_dir in \\")
# generate batch script
for work_dir in job_group:
assert os.path.isdir(work_dir), f'{work_dir} is not a directory'
if not rel_path:
work_dir = os.path.abspath(work_dir)
batch.append(f' {shlex.quote(work_dir)} \\')
batch.extend([
f' ; do',
f' pushd $work_dir',
])
if checkpoint:
batch.extend([
f' if [ -f {checkpoint_file} ]; then',
f' echo "hit checkpoint, skip"',
f' continue',
f' fi',
])
batch.extend([
f' {cmd}',
f' touch {checkpoint_file}',
f' popd',
f' done'
])
# write batch script
out_path = out.format(i=i)
ensure_dir(out_path)
with open(out_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(batch))
logger.info(f'Write batch script to {out_path}')
return self
def __str__(self) -> str:
# suppress fire help message
return ''