xref: /petsc/src/ml/regressor/tests/ex3.c (revision 6bfab51239a1d021a2781a42e04752bb50d6082e)
134b254c5SRichard Tran Mills #include <petscregressor.h>
234b254c5SRichard Tran Mills 
334b254c5SRichard Tran Mills static char help[] = "Tests some linear PetscRegressor types with different regularizers.\n\n";
434b254c5SRichard Tran Mills 
534b254c5SRichard Tran Mills typedef struct _AppCtx {
634b254c5SRichard Tran Mills   Mat       X;           /* Training data */
734b254c5SRichard Tran Mills   Vec       y;           /* Target data   */
834b254c5SRichard Tran Mills   Vec       y_predicted; /* Target data   */
934b254c5SRichard Tran Mills   Vec       coefficients;
1034b254c5SRichard Tran Mills   PetscInt  N; /* Data size     */
1134b254c5SRichard Tran Mills   PetscBool flg_string;
1234b254c5SRichard Tran Mills   PetscBool flg_ascii;
1334b254c5SRichard Tran Mills   PetscBool flg_view_sol;
1434b254c5SRichard Tran Mills   PetscBool test_prefix;
1534b254c5SRichard Tran Mills } *AppCtx;
1634b254c5SRichard Tran Mills 
DestroyCtx(AppCtx * ctx)1734b254c5SRichard Tran Mills static PetscErrorCode DestroyCtx(AppCtx *ctx)
1834b254c5SRichard Tran Mills {
1934b254c5SRichard Tran Mills   PetscFunctionBegin;
2034b254c5SRichard Tran Mills   PetscCall(MatDestroy(&(*ctx)->X));
2134b254c5SRichard Tran Mills   PetscCall(VecDestroy(&(*ctx)->y));
2234b254c5SRichard Tran Mills   PetscCall(VecDestroy(&(*ctx)->y_predicted));
2334b254c5SRichard Tran Mills   PetscCall(PetscFree(*ctx));
2434b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
2534b254c5SRichard Tran Mills }
2634b254c5SRichard Tran Mills 
TestRegressorViews(PetscRegressor regressor,AppCtx ctx)2734b254c5SRichard Tran Mills static PetscErrorCode TestRegressorViews(PetscRegressor regressor, AppCtx ctx)
2834b254c5SRichard Tran Mills {
2934b254c5SRichard Tran Mills   PetscRegressorType check_type;
3034b254c5SRichard Tran Mills   PetscBool          match;
3134b254c5SRichard Tran Mills 
3234b254c5SRichard Tran Mills   PetscFunctionBegin;
3334b254c5SRichard Tran Mills   if (ctx->flg_view_sol) {
3434b254c5SRichard Tran Mills     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Training target vector is\n"));
3534b254c5SRichard Tran Mills     PetscCall(VecView(ctx->y, PETSC_VIEWER_STDOUT_WORLD));
3634b254c5SRichard Tran Mills     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Predicted values are\n"));
3734b254c5SRichard Tran Mills     PetscCall(VecView(ctx->y_predicted, PETSC_VIEWER_STDOUT_WORLD));
3834b254c5SRichard Tran Mills     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Coefficients are\n"));
3934b254c5SRichard Tran Mills     PetscCall(VecView(ctx->coefficients, PETSC_VIEWER_STDOUT_WORLD));
4034b254c5SRichard Tran Mills   }
4134b254c5SRichard Tran Mills 
4234b254c5SRichard Tran Mills   if (ctx->flg_string) {
4334b254c5SRichard Tran Mills     PetscViewer stringviewer;
4434b254c5SRichard Tran Mills     char        string[512];
4534b254c5SRichard Tran Mills     const char *outstring;
4634b254c5SRichard Tran Mills 
4734b254c5SRichard Tran Mills     PetscCall(PetscViewerStringOpen(PETSC_COMM_WORLD, string, sizeof(string), &stringviewer));
4834b254c5SRichard Tran Mills     PetscCall(PetscRegressorView(regressor, stringviewer));
4934b254c5SRichard Tran Mills     PetscCall(PetscViewerStringGetStringRead(stringviewer, &outstring, NULL));
5034b254c5SRichard Tran Mills     PetscCheck((char *)outstring == (char *)string, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "String returned from viewer does not equal original string");
5134b254c5SRichard Tran Mills     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Output from string viewer:%s\n", outstring));
5234b254c5SRichard Tran Mills     PetscCall(PetscViewerDestroy(&stringviewer));
5334b254c5SRichard Tran Mills   } else if (ctx->flg_ascii) PetscCall(PetscRegressorView(regressor, PETSC_VIEWER_STDOUT_WORLD));
5434b254c5SRichard Tran Mills 
5534b254c5SRichard Tran Mills   PetscCall(PetscRegressorGetType(regressor, &check_type));
5634b254c5SRichard Tran Mills   PetscCall(PetscStrcmp(check_type, PETSCREGRESSORLINEAR, &match));
5734b254c5SRichard Tran Mills   PetscCheck(match, PETSC_COMM_WORLD, PETSC_ERR_ARG_NOTSAMETYPE, "Regressor type is not Linear");
5834b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
5934b254c5SRichard Tran Mills }
6034b254c5SRichard Tran Mills 
TestPrefixRegressor(PetscRegressor regressor,AppCtx ctx)6134b254c5SRichard Tran Mills static PetscErrorCode TestPrefixRegressor(PetscRegressor regressor, AppCtx ctx)
6234b254c5SRichard Tran Mills {
6334b254c5SRichard Tran Mills   PetscFunctionBegin;
6434b254c5SRichard Tran Mills   if (ctx->test_prefix) {
6534b254c5SRichard Tran Mills     PetscCall(PetscRegressorSetOptionsPrefix(regressor, "sys1_"));
6634b254c5SRichard Tran Mills     PetscCall(PetscRegressorAppendOptionsPrefix(regressor, "sys2_"));
6734b254c5SRichard Tran Mills   }
6834b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
6934b254c5SRichard Tran Mills }
7034b254c5SRichard Tran Mills 
CreateData(AppCtx ctx)7134b254c5SRichard Tran Mills static PetscErrorCode CreateData(AppCtx ctx)
7234b254c5SRichard Tran Mills {
7334b254c5SRichard Tran Mills   PetscMPIInt rank;
7434b254c5SRichard Tran Mills   PetscInt    i;
7534b254c5SRichard Tran Mills   PetscScalar mean;
7634b254c5SRichard Tran Mills 
7734b254c5SRichard Tran Mills   PetscFunctionBegin;
7834b254c5SRichard Tran Mills   PetscCallMPI(MPI_Comm_rank(PETSC_COMM_WORLD, &rank));
7934b254c5SRichard Tran Mills   PetscCall(VecCreate(PETSC_COMM_WORLD, &ctx->y));
8034b254c5SRichard Tran Mills   PetscCall(VecSetSizes(ctx->y, PETSC_DECIDE, ctx->N));
8134b254c5SRichard Tran Mills   PetscCall(VecSetFromOptions(ctx->y));
8234b254c5SRichard Tran Mills   PetscCall(VecDuplicate(ctx->y, &ctx->y_predicted));
8334b254c5SRichard Tran Mills   PetscCall(MatCreate(PETSC_COMM_WORLD, &ctx->X));
8434b254c5SRichard Tran Mills   PetscCall(MatSetSizes(ctx->X, PETSC_DECIDE, PETSC_DECIDE, ctx->N, ctx->N));
8534b254c5SRichard Tran Mills   PetscCall(MatSetFromOptions(ctx->X));
8634b254c5SRichard Tran Mills   PetscCall(MatSetUp(ctx->X));
8734b254c5SRichard Tran Mills 
8834b254c5SRichard Tran Mills   if (!rank) {
8934b254c5SRichard Tran Mills     for (i = 0; i < ctx->N; i++) {
9034b254c5SRichard Tran Mills       PetscCall(VecSetValue(ctx->y, i, (PetscScalar)i, INSERT_VALUES));
9134b254c5SRichard Tran Mills       PetscCall(MatSetValue(ctx->X, i, i, 1.0, INSERT_VALUES));
9234b254c5SRichard Tran Mills     }
9334b254c5SRichard Tran Mills   }
9434b254c5SRichard Tran Mills   /* Set up a training data matrix that is the identity.
9534b254c5SRichard Tran Mills    * We do this because this gives us a special case in which we can analytically determine what the regression
9634b254c5SRichard Tran Mills    * coefficients should be for ordinary least squares, LASSO (L1 regularized), and ridge (L2 regularized) regression.
9734b254c5SRichard Tran Mills    * See details in section 6.2 of James et al.'s An Introduction to Statistical Learning (ISLR), in the subsection
9834b254c5SRichard Tran Mills    * titled "A Simple Special Case for Ridge Regression and the Lasso".
9934b254c5SRichard Tran Mills    * Note that the coefficients we generate with ridge regression (-regressor_linear_type ridge -regressor_regularizer_weight <lambda>, or, equivalently,
10034b254c5SRichard Tran Mills    * -tao_brgn_regularization_type l2pure -tao_brgn_regularizer_weight <lambda>) match those of the ISLR formula exactly.
10134b254c5SRichard Tran Mills    * For LASSO it does not match the ISLR formula: where they use lambda/2, we need to use lambda.
10234b254c5SRichard Tran Mills    * It also doesn't match what Scikit-learn does; in that case their lambda is 1/n_samples of our lambda. Apparently everyone is scaling
10334b254c5SRichard Tran Mills    * their loss function by a different value, hence the need to change what "lambda" is. But it's clear that ISLR, Scikit-learn, and we
10434b254c5SRichard Tran Mills    * are basically doing the same thing otherwise. */
10534b254c5SRichard Tran Mills   PetscCall(VecAssemblyBegin(ctx->y));
10634b254c5SRichard Tran Mills   PetscCall(VecAssemblyEnd(ctx->y));
10734b254c5SRichard Tran Mills   PetscCall(MatAssemblyBegin(ctx->X, MAT_FINAL_ASSEMBLY));
10834b254c5SRichard Tran Mills   PetscCall(MatAssemblyEnd(ctx->X, MAT_FINAL_ASSEMBLY));
10934b254c5SRichard Tran Mills   /* Center the target vector we will train with. */
11034b254c5SRichard Tran Mills   PetscCall(VecMean(ctx->y, &mean));
11134b254c5SRichard Tran Mills   PetscCall(VecShift(ctx->y, -1.0 * mean));
11234b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
11334b254c5SRichard Tran Mills }
11434b254c5SRichard Tran Mills 
ConfigureContext(AppCtx ctx)11534b254c5SRichard Tran Mills static PetscErrorCode ConfigureContext(AppCtx ctx)
11634b254c5SRichard Tran Mills {
11734b254c5SRichard Tran Mills   PetscFunctionBegin;
11834b254c5SRichard Tran Mills   ctx->flg_string   = PETSC_FALSE;
11934b254c5SRichard Tran Mills   ctx->flg_ascii    = PETSC_FALSE;
12034b254c5SRichard Tran Mills   ctx->flg_view_sol = PETSC_FALSE;
12134b254c5SRichard Tran Mills   ctx->test_prefix  = PETSC_FALSE;
12234b254c5SRichard Tran Mills   ctx->N            = 10;
12334b254c5SRichard Tran Mills   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Options for PetscRegressor ex3:", "");
12434b254c5SRichard Tran Mills   PetscCall(PetscOptionsInt("-N", "Dimension of the N x N data matrix", "ex3.c", ctx->N, &ctx->N, NULL));
12534b254c5SRichard Tran Mills   PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_string_viewer", &ctx->flg_string, NULL));
12634b254c5SRichard Tran Mills   PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_ascii_viewer", &ctx->flg_ascii, NULL));
12734b254c5SRichard Tran Mills   PetscCall(PetscOptionsGetBool(NULL, NULL, "-view_sols", &ctx->flg_view_sol, NULL));
12834b254c5SRichard Tran Mills   PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_prefix", &ctx->test_prefix, NULL));
12934b254c5SRichard Tran Mills   PetscOptionsEnd();
13034b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
13134b254c5SRichard Tran Mills }
13234b254c5SRichard Tran Mills 
main(int argc,char ** args)13334b254c5SRichard Tran Mills int main(int argc, char **args)
13434b254c5SRichard Tran Mills {
13534b254c5SRichard Tran Mills   AppCtx         ctx;
13634b254c5SRichard Tran Mills   PetscRegressor regressor;
13734b254c5SRichard Tran Mills   PetscScalar    intercept;
13834b254c5SRichard Tran Mills 
13934b254c5SRichard Tran Mills   /* Initialize PETSc */
14034b254c5SRichard Tran Mills   PetscCall(PetscInitialize(&argc, &args, (char *)0, help));
14134b254c5SRichard Tran Mills 
14234b254c5SRichard Tran Mills   /* Initialize problem parameters and data */
14334b254c5SRichard Tran Mills   PetscCall(PetscNew(&ctx));
14434b254c5SRichard Tran Mills   PetscCall(ConfigureContext(ctx));
14534b254c5SRichard Tran Mills   PetscCall(CreateData(ctx));
14634b254c5SRichard Tran Mills 
14734b254c5SRichard Tran Mills   /* Create Regressor solver with desired type and options */
14834b254c5SRichard Tran Mills   PetscCall(PetscRegressorCreate(PETSC_COMM_WORLD, &regressor));
14934b254c5SRichard Tran Mills   PetscCall(PetscRegressorSetType(regressor, PETSCREGRESSORLINEAR));
15034b254c5SRichard Tran Mills   PetscCall(PetscRegressorLinearSetType(regressor, REGRESSOR_LINEAR_OLS));
15134b254c5SRichard Tran Mills   PetscCall(PetscRegressorLinearSetFitIntercept(regressor, PETSC_FALSE));
15234b254c5SRichard Tran Mills   /* Testing prefix functions for Regressor */
15334b254c5SRichard Tran Mills   PetscCall(TestPrefixRegressor(regressor, ctx));
15434b254c5SRichard Tran Mills   /* Check for command line options */
15534b254c5SRichard Tran Mills   PetscCall(PetscRegressorSetFromOptions(regressor));
15634b254c5SRichard Tran Mills   /* Fit the regressor */
15734b254c5SRichard Tran Mills   PetscCall(PetscRegressorFit(regressor, ctx->X, ctx->y));
15834b254c5SRichard Tran Mills   /* Predict data with fitted regressor */
15934b254c5SRichard Tran Mills   PetscCall(PetscRegressorPredict(regressor, ctx->X, ctx->y_predicted));
16034b254c5SRichard Tran Mills   /* Get other desired output data */
16134b254c5SRichard Tran Mills   PetscCall(PetscRegressorLinearGetIntercept(regressor, &intercept));
16234b254c5SRichard Tran Mills   PetscCall(PetscRegressorLinearGetCoefficients(regressor, &ctx->coefficients));
16334b254c5SRichard Tran Mills 
16434b254c5SRichard Tran Mills   /* Testing Views, and GetTypes */
16534b254c5SRichard Tran Mills   PetscCall(TestRegressorViews(regressor, ctx));
16634b254c5SRichard Tran Mills   PetscCall(PetscRegressorDestroy(&regressor));
16734b254c5SRichard Tran Mills   PetscCall(DestroyCtx(&ctx));
16834b254c5SRichard Tran Mills   PetscCall(PetscFinalize());
16934b254c5SRichard Tran Mills   return 0;
17034b254c5SRichard Tran Mills }
17134b254c5SRichard Tran Mills 
17234b254c5SRichard Tran Mills /*TEST
17334b254c5SRichard Tran Mills 
17434b254c5SRichard Tran Mills    build:
17534b254c5SRichard Tran Mills       requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES)
17634b254c5SRichard Tran Mills 
17734b254c5SRichard Tran Mills    test:
17834b254c5SRichard Tran Mills       suffix: prefix_tao
17934b254c5SRichard Tran Mills       args: -sys1_sys2_regressor_view -test_prefix
18034b254c5SRichard Tran Mills 
18134b254c5SRichard Tran Mills    test:
18234b254c5SRichard Tran Mills       suffix: prefix_ksp
183*789736e1SBarry Smith       args: -sys1_sys2_regressor_view -test_prefix -sys1_sys2_regressor_linear_use_ksp -sys1_sys2_regressor_linear_ksp_monitor
184*789736e1SBarry Smith 
185*789736e1SBarry Smith    test:
186*789736e1SBarry Smith       suffix: prefix_ksp_cholesky
187*789736e1SBarry Smith       args: -sys1_sys2_regressor_view -test_prefix -sys1_sys2_regressor_linear_use_ksp -sys1_sys2_regressor_linear_pc_type cholesky
188*789736e1SBarry Smith       TODO: Could not locate a solver type for factorization type CHOLESKY and matrix type normal
189*789736e1SBarry Smith 
190*789736e1SBarry Smith    test:
191*789736e1SBarry Smith       suffix: prefix_ksp_suitesparse
192*789736e1SBarry Smith       requires: suitesparse
193*789736e1SBarry Smith       args: -sys1_sys2_regressor_view -test_prefix -sys1_sys2_regressor_linear_use_ksp -sys1_sys2_regressor_linear_pc_type qr -sys1_sys2_regressor_linear_pc_factor_mat_solver_type spqr -sys1_sys2_regressor_linear_ksp_monitor
19434b254c5SRichard Tran Mills 
19534b254c5SRichard Tran Mills    test:
19634b254c5SRichard Tran Mills       suffix: asciiview
19734b254c5SRichard Tran Mills       args: -test_ascii_viewer
19834b254c5SRichard Tran Mills 
19934b254c5SRichard Tran Mills    test:
20034b254c5SRichard Tran Mills        suffix: stringview
20134b254c5SRichard Tran Mills        args: -test_string_viewer
20234b254c5SRichard Tran Mills 
20334b254c5SRichard Tran Mills    test:
20434b254c5SRichard Tran Mills       suffix: ksp_intercept
20534b254c5SRichard Tran Mills       args: -regressor_linear_use_ksp -regressor_linear_fit_intercept -regressor_view
20634b254c5SRichard Tran Mills 
20734b254c5SRichard Tran Mills    test:
20834b254c5SRichard Tran Mills       suffix: ksp_no_intercept
20934b254c5SRichard Tran Mills       args: -regressor_linear_use_ksp -regressor_view
21034b254c5SRichard Tran Mills 
21134b254c5SRichard Tran Mills    test:
21234b254c5SRichard Tran Mills       suffix: lasso_1
21334b254c5SRichard Tran Mills       nsize: 1
21434b254c5SRichard Tran Mills       args: -regressor_type linear -regressor_linear_type lasso -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols
21534b254c5SRichard Tran Mills 
21634b254c5SRichard Tran Mills    test:
21734b254c5SRichard Tran Mills       suffix: lasso_2
21834b254c5SRichard Tran Mills       nsize: 2
21934b254c5SRichard Tran Mills       args: -regressor_type linear -regressor_linear_type lasso -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols
22034b254c5SRichard Tran Mills 
22134b254c5SRichard Tran Mills    test:
22234b254c5SRichard Tran Mills       suffix: ridge_1
22334b254c5SRichard Tran Mills       nsize: 1
22434b254c5SRichard Tran Mills       args: -regressor_type linear -regressor_linear_type ridge -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols
22534b254c5SRichard Tran Mills 
22634b254c5SRichard Tran Mills    test:
22734b254c5SRichard Tran Mills       suffix: ridge_2
22834b254c5SRichard Tran Mills       nsize: 2
22934b254c5SRichard Tran Mills       args: -regressor_type linear -regressor_linear_type ridge -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols
23034b254c5SRichard Tran Mills 
23134b254c5SRichard Tran Mills    test:
23234b254c5SRichard Tran Mills       suffix: ols_1
23334b254c5SRichard Tran Mills       nsize: 1
23434b254c5SRichard Tran Mills       args: -regressor_type linear -regressor_linear_type ols -regressor_linear_fit_intercept -view_sols
23534b254c5SRichard Tran Mills 
23634b254c5SRichard Tran Mills    test:
23734b254c5SRichard Tran Mills       suffix: ols_2
23834b254c5SRichard Tran Mills       nsize: 2
23934b254c5SRichard Tran Mills       args: -regressor_type linear -regressor_linear_type ols -regressor_linear_fit_intercept -view_sols
24034b254c5SRichard Tran Mills 
24134b254c5SRichard Tran Mills TEST*/
242