#
# ************************************************************************
#
# ROM Tools and Workflows
# Copyright 2019 National Technology & Engineering Solutions of Sandia,LLC
# (NTESS)
#
# Under the terms of Contract DE-NA0003525 with NTESS, the
# U.S. Government retains certain rights in this software.
#
# ROM Tools and Workflows is licensed under BSD-3-Clause terms of use:
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
# IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# Questions? Contact Eric Parish (ejparis@sandia.gov)
#
# ************************************************************************
#
'''
Constructing a basis via POD typically entails computing the SVD of a snapshot matrix,
.. math::
\\mathbf{U} ,\\mathbf{\\Sigma} = \\mathrm{svd}(\\mathbf{S})
and then selecting the first :math:`K` left singular vectors (i.e., the first :math:`K`
columns of :math:`\\mathbf{U}`). Typically, :math:`K` is determined through the decay of
the singular values.
The truncater class is desined to truncate a basis.
We provide concrete implementations that truncate based on a specified number
of basis vectors and the decay of the singular values
'''
from typing import Protocol
import numpy as np
import romtools.linalg.linalg as la
[docs]
class LeftSingularVectorTruncater(Protocol):
'''
Interface for the Truncater class.
'''
[docs]
def truncate(self, basis: np.ndarray, singular_values: np.ndarray) -> np.ndarray:
'''
Truncate left singular vectors
'''
...
[docs]
class NoOpTruncater():
'''
No op implementation
This class conforms to `LeftSingularVectorTruncater` protocol.
'''
def __init__(self) -> None:
pass
def truncate(self, basis: np.ndarray, singular_values: np.ndarray) -> np.ndarray: # pylint: disable=unused-argument
return basis
[docs]
class BasisSizeTruncater():
'''
Truncates to a specified number of singular vectors, as specified in the constructor
This class conforms to `LeftSingularVectorTruncater` protocol.
'''
def __init__(self, basis_dimension: int) -> None:
'''
Constructor for the BasisSizeTruncater class.
Args:
basis_dimension (int): The desired dimension of the truncated basis.
'''
# Check if basis dimension is less than or equal to zero
if basis_dimension <= 0:
raise ValueError('Given basis dimension is <= 0: ', basis_dimension)
self.__basis_dimension = basis_dimension
self.__singular_values = None
[docs]
def truncate(self, basis: np.ndarray, singular_values: np.ndarray) -> np.ndarray: # pylint: disable=unused-argument
'''
Truncate the basis based on the specified dimension.
Args:
basis (np.ndarray): The original basis matrix.
singular_values (np.ndarray): The array of singular values associated with the basis matrix.
Returns:
np.ndarray: The truncated basis matrix with the specified dimension.
'''
# Check if basis dimension is larger than array and give error.
self.__singular_values = singular_values
if self.__basis_dimension > np.shape(basis)[1]:
raise ValueError('Given basis dimension is greater than size of basis array: ',
self.__basis_dimension, ' > ', np.shape(basis)[1])
return basis[:, :self.__basis_dimension]
[docs]
def get_energy(self):
'''
Returns:
float: The energy criteria corresponding to the truncated basis
'''
if self.__singular_values is None:
raise ValueError('Error, singular values not yet initialized. Must call truncate before calling get_energy')
energy = np.cumsum(self.__singular_values**2)/(np.sum(self.__singular_values**2) + 1.e-30)
energy = energy[self.__basis_dimension-1]
return energy
[docs]
class EnergyBasedTruncater():
'''
Truncates based on the decay of singular values, i.e., will define :math:`K` to
be the number of singular values such that the cumulative energy retained
is greater than some threshold.
This class conforms to `LeftSingularVectorTruncater` protocol.
'''
def __init__(self, threshold: float) -> None:
'''
Constructor for the EnergyTruncater class.
Args:
threshold (float): The cumulative energy threshold.
'''
self.__energy_threshold_ = threshold
self.__energy = None
[docs]
def truncate(self, basis: np.ndarray, singular_values: np.ndarray) -> np.ndarray:
'''
Truncate the basis based on the energy threshold.
Args:
basis (np.ndarray): The original basis matrix.
singular_values (np.ndarray): The array of singular values associated with the basis matrix.
Returns:
np.ndarray: The truncated basis matrix based on the energy threshold.
'''
energy = np.cumsum(singular_values**2)/(np.sum(singular_values**2) + 1.e-30)
print(energy > self.__energy_threshold_)
basis_dimension = la.argmax(energy > self.__energy_threshold_) + 1
self.__energy = energy[basis_dimension - 1]
return basis[:, 0:basis_dimension]
[docs]
def get_energy(self):
'''
Returns:
float: The energy criteria corresponding to the truncated basis
'''
if self.__energy is None:
raise ValueError('Error, energy not yet computed. Must call truncate before calling get_energy')
return self.__energy