Source code for metatrain.utils.data.writers.writers

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional, Union

import metatensor.torch as mts
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatomic.torch import ModelCapabilities, System


[docs] class Writer(ABC): def __init__( self, filename: Union[str, Path], capabilities: Optional[ModelCapabilities] = None, append: Optional[bool] = None, ): self.filename = filename self.capabilities = capabilities self.append = append
[docs] @abstractmethod def write(self, systems: List[System], predictions: Dict[str, TensorMap]): """Write a single system and its predictions.""" ...
[docs] @abstractmethod def finish(self): """Called after all writes. Optional to override.""" ...
def _split_tensormaps( systems: List[System], batch_predictions: Dict[str, TensorMap], istart_system: Optional[int] = 0, ) -> List[Dict[str, TensorMap]]: """ Split a TensorMap into multiple TensorMaps, one for each key. """ device = next(iter(batch_predictions.values()))[0].values.device split_selection = [ Labels("system", torch.tensor([[i]], device=device)) for i in range(len(systems)) ] batch_predictions_split = { key: mts.split(tensormap, "samples", split_selection) for key, tensormap in batch_predictions.items() } out_tensormaps: List[Dict[str, TensorMap]] = [] for i in range(len(systems)): # build a per-sample dict tensormaps: Dict[str, TensorMap] = {} for k in batch_predictions_split.keys(): new_blocks: List[TensorBlock] = [] for block in batch_predictions_split[k][i]: new_block = TensorBlock( samples=Labels( block.samples.names, block.samples.values + istart_system * torch.eye( block.samples.values.size(-1), device=block.samples.values.device, dtype=block.samples.values.dtype, )[0], ), components=block.components, properties=block.properties, values=block.values, ) for gradient_name, gradient_block in block.gradients(): new_block.add_gradient( gradient_name, TensorBlock( samples=Labels( gradient_block.samples.names, gradient_block.samples.values + istart_system * torch.eye( gradient_block.samples.values.size(-1), device=gradient_block.samples.values.device, dtype=gradient_block.samples.values.dtype, )[0], ), components=gradient_block.components, properties=gradient_block.properties, values=gradient_block.values, ), ) new_blocks.append(new_block) tensormaps[k] = TensorMap( keys=batch_predictions_split[k][i].keys, blocks=new_blocks, ) out_tensormaps.append(tensormaps) return out_tensormaps