Source code for trainstation.cross_validation

"""
Optimizer with cross validation score
"""

import numpy as np
from sklearn.model_selection import KFold, ShuffleSplit
from typing import Any, Dict, Tuple
from .base_optimizer import BaseOptimizer
from .optimizer import Optimizer
from .tools import ScatterData


validation_methods = {
    'k-fold': KFold,
    'shuffle-split': ShuffleSplit,
}


[docs]class CrossValidationEstimator(BaseOptimizer): r""" This class provides an optimizer with cross validation for solving the linear :math:`\boldsymbol{A}\boldsymbol{x} = \boldsymbol{y}` problem. Cross-validation (CV) scores are calculated by splitting the available reference data in multiple different ways. It also produces the finalized model (using the full input data) for which the CV score is an estimation of its performance. Warning ------- Repeatedly setting up a :class:`CrossValidationEstimator` and training *without* changing the seed for the random number generator will yield identical or correlated results, to avoid this please specify a different seed when setting up multiple :class:`CrossValidationEstimator` instances. Parameters ---------- fit_data : tuple(numpy.ndarray, numpy.ndarray) the first element of the tuple represents the fit matrix `A` (`N, M` array) while the second element represents the vector of target values `y` (`N` array); here `N` (=rows of `A`, elements of `y`) equals the number of target values and `M` (=columns of `A`) equals the number of parameters fit_method : str method to be used for training; possible choice are "ardr", "bayesian-ridge", "elasticnet", "lasso", "least-squares", "omp", "rfe", "ridge", "split-bregman" standardize : bool if True the fit matrix and target values are standardized before fitting, meaning columns in the fit matrix and th target values are rescaled to have a standard deviation of 1.0. validation_method : str method to use for cross-validation; possible choices are "shuffle-split", "k-fold" n_splits : int number of times the fit data set will be split for the cross-validation check_condition : bool if True the condition number will be checked (this can be sligthly more time consuming for larger matrices) seed : int seed for pseudo random number generator Attributes ---------- scatter_data_train : ScatterData contains target and predicted values from each individual traininig set in the cross-validation split scatter_data_validation : ScatterData contains target and predicted values from each individual validation set in the cross-validation split """ def __init__(self, fit_data: Tuple[np.ndarray, np.ndarray], fit_method: str = 'least-squares', standardize: bool = True, validation_method: str = 'k-fold', n_splits: int = 10, check_condition: bool = True, seed: int = 42, **kwargs) -> None: super().__init__(fit_data, fit_method, standardize, check_condition, seed) if validation_method not in validation_methods.keys(): msg = ['Validation method not available'] msg += ['Please choose one of the following:'] for key in validation_methods: msg += [' * ' + key] raise ValueError('\n'.join(msg)) self._validation_method = validation_method self._n_splits = n_splits self._set_kwargs(kwargs) # data set splitting object self._splitter = validation_methods[validation_method]( n_splits=self.n_splits, random_state=seed, **self._split_kwargs) self.scatter_data_train = None self.scatter_data_validation = None self.model_splits = None
[docs] def train(self) -> None: """ Constructs the final model using all input data available. """ opt = Optimizer((self._A, self._y), self.fit_method, standardize=self.standardize, train_size=1.0, check_condition=self._check_condition, **self._fit_kwargs) opt.train() self.model = opt.model
[docs] def validate(self) -> None: """ Runs validation. """ self.scatter_data_train = ScatterData() self.scatter_data_validation = ScatterData() self.model_splits = [] for train_set, test_set in self._splitter.split(self._A): opt = Optimizer((self._A, self._y), self.fit_method, standardize=self.standardize, train_set=train_set, test_set=test_set, check_condition=self._check_condition, **self._fit_kwargs) opt.train() self.model_splits.append(opt.model) self.scatter_data_train += opt.scatter_data_train self.scatter_data_validation += opt.scatter_data_test
def _set_kwargs(self, kwargs: dict) -> None: """ Sets up fit_kwargs and split_kwargs. Different split methods need different keywords. """ self._fit_kwargs = {} self._split_kwargs = {} if self.validation_method == 'k-fold': self._split_kwargs['shuffle'] = True # default True for key, val in kwargs.items(): if key in ['shuffle']: self._split_kwargs[key] = val else: self._fit_kwargs[key] = val elif self.validation_method == 'shuffle-split': for key, val in kwargs.items(): if key in ['test_size', 'train_size']: self._split_kwargs[key] = val else: self._fit_kwargs[key] = val @property def summary(self) -> Dict[str, Any]: """ Comprehensive information about the optimizer """ info = super().summary # add model metrics info = {**info, **self.model.to_dict()} # Add class specific data info['validation_method'] = self.validation_method info['n_splits'] = self.n_splits info['rmse_train'] = self.rmse_train info['rmse_train_final'] = self.rmse_train_final info['rmse_train_splits'] = self.rmse_train_splits info['rmse_validation'] = self.rmse_validation info['R2_validation'] = self.R2_validation info['rmse_validation_splits'] = self.rmse_validation_splits info['scatter_data_train'] = self.scatter_data_train info['scatter_data_validation'] = self.scatter_data_validation # add kwargs used for fitting and splitting info = {**info, **self._fit_kwargs, **self._split_kwargs} return info def __repr__(self) -> str: kwargs = dict() kwargs['fit_method'] = self.fit_method kwargs['validation_method'] = self.validation_method kwargs['n_splits'] = self.n_splits kwargs['seed'] = self.seed kwargs = {**kwargs, **self._fit_kwargs, **self._split_kwargs} return 'CrossValidationEstimator((A, y), {})'.format( ', '.join('{}={}'.format(*kwarg) for kwarg in kwargs.items())) @property def validation_method(self) -> str: """ Validation method name """ return self._validation_method @property def n_splits(self) -> int: """ Number of splits (folds) used for cross-validation """ return self._n_splits @property def parameters_splits(self) -> np.ndarray: """ All parameters obtained during cross-validation """ if self.model_splits is None: return None else: return np.array([model.parameters for model in self.model_splits]) @property def n_nonzero_parameters_splits(self) -> np.ndarray: """ Number of non-zero parameters for each split """ if self.model_splits is None: return None else: return np.array([np.count_nonzero(p) for p in self.parameters_splits]) @property def rmse_train_final(self) -> float: """ Root mean squared error when using the full set of input data """ if self.model is None: return None else: return self.model.rmse_train @property def rmse_train(self) -> float: """ Average root mean squared training error obtained during cross-validation """ if self.model_splits is None: return None else: return np.mean(self.rmse_train_splits) @property def rmse_train_splits(self) -> np.ndarray: """ Root mean squared training errors obtained during cross-validation """ if self.model_splits is None: return None else: return np.array([model.rmse_train for model in self.model_splits]) @property def rmse_validation(self) -> float: """ Average root mean squared cross-validation error """ if self.model_splits is None: return None else: return np.mean(self.rmse_validation_splits) @property def R2_validation(self) -> float: """ Average R2 score for validation sets """ if self.model_splits is None: return None else: return np.mean([model.R2_test for model in self.model_splits]) @property def rmse_validation_splits(self) -> np.ndarray: """ Root mean squared validation errors obtained during cross-validation """ if self.model_splits is None: return None else: return np.array([model.rmse_test for model in self.model_splits]) @property def AIC(self) -> float: """ Akaike information criterion (AIC) for the model """ return self.model.AIC @property def BIC(self) -> float: """ Bayesian information criterion (BIC) for the model """ return self.model.BIC