Source code for smartsim.ml.tf.data

# BSD 2-Clause License
#
# Copyright (c) 2021-2024, Hewlett Packard Enterprise
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import typing as t

import numpy as np
from tensorflow import keras

from smartsim.ml import DataDownloader

if t.TYPE_CHECKING:
    import numpy.typing as npt


class _TFDataGenerationCommon(DataDownloader, keras.utils.Sequence):
    def __getitem__(
        self, index: int
    ) -> t.Tuple[np.ndarray, np.ndarray]:  # type: ignore[type-arg]
        if len(self) < 1:
            raise ValueError(
                "Not enough samples in generator for one batch. Please "
                "run init_samples() or initialize generator with init_samples=True"
            )
        # Generate indices of the batch
        indices = self.indices[index * self.batch_size : (index + 1) * self.batch_size]

        # Generate data
        xval, yval = self._data_generation(indices)

        if yval is not None:
            return xval, yval

        return xval

    def on_epoch_end(self) -> None:
        """Callback called at the end of each training epoch

        If `self.shuffle` is set to `True`, data is shuffled.
        """
        if self.shuffle:
            np.random.shuffle(self.indices)

    def _data_generation(
        self, indices: "npt.NDArray[t.Any]"
    ) -> t.Tuple["npt.NDArray[t.Any]", "npt.NDArray[t.Any]"]:
        # Initialization
        if self.samples is None:
            raise ValueError("No samples loaded for data generation")

        xval = self.samples[indices]

        if self.need_targets:
            yval = t.cast("npt.NDArray[t.Any]", self.targets)[indices]
            if self.num_classes is not None:
                yval = keras.utils.to_categorical(yval, num_classes=self.num_classes)
        elif self.autoencoding:
            yval = xval
        else:
            return xval  # type: ignore[no-any-return]

        return xval, yval


[docs]class StaticDataGenerator(_TFDataGenerationCommon): """A class to download a dataset from the DB. Details about parameters and features of this class can be found in the documentation of ``DataDownloader``, of which it is just a TensorFlow-specialized sub-class with dynamic=False. """ def __init__(self, **kwargs: t.Any) -> None: dynamic = kwargs.pop("dynamic", False) kwargs["dynamic"] = False super().__init__(**kwargs) if dynamic: msg = ( "Static data generator cannot be started with dynamic=True, " "setting it to False" ) self.log(msg)
[docs]class DynamicDataGenerator(_TFDataGenerationCommon): """A class to download batches from the DB. Details about parameters and features of this class can be found in the documentation of ``DataDownloader``, of which it is just a TensorFlow-specialized sub-class with dynamic=True. """ def __init__(self, **kwargs: t.Any) -> None: dynamic = kwargs.pop("dynamic", True) kwargs["dynamic"] = True super().__init__(**kwargs) if not dynamic: msg = ( "Dynamic data generator cannot be started with dynamic=False," " setting it to True" ) self.log(msg)
[docs] def on_epoch_end(self) -> None: """Callback called at the end of each training epoch Update data (the DB is queried for new batches) and if `self.shuffle` is set to `True`, data is also shuffled. """ self.update_data()