Source code for ai2_kit.domain.data
from ai2_kit.core.artifact import ArtifactDict, Artifact
from typing import List, Tuple, Optional
from ase import Atoms
import ase.io
import os
[docs]
class DataFormat:
# customize data format
CP2K_OUTPUT_DIR = 'cp2k/output_dir'
VASP_OUTPUT_DIR = 'vasp/output_dir'
LAMMPS_OUTPUT_DIR = 'lammps/output_dir'
DEEPMD_OUTPUT_DIR = 'deepmd/output_dir'
ANYWARE_OUTPUT_DIR = 'anyware/output_dir'
DEEPMD_MODEL = 'deepmd/model'
DEEPMD_NPY = 'deepmd/npy'
LASP_LAMMPS_OUT_DIR ='lasp+lammps/output_dir'
# data format of dpdata
CP2K_OUTPUT = 'cp2k/output'
VASP_XML = 'vasp/xml'
# data format of ase
EXTXYZ = 'extxyz'
VASP_POSCAR = 'vasp/poscar'
[docs]
def get_data_format(artifact: dict) -> Optional[str]:
"""
Get (or guess) data type from artifact dict
Note: The reason of using dict instead of Artifact is Artifact is not pickleable
"""
url = artifact.get('url')
assert isinstance(url, str), f'url must be str, got {type(url)}'
file_name = os.path.basename(url)
format = artifact.get('format')
if format and isinstance(format, str):
return format # TODO: validate format
if file_name.endswith('.xyz'):
return DataFormat.EXTXYZ
if 'POSCAR' in file_name:
return DataFormat.VASP_POSCAR
return None
[docs]
def artifacts_to_ase_atoms(artifacts: List[ArtifactDict], type_map: List[str]) -> List[Tuple[ArtifactDict, Atoms]]:
"""
Read ase atoms list from artifacts
Deprecated since it is not recommended to use ArtifactDict
"""
results = []
for a in artifacts:
data_format = get_data_format(a) # type: ignore
url = a['url']
if data_format in [DataFormat.VASP_POSCAR, 'vasp']:
atoms_list = ase.io.read(url, ':', format='vasp')
elif data_format in [DataFormat.EXTXYZ, 'extxyz']:
atoms_list = ase.io.read(url, ':', format='extxyz')
elif data_format is not None:
atoms_list = ase.io.read(url, ':', format=data_format)
else:
raise ValueError(f'unsupported data format: {data_format}')
results.extend((a, atoms) for atoms in atoms_list)
return results
[docs]
def artifacts_to_ase_atoms_v2(artifacts: List[Artifact]) -> List[Tuple[Artifact, Atoms]]:
results = []
for a in artifacts:
data_format = get_data_format(a.to_dict()) # type: ignore
url = a.url
if data_format in [DataFormat.VASP_POSCAR, 'vasp']:
atoms_list = ase.io.read(url, ':', format='vasp')
elif data_format in [DataFormat.EXTXYZ, 'extxyz']:
atoms_list = ase.io.read(url, ':', format='extxyz')
elif data_format is not None:
atoms_list = ase.io.read(url, ':', format=data_format)
else:
raise ValueError(f'unsupported data format: {data_format}')
results.extend((a, atoms) for atoms in atoms_list)
return results
[docs]
def ase_atoms_to_cp2k_input_data(atoms: Atoms) -> Tuple[List[str], List[List[float]]]:
coords = [atom.symbol + ' ' + ' '.join(str(x) for x in atom.position) for atom in atoms] # type: ignore
cell = [list(row) for row in atoms.cell] # type: ignore
return (coords, cell)
[docs]
def convert_to_lammps_input_data(systems: List[ArtifactDict], base_dir: str, type_map: List[str]):
data_files = []
atoms_list = artifacts_to_ase_atoms(systems, type_map=type_map)
for i, (artifact, atoms) in enumerate(atoms_list):
data_file = os.path.join(base_dir, f'{i:06d}.lammps.data')
ase.io.write(data_file, atoms, format='lammps-data', specorder=type_map) # type: ignore
data_files.append({
'url': data_file,
'attrs': artifact['attrs'],
})
return data_files