# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*-
#
# Copyright (c) 2023 Authors and contributors
# (see the AUTHORS.rst file for the full list of names)
#
# Released under the BSD 3-Clause "New" or "Revised" License
# SPDX-License-Identifier: BSD-3-Clause
"""
Module for splitting lists of :py:class:`TensorMap` objects into multiple
:py:class:`TensorMap` objects along a given axis.
"""
from typing import List, Optional, Tuple, Union
import metatensor
import numpy as np
from metatensor import Labels, TensorMap
[docs]
def split_data(
tensors: Union[List[TensorMap], TensorMap],
axis: str,
names: Union[List[str], str],
n_groups: int,
group_sizes: Optional[Union[List[int], List[float]]] = None,
seed: Optional[int] = None,
) -> Tuple[List[List[TensorMap]], List[Labels]]:
"""
Splits a list of :py:class:`TensorMap` objects into multiple
:py:class:`TensorMap` objects along a given axis.
For either the "samples" or "properties" `axis`, the unique indices for the
specified metadata `name` are found. If `seed` is set, the indices are
shuffled. Then, they are divided into `n_groups`, where the sizes of the
groups are specified by the `group_sizes` argument.
These grouped indices are then used to split the list of input tensors. The
split tensors, along with the grouped labels, are returned. The tensors are
returned as a list of list of :py:class:`TensorMap` objects.
Each list in the returned :py:class:`list` of :py:class:`list` corresponds
to the split :py:class`TensorMap` at the same position in the input
`tensors` list. Each nested list contains :py:class:`TensorMap` objects that
share no common indices for the specified `axis` and `names`. However, the
metadata on all other axes (including the keys) will be equivalent.
The passed list of :py:class:`TensorMap` objects in `tensors` must have the
same set of unique indices for the specified `axis` and `names`. For
instance, if passing an input and output tensor for splitting (i.e. as used
in supervised machine learning), the output tensor must have structure
indices 0 -> 10 if the input tensor does.
:param tensors: input `list` of :py:class:`TensorMap` objects, each of which
will be split into `n_groups` new :py:class:`TensorMap` objects.
:param axis: a :py:class:`str` equal to either "samples" or "properties".
This is the axis along which the input :py:class:`TensorMap` objects
will be split.
:param names: a :py:class:`list` of :py:class:`str` indicating the
samples/properties names by which the `tensors` will be split.
:param n_groups: an :py:class:`int` indicating how many new
:py:class:`TensorMap` objects each of the tensors passed in `tensors`
will be split into. If `group_sizes` is none (default), `n_groups` is
used to split the data into ``n`` evenly sized groups according to the
unique metadata for the specified `axis` and `names`, to the nearest
integer.
:param group_sizes: an ordered :py:class:`list` of :py:class:`float` the
group sizes to split each input :py:class:`TensorMap` into. A
:py:class:`list` of :py:class:`int` will be interpreted as an indication
of the absolute group sizes, whereas a list of float as indicating the
relative sizes. For the former case, the sum of this list must be <= the
total number of unique indices present in the input `tensors` for the
chosen `axis` and `names`. In the latter, the sum of this list must be
<= 1.
:param seed: an :py:class:`int` that seeds the numpy random number
generator. Used to control shuffling of the unique indices, which
dictate the data that ends up in each of the split output tensors. If
None (default), no shuffling of the indices occurs. If a
:py:class:`int`, shuffling is executed but with a random seed set to
this value.
:return split_tensors: :py:class:`list` of :py:class:`list` of
:py:class:`TensorMap`. The ``i`` th element in the list contains
`n_groups` :py:class:`TensorMap` objects corresponding to the split ith
:py:class:`TensorMap` of the input list `tensors`.
:return grouped_labels: list of :py:class:`Labels` corresponding to the
unique indices according to the specified `axis` and `names` that are
present in each of the returned groups of :py:class:`TensorMap`. The
length of this list is `n_groups`.
Examples
--------
Split a TensorMap `tensor` into 2 new TensorMaps along the "samples" axis
for the "structure" metadata. Without specifying `group_sizes`, the data
will be split equally by structure index. If the number of unique strutcure
indices present in the input data is not exactly divisible by `n_groups`,
the group sizes will be made to the nearest int. Without specifying
`seed`, no shuffling of the structure indices will occur and they
will be grouped in lexigraphical order. For instance, if the input tensor
has structure indices 0 -> 9 (inclusive), the first new tensor will contain
only structure indices 0 -> 4 (inc.) and the second will contain only 5 -> 9
(inc).
.. code-block:: python
from equisolve.utils import split_data
[[new_tensor_1, new_tensor_2]], grouped_labels = split_data(
tensors=tensor,
axis="samples",
names=["structure"],
n_groups=2,
)
Split 2 tensors corresponding to input and output data into train and test
data, with a relative 80:20 ratio. If both input and output tensors contain
structure indices 0 -> 9 (inclusive), the `in_train` and `out_train` tensors
will contain structure indices 0 -> 7 (inc.) and the `in_test` and
`out_test` tensors will contain structure indices 8 -> 9 (inc.). As we want
to specify relative group sizes, we will pass `group_sizes` as a list of
float. Specifying the `seed` will shuffle the structure indices
before the groups are made.
.. code-block:: python
from equisolve.utils import split_data
[[in_train, in_test], [out_train, out_test]], grouped_labels = split_data(
tensors=[input, output],
axis="samples",
names=["structure"],
n_groups=2, # for train-test split
group_sizes=[0.8, 0.2], # relative, a 80% 20% train-test split
seed=100,
)
Split 2 tensors corresponding to input and output data into train, test, and
validation data. If input and output tensors have the same 10 structure
indices, we can split such that the train, test, and val tensors have 7,
2, and 1 structures in each, respectively. We want to specify absolute
group sizes, so will pass a list of int. Specifying the `seed` will
shuffle the structure indices before they are grouped.
.. code-block:: python
import metatensor
from equisolve.utils import split_data
# Find the unique structure indices in the input tensor
unique_structure_indices = metatensor.unique_metadata(
tensor=input, axis="samples", names=["structure"],
)
# They run from 0 -> 10 (inclusive)
unique_structure_indices
>>> Labels(
[(0,), (1,), (2,), (3,), (4,), (5,), (6,), (8,), (9,)],
dtype=[('structure', '<i4')]
)
# Verify that the output has the same unique structure indices
assert unique_structure_indices == metatensor.unique_metadata(
tensor=output, axis="samples", names=["structure"],
)
>>> True
# Split the data by structure index, with an abolute split of 7, 2, 1
# for the train, test, and validation tensors, respectively
(
[
[in_train, in_test, in_val],
[out_train, out_test, out_val]
]
), grouped_labels = split_data(
tensors=[input, output],
axis="samples",
names=["structure"],
n_groups=3, # for train-test-validation
group_sizes=[7, 2, 1], # absolute; 7, 2, 1 for train, test, val
seed=100,
)
# Inspect the grouped structure indices
grouped_labels
>>> [
Labels(
[(3,), (7,), (1,), (8,), (0,), (9,), (2,)],
dtype=[('structure', '<i4')]
),
Labels([(4,), (6,)], dtype=[('structure', '<i4')]),
Labels([(5,)], dtype=[('structure', '<i4')]),
]
"""
# Check input args and parse `tensors` and `names` into lists
tensors = [tensors] if isinstance(tensors, TensorMap) else tensors
names = [names] if isinstance(names, str) else names
_check_args(tensors, axis, names, n_groups, group_sizes, seed)
# Get array of unique indices to split by for each tensor in `tensors`
unique_idxs_list = [
metatensor.unique_metadata(tensor, axis, names) for tensor in tensors
]
# Check that the unique indices are equivalent for all input tensors
_check_labels_equivalent(unique_idxs_list)
unique_idxs = unique_idxs_list[0]
# Shuffle the unique indices according to the random seed if specified
if seed is not None:
rng = np.random.default_rng(seed)
shuffled_values = unique_idxs.values.copy()
rng.shuffle(shuffled_values)
unique_idxs = Labels(names=unique_idxs.names, values=shuffled_values)
# Must be at least as many unique indices as groups
n_indices = len(unique_idxs)
if n_indices < n_groups:
raise ValueError(
f"the number of groups specified ({n_groups}) is greater than the"
f" number of unique metadata indices ({n_indices}) for the"
f" chosen axis {axis} and names {names}: {unique_idxs}"
)
# Get group sizes
group_sizes = _get_group_sizes(n_groups, len(unique_idxs), group_sizes)
# The sum of the absolute group sizes must be less than or equal to the
# number of unique indices
if n_indices < sum(group_sizes):
raise ValueError(
f"the sum of the absolute group sizes ({sum(group_sizes)}) is greater than "
f"the number of unique metadata indices ({n_indices}) for the chosen "
f"axis {axis} and names {names}: {unique_idxs}"
)
# Group the indices according to the group sizes
grouped_labels = _group_indices(unique_idxs, group_sizes)
# Split each of the input TensorMaps
split_tensors = []
for tensor in tensors:
split_tensors.append(metatensor.split(tensor, axis, grouped_labels))
return split_tensors, grouped_labels
def _get_group_sizes(
n_groups: int,
n_indices: int,
group_sizes: Optional[Union[List[float], List[int]]] = None,
) -> np.ndarray:
"""
Parses the `group_sizes` arg from :py:func:`split_data` and returns an array
of group sizes in absolute terms. If `group_sizes` is None, the group sizes
returned are (to the nearest integer) evenly distributed across the number
of unique indices; i.e. if there are 12 unique indices (`n_indices=10`), and
`n_groups` is 3, the group sizes returned will be np.array([4, 4, 4]). If
`group_sizes` is specified as a list of floats (i.e. relative sizes, whose
sum is <= 1), the group sizes returned are converted to absolute sizes, i.e.
multiplied by `n_indices`. If `group_sizes` is specified as a list of int,
no conversion is performed. A cascade round is used to make sure that the
group sizes are integers, with the sum of the list preserved and the
rounding error minimized.
:param n_groups: an int, the number of groups to split the data into :param
n_indices: an int, the number of unique indices present in the data by
which the data should be grouped.
:param n_indices: a :py:class:`int` for the number of unique indices present
in the input data for the specified `axis` and `names`.
:param group_sizes: a :py:class:`list` of :py:class:`float` or
:py:class:`int` indicating the absolute or relative group sizes,
respectively.
:return: a :py:class:`numpy.ndarray` of :py:class:`int` indicating the
absolute group sizes.
"""
if group_sizes is None: # equally sized groups
group_sizes = np.array([1 / n_groups] * n_groups) * n_indices
elif np.all([isinstance(size, int) for size in group_sizes]): # absolute
group_sizes = np.array(group_sizes)
else: # relative; list of float
group_sizes = np.array(group_sizes) * n_indices
# The group sizes may not be integers. Use cascade rounding to round them
# all to integers whilst attempting to minimize rounding error.
group_sizes = _cascade_round(group_sizes)
return group_sizes
def _cascade_round(array: np.ndarray) -> np.ndarray:
"""
Given an array of floats that sum to an integer, this rounds the floats
and returns an array of integers with the same sum.
Adapted from https://jsfiddle.net/cd8xqy6e/.
"""
# Check type
if not isinstance(array, np.ndarray):
raise TypeError("must pass `array` as a numpy array.")
# Check sum
mod = np.sum(array) % 1
if not np.isclose(round(mod) - mod, 0):
raise ValueError("elements of `array` must sum to an integer.")
float_tot, integer_tot = 0, 0
rounded_array = []
for element in array:
new_int = round(element + float_tot) - integer_tot
float_tot += element
integer_tot += new_int
rounded_array.append(new_int)
# Check that the sum is preserved
assert round(np.sum(array)) == round(np.sum(rounded_array))
return np.array(rounded_array)
def _group_indices(indices: Labels, group_sizes: List[int]) -> List[Labels]:
"""
Splits `indices` into smaller groups according to the sizes specified in
`group_sizes`, and returned as a list of :py:class:`Labels` objects.
"""
# Group the indices
grouped_labels_values = []
prev_size = 0
for size in group_sizes:
grouped_labels_values.append(
indices.values[prev_size : prev_size + size].tolist()
)
prev_size += size
return [
Labels(
names=indices.names,
values=np.array(grouped_labels_values[i], dtype=indices.values.dtype),
)
for i in range(len(grouped_labels_values))
]
def _check_labels_equivalent(
labels_list: List[Labels],
):
"""
Checks that all Labels objects in the input List `labels_list` are
equivalent in names and values.
"""
if len(labels_list) <= 1:
return
# Define reference Labels object as the first in the list
ref_labels = labels_list[0]
for label_i in range(1, len(labels_list)):
test_label = labels_list[label_i]
if not np.array_equal(ref_labels, test_label):
raise ValueError(
"Labels objects in `labels_list` are not equivalent:"
f" {ref_labels} != {test_label}"
)
def _check_args(
tensors: List[TensorMap],
axis: str,
names: List[str],
n_groups: int,
group_sizes: Optional[Union[List[float], List[int]]],
seed: Optional[int],
):
"""Checks the input args for :py:func:`split_data`."""
# Check tensors passed as a list
if not isinstance(tensors, list):
raise TypeError(
f"`tensors` must be a list of metatensor `TensorMap`, got {type(tensors)}"
)
# Check all tensors in the list are TensorMaps
for tensor in tensors:
if not isinstance(tensor, TensorMap):
raise TypeError(
"`tensors` must be a list of metatensor `TensorMap`,"
f" got {type(tensors)}"
)
# Check axis
if not isinstance(axis, str):
raise TypeError(f"`axis` must be passed as a str, got {type(axis)}")
if axis not in ["samples", "properties"]:
raise ValueError(
f"`axis` must be passsed as either 'samples' or 'properties', got {axis}"
)
# Check names
if not isinstance(names, list):
raise TypeError(f"`names` must be a list of str, got {type(names)}")
if not all([isinstance(name, str) for name in names]):
raise TypeError(f"`names` must be a list of str, got {type(names)}")
for tensor in tensors:
tmp_names = (
tensor.samples_names if axis == "samples" else tensor.properties_names
)
for name in names:
if name not in tmp_names:
raise ValueError(
f"the passed `TensorMap` objects have {axis} names {tmp_names}"
f" that do not match the one passed in `names` {names}"
)
# Check n_groups
if not isinstance(n_groups, int):
raise TypeError(f"`n_groups` must be passed as an int, got {type(n_groups)}")
if not n_groups > 0:
raise ValueError(f"`n_groups` must be greater than 0, got {n_groups}")
# Check group_sizes
if group_sizes is not None:
if not isinstance(group_sizes, list):
raise TypeError(
"`group_sizes` must be passed as a list of float or int,"
f" got {type(group_sizes)}"
)
if len(group_sizes) != n_groups:
raise ValueError(
"if specifying `group_sizes`, you must pass a list whose"
" number of elements equal to `n_groups`"
)
for size in group_sizes:
if not isinstance(size, (int, float)):
raise TypeError(
"`group_sizes` must be passed as a list of float or int,"
f" got {type(group_sizes)}"
)
if not size > 0:
raise ValueError(
"all elements of `group_sizes` must be greater than 0,"
f" got {group_sizes}"
)
if np.all([isinstance(size, float) for size in group_sizes]):
if np.sum(group_sizes) > 1:
raise ValueError(
"if specifying `group_sizes` as a list of float, the sum of"
" the list must be less than or equal to 1"
)
# Check seed
if seed is not None:
if not isinstance(seed, int):
raise TypeError(f"`seed` must be passed as an `int`, got {type(seed)}")