"""
Optimizer with cross validation score
"""
import numpy as np
from sklearn.model_selection import KFold, ShuffleSplit
from typing import Any, Dict, Tuple
from .base_optimizer import BaseOptimizer
from .optimizer import Optimizer
from .tools import ScatterData
validation_methods = {
'k-fold': KFold,
'shuffle-split': ShuffleSplit,
}
[docs]class CrossValidationEstimator(BaseOptimizer):
r"""
This class provides an optimizer with cross validation for solving the
linear :math:`\boldsymbol{A}\boldsymbol{x} = \boldsymbol{y}` problem.
Cross-validation (CV) scores are calculated by splitting the
available reference data in multiple different ways. It also produces
the finalized model (using the full input data) for which the CV score
is an estimation of its performance.
Warning
-------
Repeatedly setting up a :class:`CrossValidationEstimator` and training
*without* changing the seed for the random number generator will yield
identical or correlated results, to avoid this please specify a different
seed when setting up multiple :class:`CrossValidationEstimator` instances.
Parameters
----------
fit_data : tuple(numpy.ndarray, numpy.ndarray)
the first element of the tuple represents the fit matrix `A`
(`N, M` array) while the second element represents the vector
of target values `y` (`N` array); here `N` (=rows of `A`,
elements of `y`) equals the number of target values and `M`
(=columns of `A`) equals the number of parameters
fit_method : str
method to be used for training; possible choice are
"ardr", "bayesian-ridge", "elasticnet", "lasso", "least-squares",
"omp", "rfe", "ridge", "split-bregman"
standardize : bool
if True the fit matrix and target values are standardized before fitting,
meaning columns in the fit matrix and th target values are rescaled to
have a standard deviation of 1.0.
validation_method : str
method to use for cross-validation; possible choices are
"shuffle-split", "k-fold"
n_splits : int
number of times the fit data set will be split for the cross-validation
check_condition : bool
if True the condition number will be checked
(this can be sligthly more time consuming for larger
matrices)
seed : int
seed for pseudo random number generator
Attributes
----------
scatter_data_train : ScatterData
contains target and predicted values from each individual
traininig set in the cross-validation split
scatter_data_validation : ScatterData
contains target and predicted values from each individual
validation set in the cross-validation split
"""
def __init__(self,
fit_data: Tuple[np.ndarray, np.ndarray],
fit_method: str = 'least-squares',
standardize: bool = True,
validation_method: str = 'k-fold',
n_splits: int = 10,
check_condition: bool = True,
seed: int = 42,
**kwargs) -> None:
super().__init__(fit_data, fit_method, standardize, check_condition, seed)
if validation_method not in validation_methods.keys():
msg = ['Validation method not available']
msg += ['Please choose one of the following:']
for key in validation_methods:
msg += [' * ' + key]
raise ValueError('\n'.join(msg))
self._validation_method = validation_method
self._n_splits = n_splits
self._set_kwargs(kwargs)
# data set splitting object
self._splitter = validation_methods[validation_method](
n_splits=self.n_splits, random_state=seed, **self._split_kwargs)
self.scatter_data_train = None
self.scatter_data_validation = None
self.model_splits = None
[docs] def train(self) -> None:
""" Constructs the final model using all input data available. """
opt = Optimizer((self._A, self._y), self.fit_method,
standardize=self.standardize,
train_size=1.0,
check_condition=self._check_condition,
**self._fit_kwargs)
opt.train()
self.model = opt.model
[docs] def validate(self) -> None:
""" Runs validation. """
self.scatter_data_train = ScatterData()
self.scatter_data_validation = ScatterData()
self.model_splits = []
for train_set, test_set in self._splitter.split(self._A):
opt = Optimizer((self._A, self._y), self.fit_method,
standardize=self.standardize,
train_set=train_set,
test_set=test_set,
check_condition=self._check_condition,
**self._fit_kwargs)
opt.train()
self.model_splits.append(opt.model)
self.scatter_data_train += opt.scatter_data_train
self.scatter_data_validation += opt.scatter_data_test
def _set_kwargs(self, kwargs: dict) -> None:
"""
Sets up fit_kwargs and split_kwargs.
Different split methods need different keywords.
"""
self._fit_kwargs = {}
self._split_kwargs = {}
if self.validation_method == 'k-fold':
self._split_kwargs['shuffle'] = True # default True
for key, val in kwargs.items():
if key in ['shuffle']:
self._split_kwargs[key] = val
else:
self._fit_kwargs[key] = val
elif self.validation_method == 'shuffle-split':
for key, val in kwargs.items():
if key in ['test_size', 'train_size']:
self._split_kwargs[key] = val
else:
self._fit_kwargs[key] = val
@property
def summary(self) -> Dict[str, Any]:
""" Comprehensive information about the optimizer """
info = super().summary
# add model metrics
info = {**info, **self.model.to_dict()}
# Add class specific data
info['validation_method'] = self.validation_method
info['n_splits'] = self.n_splits
info['rmse_train'] = self.rmse_train
info['rmse_train_final'] = self.rmse_train_final
info['rmse_train_splits'] = self.rmse_train_splits
info['rmse_validation'] = self.rmse_validation
info['R2_validation'] = self.R2_validation
info['rmse_validation_splits'] = self.rmse_validation_splits
info['scatter_data_train'] = self.scatter_data_train
info['scatter_data_validation'] = self.scatter_data_validation
# add kwargs used for fitting and splitting
info = {**info, **self._fit_kwargs, **self._split_kwargs}
return info
def __repr__(self) -> str:
kwargs = dict()
kwargs['fit_method'] = self.fit_method
kwargs['validation_method'] = self.validation_method
kwargs['n_splits'] = self.n_splits
kwargs['seed'] = self.seed
kwargs = {**kwargs, **self._fit_kwargs, **self._split_kwargs}
return 'CrossValidationEstimator((A, y), {})'.format(
', '.join('{}={}'.format(*kwarg) for kwarg in kwargs.items()))
@property
def validation_method(self) -> str:
""" Validation method name """
return self._validation_method
@property
def n_splits(self) -> int:
""" Number of splits (folds) used for cross-validation """
return self._n_splits
@property
def parameters_splits(self) -> np.ndarray:
""" All parameters obtained during cross-validation """
if self.model_splits is None:
return None
else:
return np.array([model.parameters for model in self.model_splits])
@property
def n_nonzero_parameters_splits(self) -> np.ndarray:
""" Number of non-zero parameters for each split """
if self.model_splits is None:
return None
else:
return np.array([np.count_nonzero(p) for p in self.parameters_splits])
@property
def rmse_train_final(self) -> float:
""" Root mean squared error when using the full set of input data """
if self.model is None:
return None
else:
return self.model.rmse_train
@property
def rmse_train(self) -> float:
""" Average root mean squared training error obtained during cross-validation """
if self.model_splits is None:
return None
else:
return np.mean(self.rmse_train_splits)
@property
def rmse_train_splits(self) -> np.ndarray:
""" Root mean squared training errors obtained during cross-validation """
if self.model_splits is None:
return None
else:
return np.array([model.rmse_train for model in self.model_splits])
@property
def rmse_validation(self) -> float:
""" Average root mean squared cross-validation error """
if self.model_splits is None:
return None
else:
return np.mean(self.rmse_validation_splits)
@property
def R2_validation(self) -> float:
""" Average R2 score for validation sets """
if self.model_splits is None:
return None
else:
return np.mean([model.R2_test for model in self.model_splits])
@property
def rmse_validation_splits(self) -> np.ndarray:
""" Root mean squared validation errors obtained during cross-validation """
if self.model_splits is None:
return None
else:
return np.array([model.rmse_test for model in self.model_splits])
@property
def AIC(self) -> float:
""" Akaike information criterion (AIC) for the model """
return self.model.AIC
@property
def BIC(self) -> float:
""" Bayesian information criterion (BIC) for the model """
return self.model.BIC