Source code for metatrain.utils.data.get_dataset

from typing import Dict, List, Tuple

from metatensor.torch import TensorMap
from omegaconf import DictConfig

from .dataset import Dataset, DiskDataset
from .readers import read_extra_data, read_systems, read_targets
from .target_info import TargetInfo


[docs] def get_dataset( options: DictConfig, ) -> Tuple[Dataset, Dict[str, TargetInfo], Dict[str, TargetInfo]]: """ Gets a dataset given a configuration dictionary. The system and targets in the dataset are read from one or more files, as specified in ``options``. :param options: the configuration options for the dataset. This configuration dictionary must contain keys for both the systems and targets in the dataset. :returns: A tuple containing a ``Dataset`` object and a ``Dict[str, TargetInfo]`` containing additional information (units, physical quantities, ...) on the targets in the dataset """ extra_data_info_dictionary = {} if options["systems"]["read_from"].endswith(".zip"): # disk dataset dataset = DiskDataset( options["systems"]["read_from"], fields=[*options["targets"], *options.get("extra_data", {})], ) target_info_dictionary = dataset.get_target_info(options["targets"]) if "extra_data" in options: extra_data_info_dictionary = dataset.get_target_info(options["extra_data"]) else: systems = read_systems( filename=options["systems"]["read_from"], reader=options["systems"]["reader"], ) targets, target_info_dictionary = read_targets(conf=options["targets"]) extra_data: Dict[str, List[TensorMap]] = {} if "extra_data" in options: extra_data, extra_data_info_dictionary = read_extra_data( conf=options["extra_data"] ) intersecting_keys = targets.keys() & extra_data.keys() if intersecting_keys: raise ValueError( f"Extra data keys {intersecting_keys} overlap with target keys. " "Please use unique keys for targets and extra data." ) dataset = Dataset.from_dict({"system": systems, **targets, **extra_data}) return dataset, target_info_dictionary, extra_data_info_dictionary