Optimizer#

class trainstation.Optimizer(fit_data, fit_method='least-squares', standardize=True, train_size=0.9, test_size=None, train_set=None, test_set=None, check_condition=True, seed=42, **kwargs)[source]#

This optimizer finds a solution to the linear \(\boldsymbol{A}\boldsymbol{x}=\boldsymbol{y}\) problem.

One has to specify either train_size/test_size or train_set/test_set. If either train_set or test_set (or both) is specified the fractions will be ignored.

Warning

Repeatedly setting up an Optimizer object and training without changing the seed for the random number generator will yield identical or correlated results, to avoid this please specify a different seed when setting up multiple Optimizer instances.

Parameters:
  • fit_data (tuple(numpy.ndarray, numpy.ndarray)) – the first element of the tuple represents the fit matrix A (N, M array) while the second element represents the vector of target values y (N array); here N (=rows of A, elements of y) equals the number of target values and M (=columns of A) equals the number of parameters

  • fit_method (str) – method to be used for training; possible choice are “ardr”, “bayesian-ridge”, “elasticnet”, “lasso”, “least-squares”, “omp”, “rfe”, “ridge”, “split-bregman”

  • standardize (bool) – if True the fit matrix and target values are standardized before fitting, meaning columns in the fit matrix and th target values are rescaled to have a standard deviation of 1.0.

  • train_size (float or int) – If float represents the fraction of fit_data (rows) to be used for training. If int, represents the absolute number of rows to be used for training.

  • test_size (float or int) – If float represents the fraction of fit_data (rows) to be used for testing. If int, represents the absolute number of rows to be used for testing.

  • train_set (tuple or list(int)) – indices of rows of A/y to be used for training

  • test_set (tuple or list(int)) – indices of rows of A/y to be used for testing

  • check_condition (bool) – if True the condition number will be checked (this can be sligthly more time consuming for larger matrices)

  • seed (int) – seed for pseudo random number generator

scatter_data_train#

target and predicted value for each row in the training set

Type:

ScatterData

scatter_data_test#

target and predicted value for each row in the test set

Type:

ScatterData

property AIC: float#

Akaike information criterion (AIC) for the model

property BIC: float#

Bayesian information criterion (BIC) for the model

property fit_method: str#

Fit method

get_contributions(A)#

Returns the average contribution for each row of A to the predicted values from each element of the parameter vector.

Parameters:

A (ndarray) – fit matrix where N (=rows of A, elements of y) equals the number of target values and M (=columns of A) equals the number of parameters

Return type:

ndarray

property n_nonzero_parameters: int#

Number of non-zero parameters

property n_parameters: int#

Number of parameters (=columns in A matrix)

property n_target_values: int#

Number of target values (=rows in A matrix)

property parameters: ndarray#

Copy of parameter vector

property parameters_norm: float#

Norm of the parameter vector

property rmse_test: float#

Root mean squared error for test set

property rmse_train: float#

Root mean squared error for training set

property seed: int#

Seed used to initialize pseudo random number generator

property standardize: bool#

If True standardize the fit matrix before fitting

property summary: Dict[str, Any]#

Comprehensive information about the optimizer

property test_fraction: float#

Fraction of rows included in test set

property test_set: List[int]#

Indices of rows included in the test set

property test_size: int#

Number of rows included in test set

train()[source]#

Carries out training.

Return type:

None

property train_fraction: float#

Fraction of rows included in training set

property train_set: List[int]#

Indices of rows included in the training set

property train_size: int#

Number of rows included in training set

write_summary(fname)#

Writes summary dict to file.

Return type:

None