Source code for ai2_kit.workflow.cll_mlp

from ai2_kit.core.executor import BaseExecutorConfig
from ai2_kit.core.artifact import ArtifactMap
from ai2_kit.core.log import get_logger
from ai2_kit.core.util import load_yaml_files, merge_dict
from ai2_kit.core.resource_manager import ResourceManager
from ai2_kit.core.checkpoint import set_checkpoint_dir, apply_checkpoint
from ai2_kit.core.pydantic import BaseModel
from ai2_kit.domain import (
    deepmd,
    iface,
    lammps,
    lasp,
    selector,
    cp2k,
    vasp,
    constant as const,
    updater,
    anyware,

    lammps as _lammps,
    lasp as _lasp,
    cp2k as _cp2k,
    vasp as _vasp,
    anyware as _anyware,
)




from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from fire import Fire

import asyncio
import itertools
import copy
import os

logger = get_logger(__name__)


[docs] class CllWorkflowExecutorConfig(BaseExecutorConfig):
[docs] class Context(BaseModel):
[docs] class Train(BaseModel): deepmd: deepmd.CllDeepmdContextConfig
[docs] class Explore(BaseModel): lammps: Optional[_lammps.CllLammpsContextConfig] = None lasp: Optional[_lasp.CllLaspContextConfig] = None anyware: Optional[_anyware.AnywareContextConfig] = None
[docs] class Label(BaseModel): cp2k: Optional[_cp2k.CllCp2kContextConfig] = None vasp: Optional[_vasp.CllVaspContextConfig] = None
train: Train explore: Explore label: Label
context: Context
[docs] class WorkflowConfig(BaseModel):
[docs] class General(BaseModel): type_map: List[str] mass_map: List[float] sel_type: Optional[List[str]] = None max_iters: int = 1 mode: iface.TRAINING_MODE = 'default' update_explore_systems: bool = False
[docs] class Label(BaseModel): cp2k: Optional[_cp2k.CllCp2kInputConfig] = None vasp: Optional[_vasp.CllVaspInputConfig] = None
[docs] class Train(BaseModel): deepmd: deepmd.CllDeepmdInputConfig
[docs] class Explore(BaseModel): lammps: Optional[_lammps.CllLammpsInputConfig] = None lasp: Optional[_lasp.CllLaspInputConfig] = None anyware: Optional[_anyware.AnywareConfig] = None
[docs] class Select(BaseModel): model_devi: selector.CllModelDeviSelectorInputConfig
[docs] class Update(BaseModel): walkthrough: updater.CllWalkthroughUpdaterInputConfig
general: General train: Train explore: Explore select: Select label: Label update: Update
[docs] class CllWorkflowConfig(BaseModel): executors: Dict[str, CllWorkflowExecutorConfig] artifacts: ArtifactMap workflow: Any # Keep it raw here, it should be parsed later in iteration
[docs] def run_workflow(*config_files, executor: Optional[str] = None, path_prefix: Optional[str] = None, checkpoint: Optional[str] = None): """ Run Closed-Loop Learning (CLL) workflow to train Machine Learning Potential (MLP) models. Args: config_files: path of config files, should be yaml files, can be multiple, support glob pattern executor: name of executor, should be defined in config `executors` section path_prefix: path prefix for output checkpoint: checkpoint file """ if checkpoint is not None: set_checkpoint_dir(checkpoint) config_data = load_yaml_files(*config_files) config = CllWorkflowConfig.parse_obj(config_data) if executor not in config.executors: raise ValueError(f'executor {executor} is not found') if path_prefix is None: raise ValueError('path_prefix should not be empty') iface.init_artifacts(config.artifacts) resource_manager = ResourceManager( executor_configs=config.executors, artifacts=config.artifacts, default_executor=executor, ) return asyncio.run(cll_mlp_training_workflow(config, resource_manager, executor, path_prefix))
[docs] async def cll_mlp_training_workflow(config: CllWorkflowConfig, resource_manager: ResourceManager, executor: str, path_prefix: str): context_config = config.executors[executor].context raw_workflow_config = copy.deepcopy(config.workflow) # output of each step label_output: Optional[iface.ICllLabelOutput] = None selector_output: Optional[iface.ICllSelectorOutput] = None train_output: Optional[iface.ICllTrainOutput] = None explore_output: Optional[iface.ICllExploreOutput] = None # cursor of update table update_cursor = 0 # Start iteration for i in itertools.count(0): # parse workflow config workflow_config = WorkflowConfig.parse_obj(raw_workflow_config) shared_vars = precondition(workflow_config) if i >= workflow_config.general.max_iters: logger.info(f'Iteration {i} exceeds max_iters, stop iteration.') break # shortcut for type_map and mass_map type_map = workflow_config.general.type_map mass_map = workflow_config.general.mass_map # decide path prefix for each iteration iter_path_prefix = os.path.join(path_prefix, f'iters-{i:03d}') # prefix of checkpoint cp_prefix = f'iters-{i:03d}' # label if workflow_config.label.cp2k and context_config.label.cp2k: cp2k_input = cp2k.CllCp2kInput( config=workflow_config.label.cp2k, mode=workflow_config.general.mode, type_map=type_map, system_files=[] if selector_output is None else selector_output.get_model_devi_dataset(), initiated=i > 0, ) cp2k_context = cp2k.CllCp2kContext( config=context_config.label.cp2k, path_prefix=os.path.join(iter_path_prefix, 'label-cp2k'), resource_manager=resource_manager, ) label_output = await apply_checkpoint(f'{cp_prefix}/label-cp2k')(cp2k.cll_cp2k)(cp2k_input, cp2k_context) elif workflow_config.label.vasp and context_config.label.vasp: vasp_input = vasp.CllVaspInput( config=workflow_config.label.vasp, type_map=type_map, system_files=[] if selector_output is None else selector_output.get_model_devi_dataset(), initiated=i > 0, ) vasp_context = vasp.CllVaspContext( config=context_config.label.vasp, path_prefix=os.path.join(iter_path_prefix, 'label-vasp'), resource_manager=resource_manager, ) label_output = await apply_checkpoint(f'{cp_prefix}/label-vasp')(vasp.cll_vasp)(vasp_input, vasp_context) else: raise ValueError('No label method is specified') # return if no new data is generated if i > 0 and len(label_output.get_labeled_system_dataset()) == 0: logger.info("No new data is generated, stop iteration.") break # train if workflow_config.train.deepmd: deepmd_input = deepmd.CllDeepmdInput( config=workflow_config.train.deepmd, mode=workflow_config.general.mode, type_map=type_map, old_dataset=[] if train_output is None else train_output.get_training_dataset(), new_dataset=label_output.get_labeled_system_dataset(), sel_type=shared_vars.dp_sel_type, previous=[] if train_output is None else train_output.get_mlp_models(), ) deepmd_context = deepmd.CllDeepmdContext( path_prefix=os.path.join(iter_path_prefix, 'train-deepmd'), config=context_config.train.deepmd, resource_manager=resource_manager, ) train_output = await apply_checkpoint(f'{cp_prefix}/train-deepmd')(deepmd.cll_deepmd)(deepmd_input, deepmd_context) else: raise ValueError('No train method is specified') # explore new_explore_system_files = [] if workflow_config.general.update_explore_systems and selector_output is not None: new_explore_system_files = selector_output.get_new_explore_systems() if workflow_config.explore.lammps and context_config.explore.lammps: lammps_input = lammps.CllLammpsInput( config=workflow_config.explore.lammps, mode=workflow_config.general.mode, type_map=type_map, mass_map=mass_map, dp_models={'': train_output.get_mlp_models()}, preset_template='default', new_system_files=new_explore_system_files, dp_modifier=shared_vars.dp_modifier, dp_sel_type=shared_vars.dp_sel_type, ) lammps_context = lammps.CllLammpsContext( path_prefix=os.path.join(iter_path_prefix, 'explore-lammps'), config=context_config.explore.lammps, resource_manager=resource_manager, ) explore_output = await apply_checkpoint(f'{cp_prefix}/explore-lammps')(lammps.cll_lammps)(lammps_input, lammps_context) elif workflow_config.explore.lasp and context_config.explore.lasp: lasp_input = lasp.CllLaspInput( config=workflow_config.explore.lasp, type_map=type_map, mass_map=mass_map, models=train_output.get_mlp_models(), new_system_files=new_explore_system_files, ) lasp_context = lasp.CllLaspContext( config=context_config.explore.lasp, path_prefix=os.path.join(iter_path_prefix, 'explore-lasp'), resource_manager=resource_manager, ) explore_output = await apply_checkpoint(f'{cp_prefix}/explore-lasp')(lasp.cll_lasp)(lasp_input, lasp_context) elif workflow_config.explore.anyware and context_config.explore.anyware: anyware_input = anyware.AnywareInput( config=workflow_config.explore.anyware, type_map=type_map, mass_map=mass_map, new_system_files=new_explore_system_files, dp_models={'': train_output.get_mlp_models()}, ) anyware_context = anyware.AnywareContext( config=context_config.explore.anyware, path_prefix=os.path.join(iter_path_prefix, 'explore-anyware'), resource_manager=resource_manager, ) explore_output = await apply_checkpoint(f'{cp_prefix}/explore-anyware')(anyware.anyware)(anyware_input, anyware_context) else: raise ValueError('No explore method is specified') # select if workflow_config.select.model_devi: selector_input = selector.CllModelDeviSelectorInput( config=workflow_config.select.model_devi, model_devi_data=explore_output.get_model_devi_dataset(), model_devi_file=const.MODEL_DEVI_OUT, type_map=type_map, ) selector_context = selector.CllModelDevSelectorContext( path_prefix=os.path.join(iter_path_prefix, 'selector-model-devi'), resource_manager=resource_manager, ) selector_output = await apply_checkpoint(f'{cp_prefix}/selector-model-devi')(selector.cll_model_devi_selector)(selector_input, selector_context) else: raise ValueError('No select method is specified') # Update update_config = workflow_config.update.walkthrough # nothing to update because the table is empty if not update_config.table: continue # keep using the latest config when it reach the end of table if update_cursor >= len(update_config.table): continue # move cursor to next row if passing rate pass threshold if selector_output.get_passing_rate() > update_config.passing_rate_threshold: raw_workflow_config = merge_dict(copy.deepcopy( config.workflow), update_config.table[update_cursor]) update_cursor += 1
[docs] @dataclass class SharedVars(): dp_modifier: Optional[dict] = None dp_sel_type: Optional[List[int]] = None
[docs] def precondition(workflow_cfg: WorkflowConfig) -> SharedVars: """ precondition of workflow config to raise error early, and extra variables that may shared by multiple steps The known shared variables are: dp_modifier, which include vars sys_charge_map, model_charge_map, ewald_h, ewald_beta sel_type, which is suppose to be used in dplr/dpff mode """ shared_vars = SharedVars() mode = workflow_cfg.general.mode type_map = workflow_cfg.general.type_map sel_type = workflow_cfg.general.sel_type if mode == 'dpff': assert sel_type is not None, 'sel_type should be specified in general config for dpff mode' shared_vars.dp_sel_type = [ type_map.index(t) for t in sel_type ] deepmd_cfg = workflow_cfg.train.deepmd if deepmd_cfg is not None: if mode == 'dpff': modifier = deepmd_cfg.input_template['model'].get('modifier') assert modifier is not None, 'modifier should be specified in deepmd input template for dpff mode' shared_vars.dp_modifier = modifier elif mode == 'fep-redox': assert deepmd_cfg.input_template['model']['fitting_net']['numb_fparam'] == 1, 'numb_fparam should be 1 for fep-redox/fep-pka mode' lammps_cfg = workflow_cfg.explore.lammps if lammps_cfg is not None: if mode == 'dpff': lammps_cfg.assert_var('EFIELD') lammps_cfg.assert_var('KMESH') efield = lammps_cfg.explore_vars.get('EFIELD', lammps_cfg.broadcast_vars.get('EFIELD')) assert all([isinstance(item, list) for item in efield ]), 'EFIELD should be a list of vector' # type: ignore elif mode in ['fep-redox', 'fep-pka']: lammps_cfg.assert_var('LAMBDA_f') return shared_vars
if __name__ == '__main__': # use python-fire to parse command line arguments Fire(run_workflow)