"""Optimization methods and configs used by inverse VI workflows."""
import copy
from dataclasses import dataclass
import numpy as np
[docs]
@dataclass
class VIGradientOptimizerConfig:
gradient_method: str = 'standard'
gradient_norm_tolerance: float = 1e-5
max_iterations: int = 1000
max_log_std_update: float = 0.5
min_variational_std: float = 1e-6
max_variational_std: float = 1e6
[docs]
@dataclass
class VINewtonOptimizerConfig:
gradient_norm_tolerance: float = 5e-5
max_iterations: int = 1000
max_log_std_update: float = 0.5
min_variational_std: float = 1e-8
max_variational_std: float = 1e6
newton_metric: str = 'standard'
newton_regularization: float = 1e-2
newton_hessian_type: str = 'diagonal'
[docs]
@dataclass
class VILegacyLineSearchConfig:
initial_step_size: float = 1e-2
max_step_size: float = np.inf
step_size_growth_factor: float = 1.01
step_size_decay_factor: float = 3.0
max_step_size_decrease_trys: int = 5
relaxation_parameter: float = 2.05
line_search_objective: str = 'elbo'
line_search_sample_growth_factor: float = 1.0
log_std_learning_rate_factor: float = 0.1
[docs]
@dataclass
class VIStochasticNonmonotoneLineSearchConfig:
initial_step_size: float = 1e-2
max_step_size: float = np.inf
step_size_growth_factor: float = 1.05
step_size_decay_factor: float = 2.0
max_step_size_decrease_trys: int = 5
relaxation_parameter: float = 3.05
line_search_objective: str = 'elbo'
line_search_nonmonotone_window: int = 5
line_search_armijo_coefficient: float = 1e-6
line_search_uncertainty_sigma: float = 4.0
line_search_sample_growth_factor: float = 1.0
log_std_learning_rate_factor: float = 1.0
def _normalize_optimization_method(optimization_method: str):
method = optimization_method.strip().lower()
if method not in ('gradient', 'newton'):
raise ValueError(
f"Unsupported optimization_method '{optimization_method}'. "
"Supported options are 'gradient' and 'newton'."
)
return method
def _normalize_line_search_objective(line_search_objective: str):
objective = line_search_objective.strip().lower()
if objective not in ('elbo', 'mse'):
raise ValueError(
f"Unsupported line_search_objective '{line_search_objective}'. "
"Supported options are 'elbo' and 'mse'."
)
return objective
def _normalize_line_search_method(line_search_method: str):
method = line_search_method.strip().lower()
if method in ('legacy', 'basic'):
return 'legacy'
if method in ('stochastic_nonmonotone', 'stochastic', 'advanced'):
return 'stochastic_nonmonotone'
raise ValueError(
f"Unsupported line_search_method '{line_search_method}'. "
"Supported options are 'legacy' and 'stochastic_nonmonotone'."
)
def _normalize_newton_metric(newton_metric: str):
metric = newton_metric.strip().lower()
if metric in ('standard', 'euclidean'):
return 'standard'
if metric in ('natural',):
return 'natural'
raise ValueError(
f"Unsupported newton_metric '{newton_metric}'. "
"Supported options are 'standard' and 'natural'."
)
def _normalize_newton_hessian_type(newton_hessian_type: str):
hessian_type = newton_hessian_type.strip().lower()
if hessian_type in ('diagonal', 'diag'):
return 'diagonal'
if hessian_type in ('full', 'dense'):
return 'full'
raise ValueError(
f"Unsupported newton_hessian_type '{newton_hessian_type}'. "
"Supported options are 'diagonal' and 'full'."
)
def _resolve_optimizer_config(optimizer_method: str,
optimizer_config,
default_gradient_config: VIGradientOptimizerConfig = None,
default_newton_config: VINewtonOptimizerConfig = None):
normalized_optimizer_method = _normalize_optimization_method(optimizer_method)
gradient_config = copy.deepcopy(default_gradient_config) if default_gradient_config is not None else (
VIGradientOptimizerConfig()
)
newton_config = copy.deepcopy(default_newton_config) if default_newton_config is not None else (
VINewtonOptimizerConfig()
)
optimizer_type_by_method = {
'gradient': VIGradientOptimizerConfig,
'newton': VINewtonOptimizerConfig,
}
default_config_by_method = {
'gradient': gradient_config,
'newton': newton_config,
}
if optimizer_config is None:
resolved_optimizer_config = copy.deepcopy(default_config_by_method[normalized_optimizer_method])
else:
expected_optimizer_type = optimizer_type_by_method[normalized_optimizer_method]
if not isinstance(optimizer_config, expected_optimizer_type):
raise TypeError(
f"optimizer_config for optimizer_method='{normalized_optimizer_method}' "
f"must be of type {expected_optimizer_type.__name__}."
)
resolved_optimizer_config = copy.deepcopy(optimizer_config)
return normalized_optimizer_method, resolved_optimizer_config
def _resolve_line_search_config(line_search_method: str,
line_search_config,
default_legacy_config: VILegacyLineSearchConfig = None,
default_stochastic_config: VIStochasticNonmonotoneLineSearchConfig = None):
normalized_line_search_method = _normalize_line_search_method(line_search_method)
legacy_config = (
copy.deepcopy(default_legacy_config)
if default_legacy_config is not None else VILegacyLineSearchConfig()
)
stochastic_config = (
copy.deepcopy(default_stochastic_config)
if default_stochastic_config is not None else VIStochasticNonmonotoneLineSearchConfig()
)
line_search_type_by_method = {
'legacy': VILegacyLineSearchConfig,
'stochastic_nonmonotone': VIStochasticNonmonotoneLineSearchConfig,
}
default_config_by_method = {
'legacy': legacy_config,
'stochastic_nonmonotone': stochastic_config,
}
explicit_line_search_config = line_search_config is not None
if line_search_config is None:
resolved_line_search_config = copy.deepcopy(default_config_by_method[normalized_line_search_method])
explicit_line_search_config = False
else:
expected_line_search_type = line_search_type_by_method[normalized_line_search_method]
if not isinstance(line_search_config, expected_line_search_type):
raise TypeError(
f"line_search_config for line_search_method='{normalized_line_search_method}' "
f"must be of type {expected_line_search_type.__name__}."
)
resolved_line_search_config = copy.deepcopy(line_search_config)
return normalized_line_search_method, resolved_line_search_config
[docs]
class SteepestDescentSolver:
"""Generic steepest-descent solver."""
def step(self, gradient: np.ndarray) -> np.ndarray:
return np.nan_to_num(gradient, nan=0.0, posinf=0.0, neginf=0.0)
[docs]
class NewtonSolver:
"""Generic Newton solver supporting diagonal or dense Hessians."""
def __init__(self, regularization: float, hessian_type: str = 'diagonal'):
self.regularization = regularization
self.hessian_type = _normalize_newton_hessian_type(hessian_type)
def _project_hessian(self, hessian) -> np.ndarray:
hessian = np.nan_to_num(hessian, nan=0.0, posinf=0.0, neginf=0.0)
if hessian.ndim == 1:
projected_diagonal = np.maximum(np.abs(hessian), self.regularization)
return np.diag(projected_diagonal)
if hessian.ndim != 2 or hessian.shape[0] != hessian.shape[1]:
raise ValueError("Hessian must be a 1D diagonal or a square 2D matrix.")
sym_hessian = 0.5 * (hessian + hessian.T)
eigenvalues, eigenvectors = np.linalg.eigh(sym_hessian)
projected_eigenvalues = np.maximum(np.abs(eigenvalues), self.regularization)
return (eigenvectors @ np.diag(projected_eigenvalues)) @ eigenvectors.T
def step(self,
gradient: np.ndarray,
hessian) -> np.ndarray:
gradient = np.nan_to_num(gradient, nan=0.0, posinf=0.0, neginf=0.0)
projected_hessian = self._project_hessian(hessian)
print('Gradient and hessian norms:', np.linalg.norm(gradient),np.linalg.norm(projected_hessian))
if self.hessian_type == 'full':
return np.linalg.solve(projected_hessian, gradient)
return np.linalg.solve(np.diag(np.diag(projected_hessian)), gradient)