Source code for metatrain.utils.data.writers.ase

from pathlib import Path
from typing import Dict, List, Optional, Union

import ase
import ase.io
import torch
from metatensor.torch import TensorMap
from metatomic.torch import ModelCapabilities, System

from metatrain.utils.external_naming import to_external_name

from .writers import Writer, _split_tensormaps


[docs] class ASEWriter(Writer): """Write systems and predictions to an ASE-compatible XYZ file.""" def __init__( self, filename: Union[str, Path], capabilities: Optional[ ModelCapabilities ] = None, # unused, but matches base signature append: Optional[bool] = False, # unused, but matches base signature ): super().__init__(filename, capabilities, append) self._first = True self._systems: List[System] = [] self._preds: List[Dict[str, TensorMap]] = []
[docs] def write(self, systems: List[System], predictions: Dict[str, TensorMap]): """ Accumulate systems and predictions to write them all at once in ``finish``. """ self._systems.extend([system.to("cpu").to(torch.float64) for system in systems]) self._preds.extend(_split_tensormaps(systems, predictions))
[docs] def finish(self): """ Write all accumulated systems and predictions to the XYZ file. """ if not self._systems: return systems = self._systems predictions_by_structure = self._preds frames = [] for system, system_predictions in zip(systems, predictions_by_structure): info = {} arrays = {} for target_name, target_map in system_predictions.items(): if len(target_map.keys) != 1: raise ValueError( "Only single-block `TensorMap`s can be " "written to xyz files for the moment." ) block = target_map.block() if "atom" in block.samples.names: # save inside arrays values = block.values.detach().cpu().numpy() arrays[target_name] = values.reshape(values.shape[0], -1) # reshaping reshaping because `arrays` only accepts 2D arrays else: # save inside info if block.values.numel() == 1: info[target_name] = block.values.item() else: info[target_name] = ( block.values.detach().cpu().numpy().squeeze(0) ) # squeeze the sample dimension, which corresponds to the system for gradient_name, gradient_block in block.gradients(): # we assume that gradients are always an array, never a scalar internal_name = f"{target_name}_{gradient_name}_gradients" external_name = to_external_name( internal_name, self.capabilities.outputs ) if "forces" in external_name: arrays[external_name] = ( # squeeze the property dimension -gradient_block.values.detach().cpu().squeeze(-1).numpy() ) elif "virial" in external_name: # in this case, we write both the virial and the stress external_name_virial = external_name external_name_stress = external_name.replace("virial", "stress") strain_derivatives = ( # squeeze the property dimension gradient_block.values.detach().cpu().squeeze(-1).numpy() ) if not torch.any(system.cell != 0): raise ValueError( "stresses cannot be written for non-periodic systems." ) cell_volume = torch.det(system.cell).item() if cell_volume == 0: raise ValueError( ( "stresses cannot be written for " "systems with zero volume." ) ) info[external_name_virial] = -strain_derivatives info[external_name_stress] = strain_derivatives / cell_volume else: info[external_name] = ( # squeeze the property dimension gradient_block.values.detach().cpu().squeeze(-1).numpy() ) atoms = ase.Atoms( symbols=system.types.numpy(), positions=system.positions.detach().numpy(), info=info, ) # assign cell and pbcs if torch.any(system.cell != 0): atoms.pbc = True atoms.cell = system.cell.detach().cpu().numpy() # assign arrays for array_name, array in arrays.items(): atoms.arrays[array_name] = array frames.append(atoms) ase.io.write(self.filename, frames)