import random
from typing import Dict, List, Optional, Tuple
import metatensor.torch as mts
import numpy as np
import torch
from metatensor.torch import TensorBlock, TensorMap
from metatomic.torch import System, register_autograd_neighbors
from scipy.spatial.transform import Rotation
from . import torch_jit_script_unless_coverage
from .data import TargetInfo
[docs]
def get_random_rotation():
return Rotation.random()
[docs]
def get_random_inversion():
return random.choice([1, -1])
[docs]
class RotationalAugmenter:
"""
A class to apply random rotations and inversions to a set of systems and their
targets.
:param target_info_dict: A dictionary mapping target names to their corresponding
:class:`TargetInfo` objects. This is used to determine the type of targets and
how to apply the augmentations.
"""
def __init__(
self,
target_info_dict: Dict[str, TargetInfo],
extra_data_info_dict: Optional[Dict[str, TargetInfo]] = None,
):
# checks on targets
for target_info in target_info_dict.values():
if target_info.is_cartesian:
if len(target_info.layout.block(0).components) > 2:
raise ValueError(
"RotationalAugmenter only supports Cartesian targets "
"with `rank<=2`."
)
self.target_info_dict = target_info_dict
if extra_data_info_dict is None:
extra_data_info_dict = {}
self.extra_data_info_dict = extra_data_info_dict
self.wigner = None
self.complex_to_real_spherical_harmonics_transforms = {}
is_any_target_spherical = any(
target_info.is_spherical for target_info in target_info_dict.values()
)
is_any_extra_data_spherical = any(
extra_data_info.is_spherical
for extra_data_info in extra_data_info_dict.values()
)
if is_any_target_spherical or is_any_extra_data_spherical:
try:
import spherical
except ImportError:
# quaternionic (used below) is a dependency of spherical
raise ImportError(
"To use spherical targets with nanoPET, please install the "
"`spherical` package with `pip install spherical`."
)
largest_l_targets = -1
largest_l_extra_data = -1
if is_any_target_spherical:
largest_l_targets = max(
(len(block.components[0]) - 1) // 2
for target_info in target_info_dict.values()
if target_info.is_spherical
for block in target_info.layout.blocks()
)
if is_any_extra_data_spherical:
largest_l_extra_data = max(
(len(block.components[0]) - 1) // 2
for extra_data_info in extra_data_info_dict.values()
if extra_data_info.is_spherical
for block in extra_data_info.layout.blocks()
)
largest_l = max(largest_l_targets, largest_l_extra_data)
self.wigner = spherical.Wigner(largest_l)
for ell in range(largest_l + 1):
self.complex_to_real_spherical_harmonics_transforms[ell] = (
_complex_to_real_spherical_harmonics_transform(ell)
)
[docs]
def apply_random_augmentations(
self,
systems: List[System],
targets: Dict[str, TensorMap],
extra_data: Optional[Dict[str, TensorMap]] = None,
) -> Tuple[List[System], Dict[str, TensorMap], Dict[str, TensorMap]]:
"""
Apply a random augmentation to a number of ``System`` objects and its targets.
:param systems: A list of :class:`System` objects to be augmented.
:param targets: A dictionary mapping target names to their corresponding
:class:`TensorMap` objects. These are the targets to be augmented.
:return: A tuple containing the augmented systems and targets.
"""
rotations = [get_random_rotation() for _ in range(len(systems))]
inversions = [get_random_inversion() for _ in range(len(systems))]
transformations = [
torch.from_numpy(r.as_matrix() * i) for r, i in zip(rotations, inversions)
]
wigner_D_matrices = {}
if self.wigner is not None:
scipy_quaternions = [r.as_quat() for r in rotations]
quaternionic_quaternions = [
_scipy_quaternion_to_quaternionic(q) for q in scipy_quaternions
]
wigner_D_matrices_complex = [
self.wigner.D(q) for q in quaternionic_quaternions
]
tensormap_dicts = (
[targets, extra_data] if extra_data is not None else [targets]
)
info_dicts = (
[self.target_info_dict, self.extra_data_info_dict]
if extra_data is not None
else [self.target_info_dict]
)
for tensormap_dict, info_dict in zip(tensormap_dicts, info_dicts):
for name in tensormap_dict.keys():
tensormap_info = info_dict[name]
if tensormap_info.is_spherical:
for block in tensormap_info.layout.blocks():
ell = (len(block.components[0]) - 1) // 2
U = self.complex_to_real_spherical_harmonics_transforms[ell]
if ell not in wigner_D_matrices: # skip if already computed
wigner_D_matrices_l = []
for (
wigner_D_matrix_complex
) in wigner_D_matrices_complex:
wigner_D_matrix = np.zeros(
(2 * ell + 1, 2 * ell + 1), dtype=np.complex128
)
for mp in range(-ell, ell + 1):
for m in range(-ell, ell + 1):
wigner_D_matrix[m + ell, mp + ell] = (
wigner_D_matrix_complex[
self.wigner.Dindex(ell, m, mp)
]
).conj()
wigner_D_matrix = U.conj() @ wigner_D_matrix @ U.T
assert np.allclose(wigner_D_matrix.imag, 0.0)
wigner_D_matrix = wigner_D_matrix.real
wigner_D_matrices_l.append(
torch.from_numpy(wigner_D_matrix)
)
wigner_D_matrices[ell] = wigner_D_matrices_l
return _apply_random_augmentations(
systems, targets, transformations, wigner_D_matrices, extra_data=extra_data
)
def _apply_wigner_D_matrices(
systems: List[System],
target_tmap: TensorMap,
transformations: List[torch.Tensor],
wigner_D_matrices: Dict[int, List[torch.Tensor]],
) -> TensorMap:
new_blocks: List[TensorBlock] = []
for key, block in target_tmap.items():
ell, sigma = int(key[0]), int(key[1])
values = block.values
if "atom" in block.samples.names:
split_values = torch.split(
values, [len(system.positions) for system in systems]
)
else:
split_values = torch.split(values, [1 for _ in systems])
new_values = []
ell = (len(block.components[0]) - 1) // 2
for v, transformation, wigner_D_matrix in zip(
split_values, transformations, wigner_D_matrices[ell]
):
is_inverted = torch.det(transformation) < 0
new_v = v.clone()
if is_inverted: # inversion
new_v = new_v * (-1) ** ell * sigma
# fold property dimension in, apply transformation,
# unfold property dimension
new_v = new_v.transpose(1, 2)
new_v = new_v @ wigner_D_matrix.T
new_v = new_v.transpose(1, 2)
new_values.append(new_v)
new_values = torch.concatenate(new_values)
new_block = TensorBlock(
values=new_values,
samples=block.samples,
components=block.components,
properties=block.properties,
)
new_blocks.append(new_block)
return TensorMap(
keys=target_tmap.keys,
blocks=new_blocks,
)
@torch_jit_script_unless_coverage # script for speed
def _apply_random_augmentations(
systems: List[System],
targets: Dict[str, TensorMap],
transformations: List[torch.Tensor],
wigner_D_matrices: Dict[int, List[torch.Tensor]],
extra_data: Optional[Dict[str, TensorMap]] = None,
) -> Tuple[List[System], Dict[str, TensorMap], Dict[str, TensorMap]]:
# Apply the transformations to the systems
new_systems: List[System] = []
for system, transformation in zip(systems, transformations):
new_system = System(
positions=system.positions @ transformation.T,
types=system.types,
cell=system.cell @ transformation.T,
pbc=system.pbc,
)
for options in system.known_neighbor_lists():
neighbors = mts.detach_block(system.get_neighbor_list(options))
neighbors.values[:] = (
neighbors.values.squeeze(-1) @ transformation.T
).unsqueeze(-1)
register_autograd_neighbors(system, neighbors)
new_system.add_neighbor_list(options, neighbors)
new_systems.append(new_system)
# Apply the transformation to the targets and extra data
new_targets: Dict[str, TensorMap] = {}
new_extra_data: Dict[str, TensorMap] = {}
for tensormap_dict, new_dict in zip(
[targets, extra_data], [new_targets, new_extra_data]
):
if tensormap_dict is None:
continue
assert tensormap_dict is not None
for name, original_tmap in tensormap_dict.items():
is_scalar = False
if len(original_tmap.blocks()) == 1:
if len(original_tmap.block().components) == 0:
is_scalar = True
is_cartesian = False
if len(original_tmap.blocks()) == 1:
if len(original_tmap.block().components) > 0:
if "xyz" in original_tmap.block().components[0].names[0]:
is_cartesian = True
is_spherical = all(
len(block.components) == 1 and block.components[0].names == ["o3_mu"]
for block in original_tmap.blocks()
)
if is_scalar:
# no change for energies
energy_block = TensorBlock(
values=original_tmap.block().values,
samples=original_tmap.block().samples,
components=original_tmap.block().components,
properties=original_tmap.block().properties,
)
if original_tmap.block().has_gradient("positions"):
# transform position gradients:
block = original_tmap.block().gradient("positions")
position_gradients = block.values.squeeze(-1)
split_sizes_forces = [
system.positions.shape[0] for system in systems
]
split_position_gradients = torch.split(
position_gradients, split_sizes_forces
)
position_gradients = torch.cat(
[
split_position_gradients[i] @ transformations[i].T
for i in range(len(systems))
]
)
energy_block.add_gradient(
"positions",
TensorBlock(
values=position_gradients.unsqueeze(-1),
samples=block.samples,
components=block.components,
properties=block.properties,
),
)
if original_tmap.block().has_gradient("strain"):
# transform strain gradients (rank-2 tensor):
block = original_tmap.block().gradient("strain")
strain_gradients = block.values.squeeze(-1)
split_strain_gradients = torch.split(strain_gradients, 1)
new_strain_gradients = torch.stack(
[
transformations[i]
@ split_strain_gradients[i].squeeze(0)
@ transformations[i].T
for i in range(len(systems))
],
dim=0,
)
energy_block.add_gradient(
"strain",
TensorBlock(
values=new_strain_gradients.unsqueeze(-1),
samples=block.samples,
components=block.components,
properties=block.properties,
),
)
new_dict[name] = TensorMap(
keys=original_tmap.keys,
blocks=[energy_block],
)
elif is_spherical:
new_dict[name] = _apply_wigner_D_matrices(
systems, original_tmap, transformations, wigner_D_matrices
)
elif is_cartesian:
rank = len(original_tmap.block().components)
if rank == 1:
# transform Cartesian vector:
block = original_tmap.block()
vectors = block.values
if "atom" in original_tmap.block().samples.names:
split_vectors = torch.split(
vectors, [len(system.positions) for system in systems]
)
else:
split_vectors = torch.split(vectors, [1 for _ in systems])
new_vectors = []
for v, transformation in zip(split_vectors, transformations):
# fold property dimension in, apply transformation,
# unfold property dimension
new_v = v.transpose(1, 2)
new_v = new_v @ transformation.T
new_v = new_v.transpose(1, 2)
new_vectors.append(new_v)
new_vectors = torch.cat(new_vectors)
new_dict[name] = TensorMap(
keys=original_tmap.keys,
blocks=[
TensorBlock(
values=new_vectors,
samples=block.samples,
components=block.components,
properties=block.properties,
)
],
)
elif rank == 2:
# transform Cartesian rank-2 tensor:
block = original_tmap.block()
tensor = block.values
if "atom" in original_tmap.block().samples.names:
split_tensors = torch.split(
tensor, [len(system.positions) for system in systems]
)
else:
split_tensors = torch.split(tensor, [1 for _ in systems])
new_tensors = []
for tensor, transformation in zip(split_tensors, transformations):
new_tensor = torch.einsum(
"Aa,iabp,bB->iABp", transformation, tensor, transformation.T
)
new_tensors.append(new_tensor)
new_tensors = torch.cat(new_tensors)
new_dict[name] = TensorMap(
keys=original_tmap.keys,
blocks=[
TensorBlock(
values=new_tensors,
samples=block.samples,
components=block.components,
properties=block.properties,
)
],
)
return new_systems, new_targets, new_extra_data
def _complex_to_real_spherical_harmonics_transform(ell: int):
# Generates the transformation matrix from complex spherical harmonics
# to real spherical harmonics for a given l.
# Returns a transformation matrix of shape ((2l+1), (2l+1)).
if ell < 0 or not isinstance(ell, int):
raise ValueError("l must be a non-negative integer.")
# The size of the transformation matrix is (2l+1) x (2l+1)
size = 2 * ell + 1
U = np.zeros((size, size), dtype=complex)
for m in range(-ell, ell + 1):
m_index = m + ell # Index in the matrix
if m > 0:
# Real part of Y_{l}^{m}
U[m_index, ell + m] = 1 / np.sqrt(2) * (-1) ** m
U[m_index, ell - m] = 1 / np.sqrt(2)
elif m < 0:
# Imaginary part of Y_{l}^{|m|}
U[m_index, ell + abs(m)] = -1j / np.sqrt(2) * (-1) ** m
U[m_index, ell - abs(m)] = 1j / np.sqrt(2)
else: # m == 0
# Y_{l}^{0} remains unchanged
U[m_index, ell] = 1
return U
def _scipy_quaternion_to_quaternionic(q_scipy):
# This function convert a quaternion obtained from the scipy library to the format
# used by the quaternionic library.
# Note: 'xyzw' is the format used by scipy.spatial.transform.Rotation
# while 'wxyz' is the format used by quaternionic.
qx, qy, qz, qw = q_scipy
q_quaternion = np.array([qw, qx, qy, qz])
return q_quaternion