Source code for metatrain.utils.data.writers.metatensor

from pathlib import Path
from typing import Dict, List, Optional, Union

import metatensor.torch as mts
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatomic.torch import ModelCapabilities, System

from .writers import Writer


[docs] class MetatensorWriter(Writer): """ Write systems and predictions to Metatensor files (.mts). """ def __init__( self, filename: Union[str, Path], capabilities: Optional[ModelCapabilities] = None, append: Optional[bool] = False, # unused, but matches base signature ): super().__init__(filename, capabilities, append) self._systems: List[System] = [] self._preds: List[Dict[str, TensorMap]] = []
[docs] def write(self, systems: List[System], predictions: Dict[str, TensorMap]): """ Accumulate systems and predictions to write them all at once in ``finish``. """ # just accumulate self._systems.extend(systems) self._preds.append(predictions)
[docs] def finish(self): """ Write all accumulated systems and predictions to Metatensor files. """ # concatenate per-sample TensorMaps into full ones predictions = _concatenate_tensormaps(self._preds) # write out .mts files (writes one file per target) filename_base = Path(self.filename).stem for prediction_name, prediction_tmap in predictions.items(): mts.save( filename_base + "_" + prediction_name + ".mts", prediction_tmap.to("cpu").to(torch.float64), )
def _concatenate_tensormaps( tensormap_dict_list: List[Dict[str, TensorMap]], ) -> Dict[str, TensorMap]: # Concatenating TensorMaps is tricky, because the model does not know the # "number" of the system it is predicting. For example, if a model predicts # 3 batches of 4 atoms each, the system labels will be [0, 1, 2, 3], # [0, 1, 2, 3], [0, 1, 2, 3] for the three batches, respectively. Due # to this, the join operation would not achieve the desired result # ([0, 1, 2, ..., 11, 12]). Here, we fix this by renaming the system labels. system_counter = 0 n_systems = 0 tensormaps_shifted_systems = [] for tensormap_dict in tensormap_dict_list: tensormap_dict_shifted = {} for name, tensormap in tensormap_dict.items(): new_keys = [] new_blocks = [] for key, block in tensormap.items(): new_key = key where_system = block.samples.names.index("system") n_systems = torch.max(block.samples.column("system")) + 1 new_samples_values = block.samples.values new_samples_values[:, where_system] += system_counter new_block = TensorBlock( values=block.values, samples=Labels(block.samples.names, values=new_samples_values), components=block.components, properties=block.properties, ) for gradient_name, gradient_block in block.gradients(): new_block.add_gradient( gradient_name, gradient_block, ) new_keys.append(new_key) new_blocks.append(new_block) tensormap_dict_shifted[name] = TensorMap( keys=Labels( names=tensormap.keys.names, values=torch.stack([new_key.values for new_key in new_keys]), ), blocks=new_blocks, ) tensormaps_shifted_systems.append(tensormap_dict_shifted) system_counter += n_systems return { target: mts.join( [pred[target] for pred in tensormaps_shifted_systems], axis="samples", remove_tensor_name=True, ) for target in tensormaps_shifted_systems[0].keys() }