Source code for metatrain.utils.data.writers.diskdataset

import zipfile
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union

import metatomic.torch as mta
import numpy as np
import torch
from metatensor.torch import TensorMap
from metatomic.torch import ModelCapabilities, System

from .writers import Writer, _split_tensormaps


[docs] class DiskDatasetWriter(Writer): def __init__( self, path: Union[str, Path], capabilities: Optional[ ModelCapabilities ] = None, # unused, but matches base signature append: Optional[bool] = False, # if True, open zip in append mode ): super().__init__(filename=path, capabilities=capabilities, append=append) mode: Literal["w", "a"] = "a" if append else "w" self.zip_file = zipfile.ZipFile(path, mode) self.index = 0
[docs] def write(self, systems: List[System], predictions: Dict[str, TensorMap]): """ Write a single (system, predictions) into the zip under a new folder "<index>/". """ if len(systems) == 1: # Avoid reindexing samples split_predictions = [predictions] else: split_predictions = _split_tensormaps( systems, predictions, istart_system=self.index ) for system, preds in zip(systems, split_predictions): # system with self.zip_file.open(f"{self.index}/system.mta", "w") as f: mta.save(f, system.to("cpu").to(torch.float64)) # each target for target_name, tensor_map in preds.items(): with self.zip_file.open(f"{self.index}/{target_name}.mts", "w") as f: buf = tensor_map.to("cpu").to(torch.float64) # metatensor.torch.save_buffer returns a torch.Tensor buffer buffer = buf.save_buffer() np.save(f, buffer.numpy()) self.index += 1
[docs] def finish(self): self.zip_file.close()