Source code for ai2_kit.main
from fire import Fire
import warnings
import logging
import os
from ai2_kit.core.cmd import CmdGroup
logger = logging.getLogger(__name__)
[docs]class AlgorithmGroup:
"""
Algorithms for specific domains.
"""
[docs] def proton_transfer(self):
"""
Proton transfer analysis toolkit.
"""
from ai2_kit.algorithm import proton_transfer
return proton_transfer.cli_entry
[docs] def aosa(self):
"""
Amorphous oxides structure analysis toolkit.
"""
from ai2_kit.algorithm import aos_analysis
return aos_analysis.cli_entry
[docs] def reweighting(self):
"""
Reweightning toolkit
"""
from ai2_kit.algorithm.reweighting import ReweightingTool
return ReweightingTool()
[docs]class WorkflowGroup:
"""
Workflows for specific domains.
"""
@property
def cll_mlp_training(self):
from ai2_kit.workflow.cll_mlp import run_workflow
return run_workflow
@property
def fep_mlp_training(self):
from ai2_kit.workflow.fep_mlp import run_workflow
return run_workflow
[docs]class FeatureGroup:
"""
Featuring tools for specific domains.
"""
@property
def catalysis(self):
"""
Catalyst specific tools.
"""
from ai2_kit.feat.catalysis import CmdEntries
return CmdEntries
@property
def cat(self):
"""
Shortcut for catalyst.
"""
return self.catalysis
@property
def spectr(self):
"""
Spectrum specific tools.
"""
from ai2_kit.feat.spectrum import CmdEntries
return CmdEntries
@property
def nmrnet(self):
"""
NMRNet specific tools.
"""
try:
from ai2_kit.algorithm.uninmr import CmdEntries
except ImportError:
logging.info('In order to use nmrnet, you need to ensure the following packages are installed: ')
logging.info('"rdkit", "scipy" and "unicore"')
raise
return CmdEntries
ai2_kit = CmdGroup({
'workflow': WorkflowGroup(),
'algorithm': AlgorithmGroup(),
'tool': ToolGroup(),
'feat': FeatureGroup(),
}, doc="Welcome to use ai2-kit!")
def _setup_logging():
level_name = os.environ.get('LOG_LEVEL', 'INFO')
level = logging._nameToLevel.get(level_name, logging.INFO)
logging.basicConfig(format='%(asctime)s %(name)s: %(message)s', level=level)
logging.getLogger('transitions.core').setLevel(logging.WARNING)
[docs]def main():
_setup_logging()
Fire(ai2_kit)
[docs]def rpc_main():
from fire_rpc import make_fire_cmd
from ai2_kit.core.util import json_dumps
_setup_logging()
Fire(make_fire_cmd(ai2_kit, json_dumps=json_dumps))
if __name__ == '__main__':
main()