import argparse
import itertools
import logging
import time
from pathlib import Path
from typing import Dict, Optional, Union
import numpy as np
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatomic.torch import AtomisticModel
from omegaconf import DictConfig, OmegaConf
from metatrain.cli.formatter import CustomHelpFormatter
from metatrain.utils.data import (
CollateFn,
Dataset,
TargetInfo,
get_dataset,
read_systems,
)
from metatrain.utils.data.writers import (
DiskDatasetWriter,
Writer,
get_writer,
)
from metatrain.utils.devices import pick_devices
from metatrain.utils.errors import ArchitectureError
from metatrain.utils.evaluate_model import evaluate_model
from metatrain.utils.io import load_model
from metatrain.utils.logging import MetricLogger
from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.omegaconf import expand_dataset_config
from metatrain.utils.per_atom import average_by_num_atoms
logger = logging.getLogger(__name__)
def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None:
"""Add the `eval_model` paramaters to an argparse (sub)-parser"""
if eval_model.__doc__ is not None:
description = eval_model.__doc__.split(r":param")[0]
else:
description = None
# If you change the synopsis of these commands or add new ones adjust the completion
# script at `src/metatrain/share/metatrain-completion.bash`.
parser = subparser.add_parser(
"eval",
description=description,
formatter_class=CustomHelpFormatter,
)
parser.set_defaults(callable="eval_model")
parser.add_argument(
"path",
type=str,
help="Saved exported (.pt) model to be evaluated.",
)
parser.add_argument(
"options",
type=str,
help="Eval options YAML file to define a dataset for evaluation.",
)
parser.add_argument(
"-e",
"--extensions-dir",
type=str,
required=False,
dest="extensions_directory",
default=None,
help=(
"path to a directory containing all extensions required by the exported "
"model"
),
)
parser.add_argument(
"-o",
"--output",
dest="output",
type=str,
required=False,
default="output.xyz",
help="filename of the predictions (default: %(default)s)",
)
parser.add_argument(
"-b",
"--batch-size",
dest="batch_size",
required=False,
type=int,
default=1,
help="batch size for evaluation (default: %(default)s)",
)
parser.add_argument(
"--check-consistency",
dest="check_consistency",
action="store_true",
help="whether to run consistency checks (default: %(default)s)",
)
def _prepare_eval_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for eval_model."""
args.options = OmegaConf.load(args.options)
# models for evaluation are already exported. Don't have to pass the `name` argument
args.model = load_model(
path=args.__dict__.pop("path"),
extensions_directory=args.__dict__.pop("extensions_directory"),
)
def _eval_targets(
model: Union[AtomisticModel, torch.jit.RecursiveScriptModule],
dataset: Dataset,
options: Dict[str, TargetInfo],
batch_size: int = 1,
check_consistency: bool = False,
writer: Optional[Writer] = None,
) -> None:
"""
Evaluate `model` on `dataset`, accumulate RMSE/MAE, and (if `writer` is provided)
stream or buffer out per-sample writes.
"""
if len(dataset) == 0:
logging.info("This dataset is empty. No evaluation will be performed.")
return None
# Attach neighbor-lists
for sample in dataset:
system = sample["system"]
get_system_with_neighbor_lists(system, get_requested_neighbor_lists(model))
# Infer device/dtype
model_tensor = next(itertools.chain(model.parameters(), model.buffers()))
dtype = model_tensor.dtype
device = pick_devices(architecture_devices=model.capabilities().supported_devices)[
0
]
logging.info(f"Running on device {device} with dtype {dtype}")
model.to(dtype=dtype, device=device)
# DataLoader & metrics setup
if len(dataset) % batch_size != 0:
logging.debug(
f"The dataset size ({len(dataset)}) is not a multiple of the batch size "
f"({batch_size}). {len(dataset) // batch_size} batches will be "
f"constructed with a batch size of {batch_size}, and the last batch will "
f"have a size of {len(dataset) % batch_size}. This might lead to "
"inaccurate average timings."
)
# Create a dataloader
target_keys = list(model.capabilities().outputs.keys())
collate_fn = CollateFn(target_keys=target_keys)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False
)
rmse_acc = RMSEAccumulator()
mae_acc = MAEAccumulator()
# Warm-up
cycled = itertools.cycle(dataloader)
for _ in range(10):
batch = next(cycled)
systems = [s.to(device=device, dtype=dtype) for s in batch[0]]
evaluate_model(
model,
systems,
options,
is_training=False,
check_consistency=check_consistency,
)
total_time = 0.0
timings_per_atom = []
# Main evaluation loop
for batch in dataloader:
systems, batch_targets, _ = batch
systems = [system.to(dtype=dtype, device=device) for system in systems]
batch_targets = {
k: v.to(device=device, dtype=dtype) for k, v in batch_targets.items()
}
start_time = time.time()
batch_predictions = evaluate_model(
model,
systems,
options,
is_training=False,
check_consistency=check_consistency,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
end_time = time.time()
# Update metrics
preds_per_atom = average_by_num_atoms(
batch_predictions, systems, per_structure_keys=[]
)
targ_per_atom = average_by_num_atoms(
batch_targets, systems, per_structure_keys=[]
)
rmse_acc.update(preds_per_atom, targ_per_atom)
mae_acc.update(preds_per_atom, targ_per_atom)
# Write out each sample if a writer is configured
if writer:
writer.write(systems, batch_predictions)
# Timing
time_taken = end_time - start_time
total_time += time_taken
timings_per_atom.append(time_taken / sum(len(system) for system in systems))
# Finish writer
if writer:
writer.finish()
# Finalize metrics and log
rmse_vals = rmse_acc.finalize(not_per_atom=["positions_gradients"])
mae_vals = mae_acc.finalize(not_per_atom=["positions_gradients"])
metrics = {**rmse_vals, **mae_vals}
metric_logger = MetricLogger(
log_obj=logger, dataset_info=model.capabilities(), initial_metrics=metrics
)
metric_logger.log(metrics)
# Log timings
timings_per_atom = np.array(timings_per_atom)
mean_per_atom = np.mean(timings_per_atom)
std_per_atom = np.std(timings_per_atom)
logging.info(
f"evaluation time: {total_time:.2f} s "
f"[{1000.0 * mean_per_atom:.4f} ± "
f"{1000.0 * std_per_atom:.4f} ms per atom]"
)
[docs]
def eval_model(
model: Union[AtomisticModel, torch.jit.RecursiveScriptModule],
options: DictConfig,
output: Union[Path, str] = "output.xyz",
batch_size: int = 1,
check_consistency: bool = False,
append: Optional[bool] = None,
) -> None:
"""
Evaluate an exported model on a given data set.
If ``options`` contains a ``targets`` sub-section, RMSE values will be reported. If
this sub-section is missing, only a xyz-file with containing the properties the
model was trained against is written.
:param model: Saved model to be evaluated.
:param options: DictConfig to define a test dataset taken for the evaluation.
:param output: Path to save the predicted values.
:param check_consistency: Whether to run consistency checks during model evaluation.
:param append: If ``True``, open the output file in append mode.
"""
logging.info("Setting up evaluation set.")
output = Path(output) if isinstance(output, str) else output
options_list = expand_dataset_config(options)
for i, options in enumerate(options_list):
idx_suffix = f"_{i}" if len(options_list) > 1 else ""
extra_log_message = f" with index {i}" if len(options_list) > 1 else ""
logging.info(f"Evaluating dataset{extra_log_message}")
filename = f"{output.stem}{idx_suffix}{output.suffix}"
# pick the right writer
writer = get_writer(filename, capabilities=model.capabilities(), append=append)
# build the dataset & target-info
if hasattr(options, "targets"):
eval_dataset, eval_info_dict, _ = get_dataset(options)
eval_systems = (
[d.system for d in eval_dataset]
if not isinstance(writer, DiskDatasetWriter)
else None
)
else:
if isinstance(writer, DiskDatasetWriter):
raise ValueError(
"Writing to DiskDataset is not allowed without explicitly"
" defining targets in the input file."
)
eval_systems = read_systems(
filename=options["systems"]["read_from"],
reader=options["systems"]["reader"],
)
# FIXME: this works only for energy models
eval_targets: Dict[str, TensorMap] = {}
eval_info_dict = {}
do_strain_grad = all(
not torch.all(system.cell == 0) for system in eval_systems
)
layout = _get_energy_layout(do_strain_grad) # TODO: layout from the user
for key in model.capabilities().outputs.keys():
eval_info_dict[key] = TargetInfo(
quantity=model.capabilities().outputs[key].quantity,
unit=model.capabilities().outputs[key].unit,
# TODO: allow the user to specify whether per-atom or not
layout=layout,
)
eval_dataset = Dataset.from_dict({"system": eval_systems, **eval_targets})
# run evaluation & writing
try:
# we always let the writer handle I/O, so we never need return_predictions
# here
_eval_targets(
model=model,
dataset=eval_dataset,
options=eval_info_dict,
batch_size=batch_size,
check_consistency=check_consistency,
writer=writer,
)
except Exception as e:
raise ArchitectureError(f"Evaluation failed: {e}") from e
# no post-call write_predictions necessary anymore-writer did it all
def _get_energy_layout(strain_gradient: bool) -> TensorMap:
block = TensorBlock(
# float64: otherwise metatensor can't serialize
values=torch.empty(0, 1, dtype=torch.float64),
samples=Labels(
names=["system"],
values=torch.empty((0, 1), dtype=torch.int32),
),
components=[],
properties=Labels.range("energy", 1),
)
position_gradient_block = TensorBlock(
# float64: otherwise metatensor can't serialize
values=torch.empty(0, 3, 1, dtype=torch.float64),
samples=Labels(
names=["sample", "atom"],
values=torch.empty((0, 2), dtype=torch.int32),
),
components=[
Labels(
names=["xyz"],
values=torch.arange(3, dtype=torch.int32).reshape(-1, 1),
),
],
properties=Labels.range("energy", 1),
)
block.add_gradient("positions", position_gradient_block)
if strain_gradient:
strain_gradient_block = TensorBlock(
# float64: otherwise metatensor can't serialize
values=torch.empty(0, 3, 3, 1, dtype=torch.float64),
samples=Labels(
names=["sample", "atom"],
values=torch.empty((0, 2), dtype=torch.int32),
),
components=[
Labels(
names=["xyz_1"],
values=torch.arange(3, dtype=torch.int32).reshape(-1, 1),
),
Labels(
names=["xyz_2"],
values=torch.arange(3, dtype=torch.int32).reshape(-1, 1),
),
],
properties=Labels.range("energy", 1),
)
block.add_gradient("strain", strain_gradient_block)
energy_layout = TensorMap(
keys=Labels.single(),
blocks=[block],
)
return energy_layout