CrossValidationEstimator#

class trainstation.CrossValidationEstimator(fit_data, fit_method='least-squares', standardize=True, validation_method='k-fold', n_splits=10, check_condition=True, seed=42, **kwargs)[source]#

This class provides an optimizer with cross validation for solving the linear \(\boldsymbol{A}\boldsymbol{x} = \boldsymbol{y}\) problem. Cross-validation (CV) scores are calculated by splitting the available reference data in multiple different ways. It also produces the finalized model (using the full input data) for which the CV score is an estimation of its performance.

Warning

Repeatedly setting up a CrossValidationEstimator and training without changing the seed for the random number generator will yield identical or correlated results, to avoid this please specify a different seed when setting up multiple CrossValidationEstimator instances.

Parameters:
  • fit_data (tuple(numpy.ndarray, numpy.ndarray)) – the first element of the tuple represents the fit matrix A (N, M array) while the second element represents the vector of target values y (N array); here N (=rows of A, elements of y) equals the number of target values and M (=columns of A) equals the number of parameters

  • fit_method (str) – method to be used for training; possible choice are “ardr”, “bayesian-ridge”, “elasticnet”, “lasso”, “least-squares”, “omp”, “rfe”, “ridge”, “split-bregman”

  • standardize (bool) – if True the fit matrix and target values are standardized before fitting, meaning columns in the fit matrix and th target values are rescaled to have a standard deviation of 1.0.

  • validation_method (str) – method to use for cross-validation; possible choices are “shuffle-split”, “k-fold”

  • n_splits (int) – number of times the fit data set will be split for the cross-validation

  • check_condition (bool) – if True the condition number will be checked (this can be sligthly more time consuming for larger matrices)

  • seed (int) – seed for pseudo random number generator

scatter_data_train#

contains target and predicted values from each individual traininig set in the cross-validation split

Type:

ScatterData

scatter_data_validation#

contains target and predicted values from each individual validation set in the cross-validation split

Type:

ScatterData

property AIC: float#

Akaike information criterion (AIC) for the model

property BIC: float#

Bayesian information criterion (BIC) for the model

property R2_validation: float#

Average R2 score for validation sets

property fit_method: str#

Fit method

get_contributions(A)#

Returns the average contribution for each row of A to the predicted values from each element of the parameter vector.

Parameters:

A (ndarray) – fit matrix where N (=rows of A, elements of y) equals the number of target values and M (=columns of A) equals the number of parameters

Return type:

ndarray

property n_nonzero_parameters: int#

Number of non-zero parameters

property n_nonzero_parameters_splits: ndarray#

Number of non-zero parameters for each split

property n_parameters: int#

Number of parameters (=columns in A matrix)

property n_splits: int#

Number of splits (folds) used for cross-validation

property n_target_values: int#

Number of target values (=rows in A matrix)

property parameters: ndarray#

Copy of parameter vector

property parameters_norm: float#

Norm of the parameter vector

property parameters_splits: ndarray#

All parameters obtained during cross-validation

property rmse_train: float#

Average root mean squared training error obtained during cross-validation

property rmse_train_final: float#

Root mean squared error when using the full set of input data

property rmse_train_splits: ndarray#

Root mean squared training errors obtained during cross-validation

property rmse_validation: float#

Average root mean squared cross-validation error

property rmse_validation_splits: ndarray#

Root mean squared validation errors obtained during cross-validation

property seed: int#

Seed used to initialize pseudo random number generator

property standardize: bool#

If True standardize the fit matrix before fitting

property summary: Dict[str, Any]#

Comprehensive information about the optimizer

train()[source]#

Constructs the final model using all input data available.

Return type:

None

validate()[source]#

Runs validation.

Return type:

None

property validation_method: str#

Validation method name

write_summary(fname)#

Writes summary dict to file.

Return type:

None