134b254c5SRichard Tran Mills #include <petscregressor.h>
234b254c5SRichard Tran Mills
334b254c5SRichard Tran Mills static char help[] = "Tests some linear PetscRegressor types with different regularizers.\n\n";
434b254c5SRichard Tran Mills
534b254c5SRichard Tran Mills typedef struct _AppCtx {
634b254c5SRichard Tran Mills Mat X; /* Training data */
734b254c5SRichard Tran Mills Vec y; /* Target data */
834b254c5SRichard Tran Mills Vec y_predicted; /* Target data */
934b254c5SRichard Tran Mills Vec coefficients;
1034b254c5SRichard Tran Mills PetscInt N; /* Data size */
1134b254c5SRichard Tran Mills PetscBool flg_string;
1234b254c5SRichard Tran Mills PetscBool flg_ascii;
1334b254c5SRichard Tran Mills PetscBool flg_view_sol;
1434b254c5SRichard Tran Mills PetscBool test_prefix;
1534b254c5SRichard Tran Mills } *AppCtx;
1634b254c5SRichard Tran Mills
DestroyCtx(AppCtx * ctx)1734b254c5SRichard Tran Mills static PetscErrorCode DestroyCtx(AppCtx *ctx)
1834b254c5SRichard Tran Mills {
1934b254c5SRichard Tran Mills PetscFunctionBegin;
2034b254c5SRichard Tran Mills PetscCall(MatDestroy(&(*ctx)->X));
2134b254c5SRichard Tran Mills PetscCall(VecDestroy(&(*ctx)->y));
2234b254c5SRichard Tran Mills PetscCall(VecDestroy(&(*ctx)->y_predicted));
2334b254c5SRichard Tran Mills PetscCall(PetscFree(*ctx));
2434b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS);
2534b254c5SRichard Tran Mills }
2634b254c5SRichard Tran Mills
TestRegressorViews(PetscRegressor regressor,AppCtx ctx)2734b254c5SRichard Tran Mills static PetscErrorCode TestRegressorViews(PetscRegressor regressor, AppCtx ctx)
2834b254c5SRichard Tran Mills {
2934b254c5SRichard Tran Mills PetscRegressorType check_type;
3034b254c5SRichard Tran Mills PetscBool match;
3134b254c5SRichard Tran Mills
3234b254c5SRichard Tran Mills PetscFunctionBegin;
3334b254c5SRichard Tran Mills if (ctx->flg_view_sol) {
3434b254c5SRichard Tran Mills PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Training target vector is\n"));
3534b254c5SRichard Tran Mills PetscCall(VecView(ctx->y, PETSC_VIEWER_STDOUT_WORLD));
3634b254c5SRichard Tran Mills PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Predicted values are\n"));
3734b254c5SRichard Tran Mills PetscCall(VecView(ctx->y_predicted, PETSC_VIEWER_STDOUT_WORLD));
3834b254c5SRichard Tran Mills PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Coefficients are\n"));
3934b254c5SRichard Tran Mills PetscCall(VecView(ctx->coefficients, PETSC_VIEWER_STDOUT_WORLD));
4034b254c5SRichard Tran Mills }
4134b254c5SRichard Tran Mills
4234b254c5SRichard Tran Mills if (ctx->flg_string) {
4334b254c5SRichard Tran Mills PetscViewer stringviewer;
4434b254c5SRichard Tran Mills char string[512];
4534b254c5SRichard Tran Mills const char *outstring;
4634b254c5SRichard Tran Mills
4734b254c5SRichard Tran Mills PetscCall(PetscViewerStringOpen(PETSC_COMM_WORLD, string, sizeof(string), &stringviewer));
4834b254c5SRichard Tran Mills PetscCall(PetscRegressorView(regressor, stringviewer));
4934b254c5SRichard Tran Mills PetscCall(PetscViewerStringGetStringRead(stringviewer, &outstring, NULL));
5034b254c5SRichard Tran Mills PetscCheck((char *)outstring == (char *)string, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "String returned from viewer does not equal original string");
5134b254c5SRichard Tran Mills PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Output from string viewer:%s\n", outstring));
5234b254c5SRichard Tran Mills PetscCall(PetscViewerDestroy(&stringviewer));
5334b254c5SRichard Tran Mills } else if (ctx->flg_ascii) PetscCall(PetscRegressorView(regressor, PETSC_VIEWER_STDOUT_WORLD));
5434b254c5SRichard Tran Mills
5534b254c5SRichard Tran Mills PetscCall(PetscRegressorGetType(regressor, &check_type));
5634b254c5SRichard Tran Mills PetscCall(PetscStrcmp(check_type, PETSCREGRESSORLINEAR, &match));
5734b254c5SRichard Tran Mills PetscCheck(match, PETSC_COMM_WORLD, PETSC_ERR_ARG_NOTSAMETYPE, "Regressor type is not Linear");
5834b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS);
5934b254c5SRichard Tran Mills }
6034b254c5SRichard Tran Mills
TestPrefixRegressor(PetscRegressor regressor,AppCtx ctx)6134b254c5SRichard Tran Mills static PetscErrorCode TestPrefixRegressor(PetscRegressor regressor, AppCtx ctx)
6234b254c5SRichard Tran Mills {
6334b254c5SRichard Tran Mills PetscFunctionBegin;
6434b254c5SRichard Tran Mills if (ctx->test_prefix) {
6534b254c5SRichard Tran Mills PetscCall(PetscRegressorSetOptionsPrefix(regressor, "sys1_"));
6634b254c5SRichard Tran Mills PetscCall(PetscRegressorAppendOptionsPrefix(regressor, "sys2_"));
6734b254c5SRichard Tran Mills }
6834b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS);
6934b254c5SRichard Tran Mills }
7034b254c5SRichard Tran Mills
CreateData(AppCtx ctx)7134b254c5SRichard Tran Mills static PetscErrorCode CreateData(AppCtx ctx)
7234b254c5SRichard Tran Mills {
7334b254c5SRichard Tran Mills PetscMPIInt rank;
7434b254c5SRichard Tran Mills PetscInt i;
7534b254c5SRichard Tran Mills PetscScalar mean;
7634b254c5SRichard Tran Mills
7734b254c5SRichard Tran Mills PetscFunctionBegin;
7834b254c5SRichard Tran Mills PetscCallMPI(MPI_Comm_rank(PETSC_COMM_WORLD, &rank));
7934b254c5SRichard Tran Mills PetscCall(VecCreate(PETSC_COMM_WORLD, &ctx->y));
8034b254c5SRichard Tran Mills PetscCall(VecSetSizes(ctx->y, PETSC_DECIDE, ctx->N));
8134b254c5SRichard Tran Mills PetscCall(VecSetFromOptions(ctx->y));
8234b254c5SRichard Tran Mills PetscCall(VecDuplicate(ctx->y, &ctx->y_predicted));
8334b254c5SRichard Tran Mills PetscCall(MatCreate(PETSC_COMM_WORLD, &ctx->X));
8434b254c5SRichard Tran Mills PetscCall(MatSetSizes(ctx->X, PETSC_DECIDE, PETSC_DECIDE, ctx->N, ctx->N));
8534b254c5SRichard Tran Mills PetscCall(MatSetFromOptions(ctx->X));
8634b254c5SRichard Tran Mills PetscCall(MatSetUp(ctx->X));
8734b254c5SRichard Tran Mills
8834b254c5SRichard Tran Mills if (!rank) {
8934b254c5SRichard Tran Mills for (i = 0; i < ctx->N; i++) {
9034b254c5SRichard Tran Mills PetscCall(VecSetValue(ctx->y, i, (PetscScalar)i, INSERT_VALUES));
9134b254c5SRichard Tran Mills PetscCall(MatSetValue(ctx->X, i, i, 1.0, INSERT_VALUES));
9234b254c5SRichard Tran Mills }
9334b254c5SRichard Tran Mills }
9434b254c5SRichard Tran Mills /* Set up a training data matrix that is the identity.
9534b254c5SRichard Tran Mills * We do this because this gives us a special case in which we can analytically determine what the regression
9634b254c5SRichard Tran Mills * coefficients should be for ordinary least squares, LASSO (L1 regularized), and ridge (L2 regularized) regression.
9734b254c5SRichard Tran Mills * See details in section 6.2 of James et al.'s An Introduction to Statistical Learning (ISLR), in the subsection
9834b254c5SRichard Tran Mills * titled "A Simple Special Case for Ridge Regression and the Lasso".
9934b254c5SRichard Tran Mills * Note that the coefficients we generate with ridge regression (-regressor_linear_type ridge -regressor_regularizer_weight <lambda>, or, equivalently,
10034b254c5SRichard Tran Mills * -tao_brgn_regularization_type l2pure -tao_brgn_regularizer_weight <lambda>) match those of the ISLR formula exactly.
10134b254c5SRichard Tran Mills * For LASSO it does not match the ISLR formula: where they use lambda/2, we need to use lambda.
10234b254c5SRichard Tran Mills * It also doesn't match what Scikit-learn does; in that case their lambda is 1/n_samples of our lambda. Apparently everyone is scaling
10334b254c5SRichard Tran Mills * their loss function by a different value, hence the need to change what "lambda" is. But it's clear that ISLR, Scikit-learn, and we
10434b254c5SRichard Tran Mills * are basically doing the same thing otherwise. */
10534b254c5SRichard Tran Mills PetscCall(VecAssemblyBegin(ctx->y));
10634b254c5SRichard Tran Mills PetscCall(VecAssemblyEnd(ctx->y));
10734b254c5SRichard Tran Mills PetscCall(MatAssemblyBegin(ctx->X, MAT_FINAL_ASSEMBLY));
10834b254c5SRichard Tran Mills PetscCall(MatAssemblyEnd(ctx->X, MAT_FINAL_ASSEMBLY));
10934b254c5SRichard Tran Mills /* Center the target vector we will train with. */
11034b254c5SRichard Tran Mills PetscCall(VecMean(ctx->y, &mean));
11134b254c5SRichard Tran Mills PetscCall(VecShift(ctx->y, -1.0 * mean));
11234b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS);
11334b254c5SRichard Tran Mills }
11434b254c5SRichard Tran Mills
ConfigureContext(AppCtx ctx)11534b254c5SRichard Tran Mills static PetscErrorCode ConfigureContext(AppCtx ctx)
11634b254c5SRichard Tran Mills {
11734b254c5SRichard Tran Mills PetscFunctionBegin;
11834b254c5SRichard Tran Mills ctx->flg_string = PETSC_FALSE;
11934b254c5SRichard Tran Mills ctx->flg_ascii = PETSC_FALSE;
12034b254c5SRichard Tran Mills ctx->flg_view_sol = PETSC_FALSE;
12134b254c5SRichard Tran Mills ctx->test_prefix = PETSC_FALSE;
12234b254c5SRichard Tran Mills ctx->N = 10;
12334b254c5SRichard Tran Mills PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Options for PetscRegressor ex3:", "");
12434b254c5SRichard Tran Mills PetscCall(PetscOptionsInt("-N", "Dimension of the N x N data matrix", "ex3.c", ctx->N, &ctx->N, NULL));
12534b254c5SRichard Tran Mills PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_string_viewer", &ctx->flg_string, NULL));
12634b254c5SRichard Tran Mills PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_ascii_viewer", &ctx->flg_ascii, NULL));
12734b254c5SRichard Tran Mills PetscCall(PetscOptionsGetBool(NULL, NULL, "-view_sols", &ctx->flg_view_sol, NULL));
12834b254c5SRichard Tran Mills PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_prefix", &ctx->test_prefix, NULL));
12934b254c5SRichard Tran Mills PetscOptionsEnd();
13034b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS);
13134b254c5SRichard Tran Mills }
13234b254c5SRichard Tran Mills
main(int argc,char ** args)13334b254c5SRichard Tran Mills int main(int argc, char **args)
13434b254c5SRichard Tran Mills {
13534b254c5SRichard Tran Mills AppCtx ctx;
13634b254c5SRichard Tran Mills PetscRegressor regressor;
13734b254c5SRichard Tran Mills PetscScalar intercept;
13834b254c5SRichard Tran Mills
13934b254c5SRichard Tran Mills /* Initialize PETSc */
14034b254c5SRichard Tran Mills PetscCall(PetscInitialize(&argc, &args, (char *)0, help));
14134b254c5SRichard Tran Mills
14234b254c5SRichard Tran Mills /* Initialize problem parameters and data */
14334b254c5SRichard Tran Mills PetscCall(PetscNew(&ctx));
14434b254c5SRichard Tran Mills PetscCall(ConfigureContext(ctx));
14534b254c5SRichard Tran Mills PetscCall(CreateData(ctx));
14634b254c5SRichard Tran Mills
14734b254c5SRichard Tran Mills /* Create Regressor solver with desired type and options */
14834b254c5SRichard Tran Mills PetscCall(PetscRegressorCreate(PETSC_COMM_WORLD, ®ressor));
14934b254c5SRichard Tran Mills PetscCall(PetscRegressorSetType(regressor, PETSCREGRESSORLINEAR));
15034b254c5SRichard Tran Mills PetscCall(PetscRegressorLinearSetType(regressor, REGRESSOR_LINEAR_OLS));
15134b254c5SRichard Tran Mills PetscCall(PetscRegressorLinearSetFitIntercept(regressor, PETSC_FALSE));
15234b254c5SRichard Tran Mills /* Testing prefix functions for Regressor */
15334b254c5SRichard Tran Mills PetscCall(TestPrefixRegressor(regressor, ctx));
15434b254c5SRichard Tran Mills /* Check for command line options */
15534b254c5SRichard Tran Mills PetscCall(PetscRegressorSetFromOptions(regressor));
15634b254c5SRichard Tran Mills /* Fit the regressor */
15734b254c5SRichard Tran Mills PetscCall(PetscRegressorFit(regressor, ctx->X, ctx->y));
15834b254c5SRichard Tran Mills /* Predict data with fitted regressor */
15934b254c5SRichard Tran Mills PetscCall(PetscRegressorPredict(regressor, ctx->X, ctx->y_predicted));
16034b254c5SRichard Tran Mills /* Get other desired output data */
16134b254c5SRichard Tran Mills PetscCall(PetscRegressorLinearGetIntercept(regressor, &intercept));
16234b254c5SRichard Tran Mills PetscCall(PetscRegressorLinearGetCoefficients(regressor, &ctx->coefficients));
16334b254c5SRichard Tran Mills
16434b254c5SRichard Tran Mills /* Testing Views, and GetTypes */
16534b254c5SRichard Tran Mills PetscCall(TestRegressorViews(regressor, ctx));
16634b254c5SRichard Tran Mills PetscCall(PetscRegressorDestroy(®ressor));
16734b254c5SRichard Tran Mills PetscCall(DestroyCtx(&ctx));
16834b254c5SRichard Tran Mills PetscCall(PetscFinalize());
16934b254c5SRichard Tran Mills return 0;
17034b254c5SRichard Tran Mills }
17134b254c5SRichard Tran Mills
17234b254c5SRichard Tran Mills /*TEST
17334b254c5SRichard Tran Mills
17434b254c5SRichard Tran Mills build:
17534b254c5SRichard Tran Mills requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES)
17634b254c5SRichard Tran Mills
17734b254c5SRichard Tran Mills test:
17834b254c5SRichard Tran Mills suffix: prefix_tao
17934b254c5SRichard Tran Mills args: -sys1_sys2_regressor_view -test_prefix
18034b254c5SRichard Tran Mills
18134b254c5SRichard Tran Mills test:
18234b254c5SRichard Tran Mills suffix: prefix_ksp
183*789736e1SBarry Smith args: -sys1_sys2_regressor_view -test_prefix -sys1_sys2_regressor_linear_use_ksp -sys1_sys2_regressor_linear_ksp_monitor
184*789736e1SBarry Smith
185*789736e1SBarry Smith test:
186*789736e1SBarry Smith suffix: prefix_ksp_cholesky
187*789736e1SBarry Smith args: -sys1_sys2_regressor_view -test_prefix -sys1_sys2_regressor_linear_use_ksp -sys1_sys2_regressor_linear_pc_type cholesky
188*789736e1SBarry Smith TODO: Could not locate a solver type for factorization type CHOLESKY and matrix type normal
189*789736e1SBarry Smith
190*789736e1SBarry Smith test:
191*789736e1SBarry Smith suffix: prefix_ksp_suitesparse
192*789736e1SBarry Smith requires: suitesparse
193*789736e1SBarry Smith args: -sys1_sys2_regressor_view -test_prefix -sys1_sys2_regressor_linear_use_ksp -sys1_sys2_regressor_linear_pc_type qr -sys1_sys2_regressor_linear_pc_factor_mat_solver_type spqr -sys1_sys2_regressor_linear_ksp_monitor
19434b254c5SRichard Tran Mills
19534b254c5SRichard Tran Mills test:
19634b254c5SRichard Tran Mills suffix: asciiview
19734b254c5SRichard Tran Mills args: -test_ascii_viewer
19834b254c5SRichard Tran Mills
19934b254c5SRichard Tran Mills test:
20034b254c5SRichard Tran Mills suffix: stringview
20134b254c5SRichard Tran Mills args: -test_string_viewer
20234b254c5SRichard Tran Mills
20334b254c5SRichard Tran Mills test:
20434b254c5SRichard Tran Mills suffix: ksp_intercept
20534b254c5SRichard Tran Mills args: -regressor_linear_use_ksp -regressor_linear_fit_intercept -regressor_view
20634b254c5SRichard Tran Mills
20734b254c5SRichard Tran Mills test:
20834b254c5SRichard Tran Mills suffix: ksp_no_intercept
20934b254c5SRichard Tran Mills args: -regressor_linear_use_ksp -regressor_view
21034b254c5SRichard Tran Mills
21134b254c5SRichard Tran Mills test:
21234b254c5SRichard Tran Mills suffix: lasso_1
21334b254c5SRichard Tran Mills nsize: 1
21434b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type lasso -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols
21534b254c5SRichard Tran Mills
21634b254c5SRichard Tran Mills test:
21734b254c5SRichard Tran Mills suffix: lasso_2
21834b254c5SRichard Tran Mills nsize: 2
21934b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type lasso -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols
22034b254c5SRichard Tran Mills
22134b254c5SRichard Tran Mills test:
22234b254c5SRichard Tran Mills suffix: ridge_1
22334b254c5SRichard Tran Mills nsize: 1
22434b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type ridge -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols
22534b254c5SRichard Tran Mills
22634b254c5SRichard Tran Mills test:
22734b254c5SRichard Tran Mills suffix: ridge_2
22834b254c5SRichard Tran Mills nsize: 2
22934b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type ridge -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols
23034b254c5SRichard Tran Mills
23134b254c5SRichard Tran Mills test:
23234b254c5SRichard Tran Mills suffix: ols_1
23334b254c5SRichard Tran Mills nsize: 1
23434b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type ols -regressor_linear_fit_intercept -view_sols
23534b254c5SRichard Tran Mills
23634b254c5SRichard Tran Mills test:
23734b254c5SRichard Tran Mills suffix: ols_2
23834b254c5SRichard Tran Mills nsize: 2
23934b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type ols -regressor_linear_fit_intercept -view_sols
24034b254c5SRichard Tran Mills
24134b254c5SRichard Tran Mills TEST*/
242