romtools.vector_space.utils.truncater
Constructing a basis via POD typically entails computing the SVD of a snapshot matrix, $$ \mathbf{U} ,\mathbf{\Sigma} = \mathrm{svd}(\mathbf{S})$$ and then selecting the first $K$ left singular vectors (i.e., the first $K$ columns of $\mathbf{U}$). Typically, $K$ is determined through the decay of the singular values.
The truncater class is desined to truncate a basis. We provide concrete implementations that truncate based on a specified number of basis vectors and the decay of the singular values
1# 2# ************************************************************************ 3# 4# ROM Tools and Workflows 5# Copyright 2019 National Technology & Engineering Solutions of Sandia,LLC 6# (NTESS) 7# 8# Under the terms of Contract DE-NA0003525 with NTESS, the 9# U.S. Government retains certain rights in this software. 10# 11# ROM Tools and Workflows is licensed under BSD-3-Clause terms of use: 12# 13# Redistribution and use in source and binary forms, with or without 14# modification, are permitted provided that the following conditions 15# are met: 16# 17# 1. Redistributions of source code must retain the above copyright 18# notice, this list of conditions and the following disclaimer. 19# 20# 2. Redistributions in binary form must reproduce the above copyright 21# notice, this list of conditions and the following disclaimer in the 22# documentation and/or other materials provided with the distribution. 23# 24# 3. Neither the name of the copyright holder nor the names of its 25# contributors may be used to endorse or promote products derived 26# from this software without specific prior written permission. 27# 28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 29# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 30# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 31# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 32# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 33# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 34# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 35# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 36# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 37# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING 38# IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 39# POSSIBILITY OF SUCH DAMAGE. 40# 41# Questions? Contact Eric Parish (ejparis@sandia.gov) 42# 43# ************************************************************************ 44# 45 46''' 47Constructing a basis via POD typically entails computing the SVD of a snapshot matrix, 48$$ \\mathbf{U} ,\\mathbf{\\Sigma} = \\mathrm{svd}(\\mathbf{S})$$ 49and then selecting the first $K$ left singular vectors (i.e., the first $K$ 50columns of $\\mathbf{U}$). Typically, $K$ is determined through the decay of 51the singular values. 52 53The truncater class is desined to truncate a basis. 54We provide concrete implementations that truncate based on a specified number 55of basis vectors and the decay of the singular values 56''' 57 58from typing import Protocol 59import numpy as np 60import romtools.linalg.linalg as la 61 62 63class LeftSingularVectorTruncater(Protocol): 64 ''' 65 Interface for the Truncater class. 66 ''' 67 68 def truncate(self, basis: np.ndarray, singular_values: np.ndarray) -> np.ndarray: 69 ''' 70 Truncate left singular vectors 71 ''' 72 ... 73 74 75class NoOpTruncater(): 76 ''' 77 No op implementation 78 79 This class conforms to `LeftSingularVectorTruncater` protocol. 80 ''' 81 def __init__(self) -> None: 82 pass 83 84 def truncate(self, basis: np.ndarray, singular_values: np.ndarray) -> np.ndarray: # pylint: disable=unused-argument 85 return basis 86 87 88class BasisSizeTruncater(): 89 ''' 90 Truncates to a specified number of singular vectors, as specified in the constructor 91 92 This class conforms to `LeftSingularVectorTruncater` protocol. 93 ''' 94 def __init__(self, basis_dimension: int) -> None: 95 ''' 96 Constructor for the BasisSizeTruncater class. 97 98 Args: 99 basis_dimension (int): The desired dimension of the truncated basis. 100 ''' 101 # Check if basis dimension is less than or equal to zero 102 if basis_dimension <= 0: 103 raise ValueError('Given basis dimension is <= 0: ', basis_dimension) 104 105 self.__basis_dimension = basis_dimension 106 107 def truncate(self, basis: np.ndarray, singular_values: np.ndarray) -> np.ndarray: # pylint: disable=unused-argument 108 ''' 109 Truncate the basis based on the specified dimension. 110 111 Args: 112 basis (np.ndarray): The original basis matrix. 113 singular_values (np.ndarray): The array of singular values associated with the basis matrix. 114 115 Returns: 116 np.ndarray: The truncated basis matrix with the specified dimension. 117 ''' 118 # Check if basis dimension is larger than array and give error. 119 if self.__basis_dimension > np.shape(basis)[1]: 120 raise ValueError('Given basis dimension is greater than size of basis array: ', 121 self.__basis_dimension, ' > ', np.shape(basis)[1]) 122 123 return basis[:, :self.__basis_dimension] 124 125 126class EnergyBasedTruncater(): 127 ''' 128 Truncates based on the decay of singular values, i.e., will define $K$ to 129 be the number of singular values such that the cumulative energy retained 130 is greater than some threshold. 131 132 This class conforms to `LeftSingularVectorTruncater` protocol. 133 ''' 134 def __init__(self, threshold: float) -> None: 135 ''' 136 Constructor for the EnergyTruncater class. 137 138 Args: 139 threshold (float): The cumulative energy threshold. 140 ''' 141 self.__energy_threshold_ = threshold 142 143 def truncate(self, basis: np.ndarray, singular_values: np.ndarray) -> np.ndarray: 144 ''' 145 Truncate the basis based on the energy threshold. 146 147 Args: 148 basis (np.ndarray): The original basis matrix. 149 singular_values (np.ndarray): The array of singular values associated with the basis matrix. 150 151 Returns: 152 np.ndarray: The truncated basis matrix based on the energy threshold. 153 ''' 154 energy = np.cumsum(singular_values**2)/np.sum(singular_values**2) 155 basis_dimension = la.argmax(energy > self.__energy_threshold_) + 1 156 return basis[:, 0:basis_dimension]
64class LeftSingularVectorTruncater(Protocol): 65 ''' 66 Interface for the Truncater class. 67 ''' 68 69 def truncate(self, basis: np.ndarray, singular_values: np.ndarray) -> np.ndarray: 70 ''' 71 Truncate left singular vectors 72 ''' 73 ...
Interface for the Truncater class.
1771def _no_init_or_replace_init(self, *args, **kwargs): 1772 cls = type(self) 1773 1774 if cls._is_protocol: 1775 raise TypeError('Protocols cannot be instantiated') 1776 1777 # Already using a custom `__init__`. No need to calculate correct 1778 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1779 if cls.__init__ is not _no_init_or_replace_init: 1780 return 1781 1782 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1783 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1784 # searches for a proper new `__init__` in the MRO. The new `__init__` 1785 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1786 # instantiation of the protocol subclass will thus use the new 1787 # `__init__` and no longer call `_no_init_or_replace_init`. 1788 for base in cls.__mro__: 1789 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1790 if init is not _no_init_or_replace_init: 1791 cls.__init__ = init 1792 break 1793 else: 1794 # should not happen 1795 cls.__init__ = object.__init__ 1796 1797 cls.__init__(self, *args, **kwargs)
76class NoOpTruncater(): 77 ''' 78 No op implementation 79 80 This class conforms to `LeftSingularVectorTruncater` protocol. 81 ''' 82 def __init__(self) -> None: 83 pass 84 85 def truncate(self, basis: np.ndarray, singular_values: np.ndarray) -> np.ndarray: # pylint: disable=unused-argument 86 return basis
No op implementation
This class conforms to LeftSingularVectorTruncater
protocol.
89class BasisSizeTruncater(): 90 ''' 91 Truncates to a specified number of singular vectors, as specified in the constructor 92 93 This class conforms to `LeftSingularVectorTruncater` protocol. 94 ''' 95 def __init__(self, basis_dimension: int) -> None: 96 ''' 97 Constructor for the BasisSizeTruncater class. 98 99 Args: 100 basis_dimension (int): The desired dimension of the truncated basis. 101 ''' 102 # Check if basis dimension is less than or equal to zero 103 if basis_dimension <= 0: 104 raise ValueError('Given basis dimension is <= 0: ', basis_dimension) 105 106 self.__basis_dimension = basis_dimension 107 108 def truncate(self, basis: np.ndarray, singular_values: np.ndarray) -> np.ndarray: # pylint: disable=unused-argument 109 ''' 110 Truncate the basis based on the specified dimension. 111 112 Args: 113 basis (np.ndarray): The original basis matrix. 114 singular_values (np.ndarray): The array of singular values associated with the basis matrix. 115 116 Returns: 117 np.ndarray: The truncated basis matrix with the specified dimension. 118 ''' 119 # Check if basis dimension is larger than array and give error. 120 if self.__basis_dimension > np.shape(basis)[1]: 121 raise ValueError('Given basis dimension is greater than size of basis array: ', 122 self.__basis_dimension, ' > ', np.shape(basis)[1]) 123 124 return basis[:, :self.__basis_dimension]
Truncates to a specified number of singular vectors, as specified in the constructor
This class conforms to LeftSingularVectorTruncater
protocol.
95 def __init__(self, basis_dimension: int) -> None: 96 ''' 97 Constructor for the BasisSizeTruncater class. 98 99 Args: 100 basis_dimension (int): The desired dimension of the truncated basis. 101 ''' 102 # Check if basis dimension is less than or equal to zero 103 if basis_dimension <= 0: 104 raise ValueError('Given basis dimension is <= 0: ', basis_dimension) 105 106 self.__basis_dimension = basis_dimension
Constructor for the BasisSizeTruncater class.
Arguments:
- basis_dimension (int): The desired dimension of the truncated basis.
108 def truncate(self, basis: np.ndarray, singular_values: np.ndarray) -> np.ndarray: # pylint: disable=unused-argument 109 ''' 110 Truncate the basis based on the specified dimension. 111 112 Args: 113 basis (np.ndarray): The original basis matrix. 114 singular_values (np.ndarray): The array of singular values associated with the basis matrix. 115 116 Returns: 117 np.ndarray: The truncated basis matrix with the specified dimension. 118 ''' 119 # Check if basis dimension is larger than array and give error. 120 if self.__basis_dimension > np.shape(basis)[1]: 121 raise ValueError('Given basis dimension is greater than size of basis array: ', 122 self.__basis_dimension, ' > ', np.shape(basis)[1]) 123 124 return basis[:, :self.__basis_dimension]
Truncate the basis based on the specified dimension.
Arguments:
- basis (np.ndarray): The original basis matrix.
- singular_values (np.ndarray): The array of singular values associated with the basis matrix.
Returns:
np.ndarray: The truncated basis matrix with the specified dimension.
127class EnergyBasedTruncater(): 128 ''' 129 Truncates based on the decay of singular values, i.e., will define $K$ to 130 be the number of singular values such that the cumulative energy retained 131 is greater than some threshold. 132 133 This class conforms to `LeftSingularVectorTruncater` protocol. 134 ''' 135 def __init__(self, threshold: float) -> None: 136 ''' 137 Constructor for the EnergyTruncater class. 138 139 Args: 140 threshold (float): The cumulative energy threshold. 141 ''' 142 self.__energy_threshold_ = threshold 143 144 def truncate(self, basis: np.ndarray, singular_values: np.ndarray) -> np.ndarray: 145 ''' 146 Truncate the basis based on the energy threshold. 147 148 Args: 149 basis (np.ndarray): The original basis matrix. 150 singular_values (np.ndarray): The array of singular values associated with the basis matrix. 151 152 Returns: 153 np.ndarray: The truncated basis matrix based on the energy threshold. 154 ''' 155 energy = np.cumsum(singular_values**2)/np.sum(singular_values**2) 156 basis_dimension = la.argmax(energy > self.__energy_threshold_) + 1 157 return basis[:, 0:basis_dimension]
Truncates based on the decay of singular values, i.e., will define $K$ to be the number of singular values such that the cumulative energy retained is greater than some threshold.
This class conforms to LeftSingularVectorTruncater
protocol.
135 def __init__(self, threshold: float) -> None: 136 ''' 137 Constructor for the EnergyTruncater class. 138 139 Args: 140 threshold (float): The cumulative energy threshold. 141 ''' 142 self.__energy_threshold_ = threshold
Constructor for the EnergyTruncater class.
Arguments:
- threshold (float): The cumulative energy threshold.
144 def truncate(self, basis: np.ndarray, singular_values: np.ndarray) -> np.ndarray: 145 ''' 146 Truncate the basis based on the energy threshold. 147 148 Args: 149 basis (np.ndarray): The original basis matrix. 150 singular_values (np.ndarray): The array of singular values associated with the basis matrix. 151 152 Returns: 153 np.ndarray: The truncated basis matrix based on the energy threshold. 154 ''' 155 energy = np.cumsum(singular_values**2)/np.sum(singular_values**2) 156 basis_dimension = la.argmax(energy > self.__energy_threshold_) + 1 157 return basis[:, 0:basis_dimension]
Truncate the basis based on the energy threshold.
Arguments:
- basis (np.ndarray): The original basis matrix.
- singular_values (np.ndarray): The array of singular values associated with the basis matrix.
Returns:
np.ndarray: The truncated basis matrix based on the energy threshold.