xref: /petsc/src/ml/regressor/tests/ex3.c (revision 6bfab51239a1d021a2781a42e04752bb50d6082e)
1 #include <petscregressor.h>
2 
3 static char help[] = "Tests some linear PetscRegressor types with different regularizers.\n\n";
4 
5 typedef struct _AppCtx {
6   Mat       X;           /* Training data */
7   Vec       y;           /* Target data   */
8   Vec       y_predicted; /* Target data   */
9   Vec       coefficients;
10   PetscInt  N; /* Data size     */
11   PetscBool flg_string;
12   PetscBool flg_ascii;
13   PetscBool flg_view_sol;
14   PetscBool test_prefix;
15 } *AppCtx;
16 
DestroyCtx(AppCtx * ctx)17 static PetscErrorCode DestroyCtx(AppCtx *ctx)
18 {
19   PetscFunctionBegin;
20   PetscCall(MatDestroy(&(*ctx)->X));
21   PetscCall(VecDestroy(&(*ctx)->y));
22   PetscCall(VecDestroy(&(*ctx)->y_predicted));
23   PetscCall(PetscFree(*ctx));
24   PetscFunctionReturn(PETSC_SUCCESS);
25 }
26 
TestRegressorViews(PetscRegressor regressor,AppCtx ctx)27 static PetscErrorCode TestRegressorViews(PetscRegressor regressor, AppCtx ctx)
28 {
29   PetscRegressorType check_type;
30   PetscBool          match;
31 
32   PetscFunctionBegin;
33   if (ctx->flg_view_sol) {
34     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Training target vector is\n"));
35     PetscCall(VecView(ctx->y, PETSC_VIEWER_STDOUT_WORLD));
36     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Predicted values are\n"));
37     PetscCall(VecView(ctx->y_predicted, PETSC_VIEWER_STDOUT_WORLD));
38     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Coefficients are\n"));
39     PetscCall(VecView(ctx->coefficients, PETSC_VIEWER_STDOUT_WORLD));
40   }
41 
42   if (ctx->flg_string) {
43     PetscViewer stringviewer;
44     char        string[512];
45     const char *outstring;
46 
47     PetscCall(PetscViewerStringOpen(PETSC_COMM_WORLD, string, sizeof(string), &stringviewer));
48     PetscCall(PetscRegressorView(regressor, stringviewer));
49     PetscCall(PetscViewerStringGetStringRead(stringviewer, &outstring, NULL));
50     PetscCheck((char *)outstring == (char *)string, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "String returned from viewer does not equal original string");
51     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Output from string viewer:%s\n", outstring));
52     PetscCall(PetscViewerDestroy(&stringviewer));
53   } else if (ctx->flg_ascii) PetscCall(PetscRegressorView(regressor, PETSC_VIEWER_STDOUT_WORLD));
54 
55   PetscCall(PetscRegressorGetType(regressor, &check_type));
56   PetscCall(PetscStrcmp(check_type, PETSCREGRESSORLINEAR, &match));
57   PetscCheck(match, PETSC_COMM_WORLD, PETSC_ERR_ARG_NOTSAMETYPE, "Regressor type is not Linear");
58   PetscFunctionReturn(PETSC_SUCCESS);
59 }
60 
TestPrefixRegressor(PetscRegressor regressor,AppCtx ctx)61 static PetscErrorCode TestPrefixRegressor(PetscRegressor regressor, AppCtx ctx)
62 {
63   PetscFunctionBegin;
64   if (ctx->test_prefix) {
65     PetscCall(PetscRegressorSetOptionsPrefix(regressor, "sys1_"));
66     PetscCall(PetscRegressorAppendOptionsPrefix(regressor, "sys2_"));
67   }
68   PetscFunctionReturn(PETSC_SUCCESS);
69 }
70 
CreateData(AppCtx ctx)71 static PetscErrorCode CreateData(AppCtx ctx)
72 {
73   PetscMPIInt rank;
74   PetscInt    i;
75   PetscScalar mean;
76 
77   PetscFunctionBegin;
78   PetscCallMPI(MPI_Comm_rank(PETSC_COMM_WORLD, &rank));
79   PetscCall(VecCreate(PETSC_COMM_WORLD, &ctx->y));
80   PetscCall(VecSetSizes(ctx->y, PETSC_DECIDE, ctx->N));
81   PetscCall(VecSetFromOptions(ctx->y));
82   PetscCall(VecDuplicate(ctx->y, &ctx->y_predicted));
83   PetscCall(MatCreate(PETSC_COMM_WORLD, &ctx->X));
84   PetscCall(MatSetSizes(ctx->X, PETSC_DECIDE, PETSC_DECIDE, ctx->N, ctx->N));
85   PetscCall(MatSetFromOptions(ctx->X));
86   PetscCall(MatSetUp(ctx->X));
87 
88   if (!rank) {
89     for (i = 0; i < ctx->N; i++) {
90       PetscCall(VecSetValue(ctx->y, i, (PetscScalar)i, INSERT_VALUES));
91       PetscCall(MatSetValue(ctx->X, i, i, 1.0, INSERT_VALUES));
92     }
93   }
94   /* Set up a training data matrix that is the identity.
95    * We do this because this gives us a special case in which we can analytically determine what the regression
96    * coefficients should be for ordinary least squares, LASSO (L1 regularized), and ridge (L2 regularized) regression.
97    * See details in section 6.2 of James et al.'s An Introduction to Statistical Learning (ISLR), in the subsection
98    * titled "A Simple Special Case for Ridge Regression and the Lasso".
99    * Note that the coefficients we generate with ridge regression (-regressor_linear_type ridge -regressor_regularizer_weight <lambda>, or, equivalently,
100    * -tao_brgn_regularization_type l2pure -tao_brgn_regularizer_weight <lambda>) match those of the ISLR formula exactly.
101    * For LASSO it does not match the ISLR formula: where they use lambda/2, we need to use lambda.
102    * It also doesn't match what Scikit-learn does; in that case their lambda is 1/n_samples of our lambda. Apparently everyone is scaling
103    * their loss function by a different value, hence the need to change what "lambda" is. But it's clear that ISLR, Scikit-learn, and we
104    * are basically doing the same thing otherwise. */
105   PetscCall(VecAssemblyBegin(ctx->y));
106   PetscCall(VecAssemblyEnd(ctx->y));
107   PetscCall(MatAssemblyBegin(ctx->X, MAT_FINAL_ASSEMBLY));
108   PetscCall(MatAssemblyEnd(ctx->X, MAT_FINAL_ASSEMBLY));
109   /* Center the target vector we will train with. */
110   PetscCall(VecMean(ctx->y, &mean));
111   PetscCall(VecShift(ctx->y, -1.0 * mean));
112   PetscFunctionReturn(PETSC_SUCCESS);
113 }
114 
ConfigureContext(AppCtx ctx)115 static PetscErrorCode ConfigureContext(AppCtx ctx)
116 {
117   PetscFunctionBegin;
118   ctx->flg_string   = PETSC_FALSE;
119   ctx->flg_ascii    = PETSC_FALSE;
120   ctx->flg_view_sol = PETSC_FALSE;
121   ctx->test_prefix  = PETSC_FALSE;
122   ctx->N            = 10;
123   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Options for PetscRegressor ex3:", "");
124   PetscCall(PetscOptionsInt("-N", "Dimension of the N x N data matrix", "ex3.c", ctx->N, &ctx->N, NULL));
125   PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_string_viewer", &ctx->flg_string, NULL));
126   PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_ascii_viewer", &ctx->flg_ascii, NULL));
127   PetscCall(PetscOptionsGetBool(NULL, NULL, "-view_sols", &ctx->flg_view_sol, NULL));
128   PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_prefix", &ctx->test_prefix, NULL));
129   PetscOptionsEnd();
130   PetscFunctionReturn(PETSC_SUCCESS);
131 }
132 
main(int argc,char ** args)133 int main(int argc, char **args)
134 {
135   AppCtx         ctx;
136   PetscRegressor regressor;
137   PetscScalar    intercept;
138 
139   /* Initialize PETSc */
140   PetscCall(PetscInitialize(&argc, &args, (char *)0, help));
141 
142   /* Initialize problem parameters and data */
143   PetscCall(PetscNew(&ctx));
144   PetscCall(ConfigureContext(ctx));
145   PetscCall(CreateData(ctx));
146 
147   /* Create Regressor solver with desired type and options */
148   PetscCall(PetscRegressorCreate(PETSC_COMM_WORLD, &regressor));
149   PetscCall(PetscRegressorSetType(regressor, PETSCREGRESSORLINEAR));
150   PetscCall(PetscRegressorLinearSetType(regressor, REGRESSOR_LINEAR_OLS));
151   PetscCall(PetscRegressorLinearSetFitIntercept(regressor, PETSC_FALSE));
152   /* Testing prefix functions for Regressor */
153   PetscCall(TestPrefixRegressor(regressor, ctx));
154   /* Check for command line options */
155   PetscCall(PetscRegressorSetFromOptions(regressor));
156   /* Fit the regressor */
157   PetscCall(PetscRegressorFit(regressor, ctx->X, ctx->y));
158   /* Predict data with fitted regressor */
159   PetscCall(PetscRegressorPredict(regressor, ctx->X, ctx->y_predicted));
160   /* Get other desired output data */
161   PetscCall(PetscRegressorLinearGetIntercept(regressor, &intercept));
162   PetscCall(PetscRegressorLinearGetCoefficients(regressor, &ctx->coefficients));
163 
164   /* Testing Views, and GetTypes */
165   PetscCall(TestRegressorViews(regressor, ctx));
166   PetscCall(PetscRegressorDestroy(&regressor));
167   PetscCall(DestroyCtx(&ctx));
168   PetscCall(PetscFinalize());
169   return 0;
170 }
171 
172 /*TEST
173 
174    build:
175       requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES)
176 
177    test:
178       suffix: prefix_tao
179       args: -sys1_sys2_regressor_view -test_prefix
180 
181    test:
182       suffix: prefix_ksp
183       args: -sys1_sys2_regressor_view -test_prefix -sys1_sys2_regressor_linear_use_ksp -sys1_sys2_regressor_linear_ksp_monitor
184 
185    test:
186       suffix: prefix_ksp_cholesky
187       args: -sys1_sys2_regressor_view -test_prefix -sys1_sys2_regressor_linear_use_ksp -sys1_sys2_regressor_linear_pc_type cholesky
188       TODO: Could not locate a solver type for factorization type CHOLESKY and matrix type normal
189 
190    test:
191       suffix: prefix_ksp_suitesparse
192       requires: suitesparse
193       args: -sys1_sys2_regressor_view -test_prefix -sys1_sys2_regressor_linear_use_ksp -sys1_sys2_regressor_linear_pc_type qr -sys1_sys2_regressor_linear_pc_factor_mat_solver_type spqr -sys1_sys2_regressor_linear_ksp_monitor
194 
195    test:
196       suffix: asciiview
197       args: -test_ascii_viewer
198 
199    test:
200        suffix: stringview
201        args: -test_string_viewer
202 
203    test:
204       suffix: ksp_intercept
205       args: -regressor_linear_use_ksp -regressor_linear_fit_intercept -regressor_view
206 
207    test:
208       suffix: ksp_no_intercept
209       args: -regressor_linear_use_ksp -regressor_view
210 
211    test:
212       suffix: lasso_1
213       nsize: 1
214       args: -regressor_type linear -regressor_linear_type lasso -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols
215 
216    test:
217       suffix: lasso_2
218       nsize: 2
219       args: -regressor_type linear -regressor_linear_type lasso -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols
220 
221    test:
222       suffix: ridge_1
223       nsize: 1
224       args: -regressor_type linear -regressor_linear_type ridge -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols
225 
226    test:
227       suffix: ridge_2
228       nsize: 2
229       args: -regressor_type linear -regressor_linear_type ridge -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols
230 
231    test:
232       suffix: ols_1
233       nsize: 1
234       args: -regressor_type linear -regressor_linear_type ols -regressor_linear_fit_intercept -view_sols
235 
236    test:
237       suffix: ols_2
238       nsize: 2
239       args: -regressor_type linear -regressor_linear_type ols -regressor_linear_fit_intercept -view_sols
240 
241 TEST*/
242