[docs]classASEWriter(Writer):"""Write systems and predictions to an ASE-compatible XYZ file."""def__init__(self,filename:Union[str,Path],capabilities:Optional[ModelCapabilities]=None,# unused, but matches base signatureappend:Optional[bool]=False,# unused, but matches base signature):super().__init__(filename,capabilities,append)self._first=Trueself._systems:List[System]=[]self._preds:List[Dict[str,TensorMap]]=[]
[docs]defwrite(self,systems:List[System],predictions:Dict[str,TensorMap]):""" Accumulate systems and predictions to write them all at once in ``finish``. """self._systems.extend([system.to("cpu").to(torch.float64)forsysteminsystems])self._preds.extend(_split_tensormaps(systems,predictions))
[docs]deffinish(self):""" Write all accumulated systems and predictions to the XYZ file. """ifnotself._systems:returnsystems=self._systemspredictions_by_structure=self._predsframes=[]forsystem,system_predictionsinzip(systems,predictions_by_structure):info={}arrays={}fortarget_name,target_mapinsystem_predictions.items():iflen(target_map.keys)!=1:raiseValueError("Only single-block `TensorMap`s can be ""written to xyz files for the moment.")block=target_map.block()if"atom"inblock.samples.names:# save inside arraysvalues=block.values.detach().cpu().numpy()arrays[target_name]=values.reshape(values.shape[0],-1)# reshaping reshaping because `arrays` only accepts 2D arrayselse:# save inside infoifblock.values.numel()==1:info[target_name]=block.values.item()else:info[target_name]=(block.values.detach().cpu().numpy().squeeze(0))# squeeze the sample dimension, which corresponds to the systemforgradient_name,gradient_blockinblock.gradients():# we assume that gradients are always an array, never a scalarinternal_name=f"{target_name}_{gradient_name}_gradients"external_name=to_external_name(internal_name,self.capabilities.outputs)if"forces"inexternal_name:arrays[external_name]=(# squeeze the property dimension-gradient_block.values.detach().cpu().squeeze(-1).numpy())elif"virial"inexternal_name:# in this case, we write both the virial and the stressexternal_name_virial=external_nameexternal_name_stress=external_name.replace("virial","stress")strain_derivatives=(# squeeze the property dimensiongradient_block.values.detach().cpu().squeeze(-1).numpy())ifnottorch.any(system.cell!=0):raiseValueError("stresses cannot be written for non-periodic systems.")cell_volume=torch.det(system.cell).item()ifcell_volume==0:raiseValueError(("stresses cannot be written for ""systems with zero volume."))info[external_name_virial]=-strain_derivativesinfo[external_name_stress]=strain_derivatives/cell_volumeelse:info[external_name]=(# squeeze the property dimensiongradient_block.values.detach().cpu().squeeze(-1).numpy())atoms=ase.Atoms(symbols=system.types.numpy(),positions=system.positions.detach().numpy(),info=info,)# assign cell and pbcsiftorch.any(system.cell!=0):atoms.pbc=Trueatoms.cell=system.cell.detach().cpu().numpy()# assign arraysforarray_name,arrayinarrays.items():atoms.arrays[array_name]=arrayframes.append(atoms)ase.io.write(self.filename,frames)