CrossValidationEstimator#
- class trainstation.CrossValidationEstimator(fit_data, fit_method='least-squares', standardize=True, validation_method='k-fold', n_splits=10, check_condition=True, seed=42, **kwargs)[source]#
This class provides an optimizer with cross validation for solving the linear \(\boldsymbol{A}\boldsymbol{x} = \boldsymbol{y}\) problem. Cross-validation (CV) scores are calculated by splitting the available reference data in multiple different ways. It also produces the finalized model (using the full input data) for which the CV score is an estimation of its performance.
Warning
Repeatedly setting up a
CrossValidationEstimator
and training without changing the seed for the random number generator will yield identical or correlated results, to avoid this please specify a different seed when setting up multipleCrossValidationEstimator
instances.- Parameters:
fit_data (tuple(numpy.ndarray, numpy.ndarray)) – the first element of the tuple represents the fit matrix
A
(N, M
array) while the second element represents the vector of target valuesy
(N
array); hereN
(=rows ofA
, elements ofy
) equals the number of target values andM
(=columns ofA
) equals the number of parametersfit_method (str) – method to be used for training; possible choice are “ardr”, “bayesian-ridge”, “elasticnet”, “lasso”, “least-squares”, “omp”, “rfe”, “ridge”, “split-bregman”
standardize (bool) – if True the fit matrix and target values are standardized before fitting, meaning columns in the fit matrix and th target values are rescaled to have a standard deviation of 1.0.
validation_method (str) – method to use for cross-validation; possible choices are “shuffle-split”, “k-fold”
n_splits (int) – number of times the fit data set will be split for the cross-validation
check_condition (bool) – if True the condition number will be checked (this can be sligthly more time consuming for larger matrices)
seed (int) – seed for pseudo random number generator
- scatter_data_train#
contains target and predicted values from each individual traininig set in the cross-validation split
- Type:
ScatterData
- scatter_data_validation#
contains target and predicted values from each individual validation set in the cross-validation split
- Type:
ScatterData
- property AIC: float#
Akaike information criterion (AIC) for the model
- property BIC: float#
Bayesian information criterion (BIC) for the model
- property R2_validation: float#
Average R2 score for validation sets
- property fit_method: str#
Fit method
- get_contributions(A)#
Returns the average contribution for each row of
A
to the predicted values from each element of the parameter vector.
- property n_nonzero_parameters: int#
Number of non-zero parameters
- property n_parameters: int#
Number of parameters (=columns in
A
matrix)
- property n_splits: int#
Number of splits (folds) used for cross-validation
- property n_target_values: int#
Number of target values (=rows in
A
matrix)
- property parameters_norm: float#
Norm of the parameter vector
- property rmse_train: float#
Average root mean squared training error obtained during cross-validation
- property rmse_train_final: float#
Root mean squared error when using the full set of input data
- property rmse_train_splits: ndarray#
Root mean squared training errors obtained during cross-validation
- property rmse_validation: float#
Average root mean squared cross-validation error
- property rmse_validation_splits: ndarray#
Root mean squared validation errors obtained during cross-validation
- property seed: int#
Seed used to initialize pseudo random number generator
- property standardize: bool#
If True standardize the fit matrix before fitting
- property summary: Dict[str, Any]#
Comprehensive information about the optimizer
- property validation_method: str#
Validation method name
- write_summary(fname)#
Writes summary dict to file.
- Return type:
None