from ai2_kit.core.script import BashTemplate, BashStep, BashScript
from ai2_kit.core.artifact import Artifact
from ai2_kit.core.job import gather_jobs
from ai2_kit.core.util import list_split, dump_text
from ai2_kit.core.pydantic import BaseModel
from ai2_kit.core.connector import get_ln_cmd
from typing import List, Optional, Mapping, Any
from dataclasses import dataclass
from string import Template
import os
import itertools
import random
import ase.io
from .data import artifacts_to_ase_atoms_v2, DataFormat
from .iface import BaseCllContext, ICllExploreOutput
from .cp2k import dump_coord_n_cell
from .lammps import _get_dp_models_variables
[docs]
class AnywareContextConfig(BaseModel):
script_template: BashTemplate
concurrency: int = 5
[docs]
class AnywareConfig(BaseModel):
system_files: List[str]
"""
Artifact keys to the system
"""
template_files: Mapping[str, str]
"""
Templates files that will generate for each explore tasks,
You can use $$VAR_NAME to reference the variables defined in product_vars and broadcast_vars.
Besides, the following build-in variables are also available:
- SYSTEM_FILE: the path of the system file
- DP_MODELS: the path of the deep potential models, in the format of '1.pb 2.pb 3.pb 4.pb'
You can use literal string to define the template file,
or use !load_text to load the content from a file.
For example, if you define a template file named 'cp2k.inp' with the following content:
cp2k-warmup.inp: |
&GLOBAL
...
&END GLOBAL
cp2k.inp: !load_text cp2k.inp
"""
product_vars: Mapping[str, List[str]] = {}
"""
Define template variables by Cartesian product
The variable can be referenced in the template file with the following format:
If there are too many variables, it will generate a large number of tasks,
in this case, you can use broadcast_vars to reduce the number of tasks.
$$VAR_NAME
"""
broadcast_vars: Mapping[str, List[str]] = {}
"""
Define template variables by broadcast (broadcast as in numpy).
It's the same as product_vars, except that it will broadcast the variable to all other variables.
"""
system_file_name: str
"""
The name of the system file you want to generate,
for example, 'system.xyz', 'coord_n_cell.inc', etc
"""
system_file_format: str
"""
The format of the system file you want to generate,
for example, `lammps-data`, `cp2k-inc`, etc
For all supported data, you can refer to ase.io
https://wiki.fysik.dtu.dk/ase/ase/io/io.html
Custom formats:
- cp2k-inc: coord & cell in the format of CP2K include file, can be used in CP2K input file via `@include coord_n_cell.inc`
"""
submit_script: str
"""
A bash script that will be executed in each task directory to submit the task.
For example,
mpirun cp2k.popt -i cp2k.inp &> cp2k.out
"""
post_process_fn: Optional[str] = None
"""
A python function that will be executed after the task is finished.
You may use this function to post-process the results.
The function must named as `post_process_fn` and accept a list of task directories as input.
The below is an example of merging multiple file into one by keeping only the last line of each file.
post_process_fn: |
def post_process_fn(task_dirs):
import glob
for task_dir in task_dirs:
files = glob.glob(os.path.join(task_dir, '*.out')) # file to merge
with open(os.path.join(task_dir, 'merged.out'), 'w') as fp:
for file in files:
with open(file, 'r') as f:
lines = f.readlines()
if len(lines) > 0:
fp.write(lines[-1])
"""
delimiter: str = '$$'
"""
delimiter for template
"""
shuffle: bool = False
"""
shuffle the combination of system_files, product_vars and broadcast_vars
"""
[docs]
@dataclass
class AnywareContext(BaseCllContext):
config: AnywareContextConfig
[docs]
@dataclass
class AnywareOutput(ICllExploreOutput):
output_dirs: List[Artifact]
[docs]
def get_model_devi_dataset(self) -> List[Artifact]:
return self.output_dirs
[docs]
async def anyware(input: AnywareInput, ctx: AnywareContext) -> AnywareOutput:
executor = ctx.resource_manager.default_executor
work_dir = os.path.join(executor.work_dir, ctx.path_prefix)
if len(input.new_system_files) > 0:
data_files = input.new_system_files
else:
data_files = ctx.resource_manager.resolve_artifacts(input.config.system_files)
assert len(data_files) > 0, 'no data files found'
task_artifacts = executor.run_python_fn(make_anyware_task_dirs)(
work_dir=work_dir,
data_files=data_files,
dp_models={k: [m.url for m in v] for k, v in input.dp_models.items()},
type_map=input.type_map,
mass_map=input.mass_map,
product_vars=input.config.product_vars,
broadcast_vars=input.config.broadcast_vars,
template_files=input.config.template_files,
template_delimiter=input.config.delimiter,
system_file_name=input.config.system_file_name,
system_file_format=input.config.system_file_format,
shuffle=input.config.shuffle,
)
steps = []
for task_artifact in task_artifacts:
steps.append(BashStep(
cwd=task_artifact.url, cmd=input.config.submit_script, checkpoint='submit')
)
# # submit jobs by the number of concurrency
jobs = []
for i, steps_group in enumerate(list_split(steps, ctx.config.concurrency)):
if not steps_group:
continue
script = BashScript(
template=ctx.config.script_template,
steps=steps_group,
)
job = executor.submit(script.render(), cwd=work_dir)
jobs.append(job)
await gather_jobs(jobs, max_tries=2)
if input.config.post_process_fn:
executor.run_python_fn(run_post_process_fn)(
post_process_fn=input.config.post_process_fn,
task_dirs=[task.url for task in task_artifacts]
)
return AnywareOutput(output_dirs=task_artifacts)
[docs]
def run_post_process_fn(post_process_fn: str, task_dirs: List[str]):
_locals = {}
exec(post_process_fn, None, _locals)
_locals['post_process_fn'](task_dirs)
[docs]
def make_anyware_task_dirs(work_dir: str,
data_files: List[Artifact],
dp_models: Mapping[str, List[str]],
type_map: List[str],
mass_map: List[float],
product_vars: Mapping[str, List[str]],
broadcast_vars: Mapping[str, List[str]],
template_files: Mapping[str, str],
template_delimiter: str,
system_file_name: str,
system_file_format: str,
shuffle: bool,
):
class _Template(Template):
delimiter = template_delimiter
# handle data files
systems_dir = os.path.join(work_dir, 'systems')
os.makedirs(systems_dir, exist_ok=True)
atoms_list = []
atoms_list = artifacts_to_ase_atoms_v2(data_files)
system_artifacts = []
for i, (artifact, atoms) in enumerate(atoms_list):
ancestor = artifact.attrs['ancestor']
data_file = os.path.join(systems_dir, f'{ancestor}-{i:06d}-{system_file_name}')
if system_file_format == 'cp2k-inc':
with open(data_file, 'w') as fp:
dump_coord_n_cell(fp, atoms)
elif system_file_format == 'lammps-data':
ase.io.write(data_file, atoms, format=system_file_format, specorder=type_map) # type: ignore
else:
ase.io.write(data_file, atoms, format=system_file_format) # type: ignore
system_artifacts.append(Artifact(url=data_file, attrs=artifact.attrs))
if shuffle:
random.shuffle(system_artifacts)
# handle task dirs
combination_fields: List[str] = ['SYSTEM_FILE']
combination_values: List[List[Any]] = [system_artifacts]
for k, v in product_vars.items():
combination_fields.append(k)
if shuffle:
random.shuffle(v)
combination_values.append(v)
combinations = itertools.product(*combination_values)
combinations = list(map(list, combinations))
combination_fields.extend(broadcast_vars.keys())
for i, combination in enumerate(combinations):
for _vars in broadcast_vars.values():
combination.append(_vars[i % len(_vars)])
task_artifacts = []
tasks_base_dir = os.path.join(work_dir, 'tasks')
for i, combination in enumerate(combinations):
task_dir = os.path.join(tasks_base_dir, f'{i:06d}')
os.makedirs(task_dir, exist_ok=True)
template_vars = dict(zip(combination_fields, combination))
# link system_file to task_dir
system_artifact: Artifact = template_vars.pop('SYSTEM_FILE')
system_file = os.path.join(task_dir, system_file_name)
# the reason of not using os.symlink is that it will raise an error if the link already exists
os.system(get_ln_cmd(system_artifact.url, system_file))
template_vars['SYSTEM_FILE'] = system_file_name
# dp models variables
dp_vars = _get_dp_models_variables(dp_models)
# generate template files
for k, v in template_files.items():
template_file_path = os.path.join(task_dir, k)
dump_text(_Template(v).substitute(**template_vars, **dp_vars), template_file_path, encoding='utf-8')
task_artifacts.append(Artifact(url=task_dir,
attrs={**system_artifact.attrs, 'source': system_artifact.url},
format=DataFormat.ANYWARE_OUTPUT_DIR))
return task_artifacts