Source code for ai2_kit.tool.ase
from ai2_kit.core.util import ensure_dir, expand_globs, slice_from_str, SAMPLE_METHOD, list_sample
from ai2_kit.core.log import get_logger
from ai2_kit.domain.cp2k import dump_coord_n_cell
from typing import List, Union, Optional
from ase import Atoms
import ase.io
logger = get_logger(__name__)
[docs]
class AseTool:
def __init__(self, atoms_list: Optional[List[Atoms]] = None):
self._atoms_list: List[Atoms] = [] if atoms_list is None else atoms_list
[docs]
def read(self, *file_path_or_glob: str, **kwargs):
files = expand_globs(file_path_or_glob)
if len(files) == 0:
raise FileNotFoundError(f'No file found for {file_path_or_glob}')
for file in files:
self._read(file, **kwargs)
return self
[docs]
def write(self, filename: str, **kwargs):
ensure_dir(filename.format(i=0))
self._write(filename, self._atoms_list, **kwargs)
[docs]
def set_cell(self, cell, scale_atoms=False, apply_constraint=True):
for atoms in self._atoms_list:
atoms.set_cell(cell, scale_atoms=scale_atoms, apply_constraint=apply_constraint)
return self
[docs]
def set_by_ref(self, ref_file: str, **kwargs):
kwargs.setdefault('index', 0)
ref_atoms = ase.io.read(ref_file, **kwargs)
assert isinstance(ref_atoms, Atoms), 'Only support single frame reference'
for atoms in self._atoms_list:
try:
atoms.set_cell(ref_atoms.get_cell())
except ValueError:
pass
try:
atoms.set_pbc(ref_atoms.get_pbc())
except ValueError:
pass
return self
[docs]
def slice(self, expr: str):
"""
slice systems by python slice expression, for example
`10:`, `:10`, `::2`, etc
:param start: start index
:param stop: stop index
:param step: step
"""
s = slice_from_str(expr)
self._atoms_list = self._atoms_list[s]
return self
[docs]
def sample(self, size: int, method: SAMPLE_METHOD='even', **kwargs):
"""
sample data
:param size: size of sample, if size is larger than data size, return all data
:param method: method to sample, can be 'even', 'random', 'truncate', default is 'even'
:param seed: seed for random sample, only used when method is 'random'
Note that by default the seed is length of input list,
if you want to generate different sample each time, you should set random seed manually
"""
self._atoms_list= list_sample(self._atoms_list, size, method, **kwargs)
return self
[docs]
def to_dpdata(self, labeled=False):
"""
convert to dpdata format and use dpdata tool to handle
:param labeled: if True, use dpdata.LabeledSystem, else use dpdata.System
"""
from .dpdata import dpdata, DpdataTool
System = dpdata.LabeledSystem if labeled else dpdata.System
systems = [System(atom, fmt='ase/structure') for atom in self._atoms_list]
return DpdataTool(systems=systems)
[docs]
def delete_atoms(self, id: Union[int, List[int]], start_id=0):
"""
delete atoms by id or list of id
:param id: id or list of id
:param start_id: the start id of first item, for example, in LAMMPS the id of first item is 1, in ASE it is 0
"""
ids = [id] if isinstance(id, int) else id
for atoms in self._atoms_list:
for i in sorted(ids, reverse=True):
assert i >= start_id, f'Invalid id {i}'
del atoms[i - start_id]
return self
@property
def write_each_frame(self):
logger.warn('write_each_frame has been deprecated, use write_frames instead')
return self.write_frames
[docs]
def write_frames(self, filename: str, **kwargs):
"""
write each frame to a separate file, useful to write to format only support single frame, POSCAR for example
:param filename: the filename template, use {i} to represent the index, for example, 'frame_{i}.xyz'
:param kwargs: other arguments for ase.io.write
"""
for i, atoms in enumerate(self._atoms_list):
_filename = filename.format(i=i)
ensure_dir(_filename)
self._write(_filename, atoms, **kwargs)
[docs]
def write_dplr_lammps_data(self, filename: str,
type_map: List[str], sel_type: List[int],
sys_charge_map: List[float], model_charge_map: List[float]):
"""
write atoms to LAMMPS data file for DPLR
the naming convention of params follows Deepmd-Kit's
about dplr: https://docs.deepmodeling.com/projects/deepmd/en/master/model/dplr.html
about fitting tensor: https://docs.deepmodeling.com/projects/deepmd/en/master/model/train-fitting-tensor.html
:param filename: the filename of LAMMPS data file, use {i} to represent the index, for example, 'frame_{i}.lammps.data'
:param type_map: the type map of atom type, for example, [O,H]
:param sel_type: the selected type of atom, for example, [0] means atom type 0, aka O is selected
:param sys_charge_map: the charge map of atom in system, for example, [6, 1]
:param model_charge_map: the charge map of atom in model, for example, [-8]
"""
from ai2_kit.domain.dpff import dump_dplr_lammps_data
ensure_dir(filename.format(i=0))
for i, atoms in enumerate(self._atoms_list):
with open(filename.format(i=i), 'w') as f:
dump_dplr_lammps_data(f, atoms=atoms, type_map=type_map, sel_type=sel_type,
sys_charge_map=sys_charge_map, model_charge_map=model_charge_map)
def _read(self, filename: str, **kwargs):
kwargs.setdefault('index', ':')
data = ase.io.read(filename, **kwargs)
if not isinstance(data, list):
data = [data]
self._atoms_list += data
def _write(self, filename: str, atoms_list, **kwargs):
assert len(atoms_list) > 0, 'No atoms to write'
fmt = kwargs.get('format', None)
if fmt == 'lammps-dump-text':
return self._write_lammps_dump_text(filename, atoms_list, **kwargs)
if fmt == 'cp2k-inc' or (fmt is None and filename.endswith('.inc')):
if not isinstance(atoms_list, list):
atoms_list = [atoms_list]
assert len(atoms_list) == 1, 'cp2k-inc only support single frame'
return self._write_cp2k_inc(filename, atoms_list[0])
return ase.io.write(filename, atoms_list, **kwargs)
def _write_cp2k_inc(self, filename: str, atoms: Atoms):
with open(filename, 'w') as f:
dump_coord_n_cell(f, atoms)
def _write_lammps_dump_text(self, filename: str, atoms_list, **kwargs):
if not isinstance(atoms_list, list):
atoms_list = [atoms_list]
type_map = kwargs.get('type_map')
lines = []
for i, atoms in enumerate(atoms_list):
lines += [
"ITEM: TIMESTEP", str(i),
"ITEM: NUMBER OF ATOMS", str(len(atoms)),
]
# It is kind of hacky to set boundary and cell in this way.
# TODO: refactor it in the future.
pbc = atoms.get_pbc()
boundary = ' '.join(['pp' if b else 'ff' for b in pbc])
cell = atoms.get_cell()
lines += [
f"ITEM: BOX BOUNDS xy xz yz {boundary}",
"%.16e %.16e %.16e" % (0, cell[0][0], 0),
"%.16e %.16e %.16e" % (0, cell[1][1], 0),
"%.16e %.16e %.16e" % (0, cell[2][2], 0),
# TODO: support more properties, forces for example
"ITEM: ATOMS id type x y z",
]
for j, atom in enumerate(atoms):
pos = atom.position
no = atom.number if type_map is None else type_map.index(atom.symbol) + 1
lines.append("%3d %2s %10.5f %10.5f %10.5f" % (j+1, no , pos[0], pos[1], pos[2]))
with open(filename, 'w') as f:
f.write('\n'.join(lines))