xref: /petsc/src/ml/regressor/impls/linear/linear.c (revision 8577b683712d1cca1e9b8fdaa9ae028364224dad) !
1 #include <../src/ml/regressor/impls/linear/linearimpl.h> /*I "petscregressor.h" I*/
2 
3 const char *const PetscRegressorLinearTypes[] = {"ols", "lasso", "ridge", "RegressorLinearType", "REGRESSOR_LINEAR_", NULL};
4 
PetscRegressorLinearSetFitIntercept_Linear(PetscRegressor regressor,PetscBool flg)5 static PetscErrorCode PetscRegressorLinearSetFitIntercept_Linear(PetscRegressor regressor, PetscBool flg)
6 {
7   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
8 
9   PetscFunctionBegin;
10   linear->fit_intercept = flg;
11   PetscFunctionReturn(PETSC_SUCCESS);
12 }
13 
PetscRegressorLinearSetType_Linear(PetscRegressor regressor,PetscRegressorLinearType type)14 static PetscErrorCode PetscRegressorLinearSetType_Linear(PetscRegressor regressor, PetscRegressorLinearType type)
15 {
16   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
17 
18   PetscFunctionBegin;
19   linear->type = type;
20   PetscFunctionReturn(PETSC_SUCCESS);
21 }
22 
PetscRegressorLinearGetType_Linear(PetscRegressor regressor,PetscRegressorLinearType * type)23 static PetscErrorCode PetscRegressorLinearGetType_Linear(PetscRegressor regressor, PetscRegressorLinearType *type)
24 {
25   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
26 
27   PetscFunctionBegin;
28   *type = linear->type;
29   PetscFunctionReturn(PETSC_SUCCESS);
30 }
31 
PetscRegressorLinearGetIntercept_Linear(PetscRegressor regressor,PetscScalar * intercept)32 static PetscErrorCode PetscRegressorLinearGetIntercept_Linear(PetscRegressor regressor, PetscScalar *intercept)
33 {
34   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
35 
36   PetscFunctionBegin;
37   *intercept = linear->intercept;
38   PetscFunctionReturn(PETSC_SUCCESS);
39 }
40 
PetscRegressorLinearGetCoefficients_Linear(PetscRegressor regressor,Vec * coefficients)41 static PetscErrorCode PetscRegressorLinearGetCoefficients_Linear(PetscRegressor regressor, Vec *coefficients)
42 {
43   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
44 
45   PetscFunctionBegin;
46   *coefficients = linear->coefficients;
47   PetscFunctionReturn(PETSC_SUCCESS);
48 }
49 
PetscRegressorLinearGetKSP_Linear(PetscRegressor regressor,KSP * ksp)50 static PetscErrorCode PetscRegressorLinearGetKSP_Linear(PetscRegressor regressor, KSP *ksp)
51 {
52   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
53 
54   PetscFunctionBegin;
55   if (!linear->ksp) {
56     PetscCall(KSPCreate(PetscObjectComm((PetscObject)regressor), &linear->ksp));
57     PetscCall(PetscObjectIncrementTabLevel((PetscObject)linear->ksp, (PetscObject)regressor, 1));
58     PetscCall(PetscObjectSetOptions((PetscObject)linear->ksp, ((PetscObject)regressor)->options));
59   }
60   *ksp = linear->ksp;
61   PetscFunctionReturn(PETSC_SUCCESS);
62 }
63 
PetscRegressorLinearSetUseKSP_Linear(PetscRegressor regressor,PetscBool flg)64 static PetscErrorCode PetscRegressorLinearSetUseKSP_Linear(PetscRegressor regressor, PetscBool flg)
65 {
66   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
67 
68   PetscFunctionBegin;
69   linear->use_ksp = flg;
70   PetscFunctionReturn(PETSC_SUCCESS);
71 }
72 
EvaluateResidual(Tao tao,Vec x,Vec f,void * ptr)73 static PetscErrorCode EvaluateResidual(Tao tao, Vec x, Vec f, void *ptr)
74 {
75   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)ptr;
76 
77   PetscFunctionBegin;
78   /* Evaluate f = A * x - b */
79   PetscCall(MatMult(linear->X, x, f));
80   PetscCall(VecAXPY(f, -1.0, linear->rhs));
81   PetscFunctionReturn(PETSC_SUCCESS);
82 }
83 
EvaluateJacobian(Tao tao,Vec x,Mat J,Mat Jpre,void * ptr)84 static PetscErrorCode EvaluateJacobian(Tao tao, Vec x, Mat J, Mat Jpre, void *ptr)
85 {
86   /* The TAOBRGN API expects us to pass an EvaluateJacobian() routine to it, but in this case it is a dummy function.
87      Denoting our data matrix as X, for linear least squares J[m][n] = df[m]/dx[n] = X[m][n]. */
88   PetscFunctionBegin;
89   PetscFunctionReturn(PETSC_SUCCESS);
90 }
91 
PetscRegressorSetUp_Linear(PetscRegressor regressor)92 static PetscErrorCode PetscRegressorSetUp_Linear(PetscRegressor regressor)
93 {
94   PetscInt               M, N;
95   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
96   KSP                    ksp;
97   Tao                    tao;
98 
99   PetscFunctionBegin;
100   PetscCall(MatGetSize(regressor->training, &M, &N));
101 
102   if (linear->fit_intercept) {
103     /* If we are fitting the intercept, we need to make A a composite matrix using MATCENTERING to preserve sparsity.
104      * Though there might be some cases we don't want to do this for, depending on what kind of matrix is passed in. (Probably bad idea for dense?)
105      * We will also need to ensure that the right-hand side passed to the KSP is also mean-centered, since we
106      * intend to compute the intercept separately from regression coefficients (that is, we will not be adding a
107      * column of all 1s to our design matrix). */
108     PetscCall(MatCreateCentering(PetscObjectComm((PetscObject)regressor), PETSC_DECIDE, M, &linear->C));
109     PetscCall(MatCreate(PetscObjectComm((PetscObject)regressor), &linear->X));
110     PetscCall(MatSetSizes(linear->X, PETSC_DECIDE, PETSC_DECIDE, M, N));
111     PetscCall(MatSetType(linear->X, MATCOMPOSITE));
112     PetscCall(MatCompositeSetType(linear->X, MAT_COMPOSITE_MULTIPLICATIVE));
113     PetscCall(MatCompositeAddMat(linear->X, regressor->training));
114     PetscCall(MatCompositeAddMat(linear->X, linear->C));
115     PetscCall(VecDuplicate(regressor->target, &linear->rhs));
116     PetscCall(MatMult(linear->C, regressor->target, linear->rhs));
117   } else {
118     // When not fitting intercept, we assume that the input data are already centered.
119     linear->X   = regressor->training;
120     linear->rhs = regressor->target;
121 
122     PetscCall(PetscObjectReference((PetscObject)linear->X));
123     PetscCall(PetscObjectReference((PetscObject)linear->rhs));
124   }
125 
126   if (linear->coefficients) PetscCall(VecDestroy(&linear->coefficients));
127 
128   if (linear->use_ksp) {
129     PetscCheck(linear->type == REGRESSOR_LINEAR_OLS, PetscObjectComm((PetscObject)regressor), PETSC_ERR_ARG_WRONGSTATE, "KSP can be used to fit a linear regressor only when its type is OLS");
130 
131     if (!linear->ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp));
132     ksp = linear->ksp;
133 
134     PetscCall(MatCreateVecs(linear->X, &linear->coefficients, NULL));
135     /* Set up the KSP to solve the least squares problem (without solving for intercept, as this is done separately) using KSPLSQR. */
136     PetscCall(MatCreateNormal(linear->X, &linear->XtX));
137     PetscCall(KSPSetType(ksp, KSPLSQR));
138     PetscCall(KSPSetOperators(ksp, linear->X, linear->XtX));
139     PetscCall(KSPSetOptionsPrefix(ksp, ((PetscObject)regressor)->prefix));
140     PetscCall(KSPAppendOptionsPrefix(ksp, "regressor_linear_"));
141     PetscCall(KSPSetFromOptions(ksp));
142   } else {
143     /* Note: Currently implementation creates TAO inside of implementations.
144       * Thus, all the prefix jobs are done inside implementations, not in interface */
145     const char *prefix;
146 
147     if (!regressor->tao) PetscCall(PetscRegressorGetTao(regressor, &tao));
148 
149     PetscCall(MatCreateVecs(linear->X, &linear->coefficients, &linear->residual));
150     /* Set up the TAO object to solve the (regularized) least squares problem (without solving for intercept, which is done separately) using TAOBRGN. */
151     PetscCall(TaoSetType(tao, TAOBRGN));
152     PetscCall(TaoSetSolution(tao, linear->coefficients));
153     PetscCall(TaoSetResidualRoutine(tao, linear->residual, EvaluateResidual, linear));
154     PetscCall(TaoSetJacobianResidualRoutine(tao, linear->X, linear->X, EvaluateJacobian, linear));
155     // Set the regularization type and weight for the BRGN as linear->type dictates:
156     // TODO BRGN needs to be BRGNSetRegularizationType
157     // PetscOptionsSetValue no longer works due to functioning prefix system
158     PetscCall(PetscRegressorGetOptionsPrefix(regressor, &prefix));
159     PetscCall(TaoSetOptionsPrefix(regressor->tao, prefix));
160     PetscCall(TaoAppendOptionsPrefix(tao, "regressor_linear_"));
161     switch (linear->type) {
162     case REGRESSOR_LINEAR_OLS:
163       regressor->regularizer_weight = 0.0; // OLS, by definition, uses a regularizer weight of 0
164       break;
165     case REGRESSOR_LINEAR_LASSO:
166       PetscCall(TaoBRGNSetRegularizationType(regressor->tao, TAOBRGN_REGULARIZATION_L1DICT));
167       break;
168     case REGRESSOR_LINEAR_RIDGE:
169       PetscCall(TaoBRGNSetRegularizationType(regressor->tao, TAOBRGN_REGULARIZATION_L2PURE));
170       break;
171     default:
172       break;
173     }
174     if (!linear->use_ksp) PetscCall(TaoBRGNSetRegularizerWeight(tao, regressor->regularizer_weight));
175     PetscCall(TaoSetFromOptions(tao));
176   }
177   PetscFunctionReturn(PETSC_SUCCESS);
178 }
179 
PetscRegressorReset_Linear(PetscRegressor regressor)180 static PetscErrorCode PetscRegressorReset_Linear(PetscRegressor regressor)
181 {
182   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
183 
184   PetscFunctionBegin;
185   /* Destroy the PETSc objects associated with the linear regressor implementation. */
186   linear->ksp_its     = 0;
187   linear->ksp_tot_its = 0;
188 
189   PetscCall(MatDestroy(&linear->X));
190   PetscCall(MatDestroy(&linear->XtX));
191   PetscCall(MatDestroy(&linear->C));
192   PetscCall(KSPDestroy(&linear->ksp));
193   PetscCall(VecDestroy(&linear->coefficients));
194   PetscCall(VecDestroy(&linear->rhs));
195   PetscCall(VecDestroy(&linear->residual));
196   PetscFunctionReturn(PETSC_SUCCESS);
197 }
198 
PetscRegressorDestroy_Linear(PetscRegressor regressor)199 static PetscErrorCode PetscRegressorDestroy_Linear(PetscRegressor regressor)
200 {
201   PetscFunctionBegin;
202   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", NULL));
203   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", NULL));
204   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", NULL));
205   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", NULL));
206   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", NULL));
207   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", NULL));
208   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", NULL));
209   PetscCall(PetscRegressorReset_Linear(regressor));
210   PetscCall(PetscFree(regressor->data));
211   PetscFunctionReturn(PETSC_SUCCESS);
212 }
213 
214 /*@
215   PetscRegressorLinearSetFitIntercept - Set a flag to indicate that the intercept (also known as the "bias" or "offset") should
216   be calculated; data are assumed to be mean-centered if false.
217 
218   Logically Collective
219 
220   Input Parameters:
221 + regressor - the `PetscRegressor` context
222 - flg       - `PETSC_TRUE` to calculate the intercept, `PETSC_FALSE` to assume mean-centered data (default is `PETSC_TRUE`)
223 
224   Level: intermediate
225 
226   Options Database Key:
227 . regressor_linear_fit_intercept <true,false> - fit the intercept
228 
229   Note:
230   If the user indicates that the intercept should not be calculated, the intercept will be set to zero.
231 
232 .seealso: `PetscRegressor`, `PetscRegressorFit()`
233 @*/
PetscRegressorLinearSetFitIntercept(PetscRegressor regressor,PetscBool flg)234 PetscErrorCode PetscRegressorLinearSetFitIntercept(PetscRegressor regressor, PetscBool flg)
235 {
236   PetscFunctionBegin;
237   /* TODO: Add companion PetscRegressorLinearGetFitIntercept(), and put it in the .seealso: */
238   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
239   PetscValidLogicalCollectiveBool(regressor, flg, 2);
240   PetscTryMethod(regressor, "PetscRegressorLinearSetFitIntercept_C", (PetscRegressor, PetscBool), (regressor, flg));
241   PetscFunctionReturn(PETSC_SUCCESS);
242 }
243 
244 /*@
245   PetscRegressorLinearSetUseKSP - Set a flag to indicate that a `KSP` object, instead of a `Tao` one, should be used
246   to fit the linear regressor
247 
248   Logically Collective
249 
250   Input Parameters:
251 + regressor - the `PetscRegressor` context
252 - flg       - `PETSC_TRUE` to use a `KSP`, `PETSC_FALSE` to use a `Tao` object (default is false)
253 
254   Options Database Key:
255 . regressor_linear_use_ksp <true,false> - use `KSP`
256 
257   Level: intermediate
258 
259   Notes:
260   `KSPLSQR` with no preconditioner is used to solve the normal equations by default.
261 
262   For sequential `MATSEQAIJ` sparse matrices QR factorization a `PCType` of `PCQR` can be used to solve the least-squares system with a `MatSolverType` of
263   `MATSOLVERSPQR`, using, for example,
264 .vb
265   -ksp_type none -pc_type qr -pc_factor_mat_solver_type sp
266 .ve
267   if centering, `PetscRegressorLinearSetFitIntercept()`, is not used.
268 
269   Developer Notes:
270   It should be possible to use Cholesky (and any other preconditioners) to solve the normal equations.
271 
272   It should be possible to use QR if centering is used. See ml/regressor/ex1.c and ex2.c
273 
274   It should be possible to use dense SVD `PCSVD` and dense qr directly on the rectangular matrix to solve the least squares problem.
275 
276   Adding the above support seems to require a refactorization of how least squares problems are solved with PETSc in `KSPLSQR`
277 
278 .seealso: `PetscRegressor`, `PetscRegressorLinearGetKSP()`, `KSPLSQR`, `PCQR`, `MATSOLVERSPQR`, `MatSolverType`, `MATSEQDENSE`, `PCSVD`
279 @*/
PetscRegressorLinearSetUseKSP(PetscRegressor regressor,PetscBool flg)280 PetscErrorCode PetscRegressorLinearSetUseKSP(PetscRegressor regressor, PetscBool flg)
281 {
282   PetscFunctionBegin;
283   /* TODO: Add companion PetscRegressorLinearGetUseKSP(), and put it in the .seealso: */
284   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
285   PetscValidLogicalCollectiveBool(regressor, flg, 2);
286   PetscTryMethod(regressor, "PetscRegressorLinearSetUseKSP_C", (PetscRegressor, PetscBool), (regressor, flg));
287   PetscFunctionReturn(PETSC_SUCCESS);
288 }
289 
PetscRegressorSetFromOptions_Linear(PetscRegressor regressor,PetscOptionItems PetscOptionsObject)290 static PetscErrorCode PetscRegressorSetFromOptions_Linear(PetscRegressor regressor, PetscOptionItems PetscOptionsObject)
291 {
292   PetscBool              set, flg = PETSC_FALSE;
293   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
294 
295   PetscFunctionBegin;
296   PetscOptionsHeadBegin(PetscOptionsObject, "PetscRegressor options for linear regressors");
297   PetscCall(PetscOptionsBool("-regressor_linear_fit_intercept", "Calculate intercept for linear model", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set));
298   if (set) PetscCall(PetscRegressorLinearSetFitIntercept(regressor, flg));
299   PetscCall(PetscOptionsBool("-regressor_linear_use_ksp", "Use KSP instead of TAO for linear model fitting problem", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set));
300   if (set) PetscCall(PetscRegressorLinearSetUseKSP(regressor, flg));
301   PetscCall(PetscOptionsEnum("-regressor_linear_type", "Linear regression method", "PetscRegressorLinearTypes", PetscRegressorLinearTypes, (PetscEnum)linear->type, (PetscEnum *)&linear->type, NULL));
302   PetscOptionsHeadEnd();
303   PetscFunctionReturn(PETSC_SUCCESS);
304 }
305 
PetscRegressorView_Linear(PetscRegressor regressor,PetscViewer viewer)306 static PetscErrorCode PetscRegressorView_Linear(PetscRegressor regressor, PetscViewer viewer)
307 {
308   PetscBool              isascii;
309   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
310 
311   PetscFunctionBegin;
312   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
313   if (isascii) {
314     PetscCall(PetscViewerASCIIPushTab(viewer));
315     PetscCall(PetscViewerASCIIPrintf(viewer, "PetscRegressor Linear Type: %s\n", PetscRegressorLinearTypes[linear->type]));
316     if (linear->ksp) {
317       PetscCall(KSPView(linear->ksp, viewer));
318       PetscCall(PetscViewerASCIIPrintf(viewer, "total KSP iterations: %" PetscInt_FMT "\n", linear->ksp_tot_its));
319     }
320     if (linear->fit_intercept) PetscCall(PetscViewerASCIIPrintf(viewer, "Intercept=%g\n", (double)linear->intercept));
321     PetscCall(PetscViewerASCIIPopTab(viewer));
322   }
323   PetscFunctionReturn(PETSC_SUCCESS);
324 }
325 
326 /*@
327   PetscRegressorLinearGetKSP - Returns the `KSP` context for a `PETSCREGRESSORLINEAR` object.
328 
329   Not Collective, but if the `PetscRegressor` is parallel, then the `KSP` object is parallel
330 
331   Input Parameter:
332 . regressor - the `PetscRegressor` context
333 
334   Output Parameter:
335 . ksp - the `KSP` context
336 
337   Level: beginner
338 
339   Note:
340   This routine will always return a `KSP`, but, depending on the type of the linear regressor and the options that are set, the regressor may actually use a `Tao` object instead of this `KSP`.
341 
342 .seealso: `PetscRegressorGetTao()`
343 @*/
PetscRegressorLinearGetKSP(PetscRegressor regressor,KSP * ksp)344 PetscErrorCode PetscRegressorLinearGetKSP(PetscRegressor regressor, KSP *ksp)
345 {
346   PetscFunctionBegin;
347   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
348   PetscAssertPointer(ksp, 2);
349   PetscUseMethod(regressor, "PetscRegressorLinearGetKSP_C", (PetscRegressor, KSP *), (regressor, ksp));
350   PetscFunctionReturn(PETSC_SUCCESS);
351 }
352 
353 /*@
354   PetscRegressorLinearGetCoefficients - Get a vector of the fitted coefficients from a linear regression model
355 
356   Not Collective but the vector is parallel
357 
358   Input Parameter:
359 . regressor - the `PetscRegressor` context
360 
361   Output Parameter:
362 . coefficients - the vector of the coefficients
363 
364   Level: beginner
365 
366 .seealso: `PetscRegressor`, `PetscRegressorLinearGetIntercept()`, `PETSCREGRESSORLINEAR`, `Vec`
367 @*/
PetscRegressorLinearGetCoefficients(PetscRegressor regressor,Vec * coefficients)368 PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetCoefficients(PetscRegressor regressor, Vec *coefficients)
369 {
370   PetscFunctionBegin;
371   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
372   PetscAssertPointer(coefficients, 2);
373   PetscUseMethod(regressor, "PetscRegressorLinearGetCoefficients_C", (PetscRegressor, Vec *), (regressor, coefficients));
374   PetscFunctionReturn(PETSC_SUCCESS);
375 }
376 
377 /*@
378   PetscRegressorLinearGetIntercept - Get the intercept from a linear regression model
379 
380   Not Collective
381 
382   Input Parameter:
383 . regressor - the `PetscRegressor` context
384 
385   Output Parameter:
386 . intercept - the intercept
387 
388   Level: beginner
389 
390 .seealso: `PetscRegressor`, `PetscRegressorLinearSetFitIntercept()`, `PetscRegressorLinearGetCoefficients()`, `PETSCREGRESSORLINEAR`
391 @*/
PetscRegressorLinearGetIntercept(PetscRegressor regressor,PetscScalar * intercept)392 PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetIntercept(PetscRegressor regressor, PetscScalar *intercept)
393 {
394   PetscFunctionBegin;
395   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
396   PetscAssertPointer(intercept, 2);
397   PetscUseMethod(regressor, "PetscRegressorLinearGetIntercept_C", (PetscRegressor, PetscScalar *), (regressor, intercept));
398   PetscFunctionReturn(PETSC_SUCCESS);
399 }
400 
401 /*@C
402   PetscRegressorLinearSetType - Sets the type of linear regression to be performed
403 
404   Logically Collective
405 
406   Input Parameters:
407 + regressor - the `PetscRegressor` context (should be of type `PETSCREGRESSORLINEAR`)
408 - type      - a known linear regression method
409 
410   Options Database Key:
411 . -regressor_linear_type - Sets the linear regression method; use -help for a list of available methods
412    (for instance "-regressor_linear_type ols" or "-regressor_linear_type lasso")
413 
414   Level: intermediate
415 
416 .seealso: `PetscRegressorLinearGetType()`, `PetscRegressorLinearType`, `PetscRegressorSetType()`, `REGRESSOR_LINEAR_OLS`,
417           `REGRESSOR_LINEAR_LASSO`, `REGRESSOR_LINEAR_RIDGE`
418 @*/
PetscRegressorLinearSetType(PetscRegressor regressor,PetscRegressorLinearType type)419 PetscErrorCode PetscRegressorLinearSetType(PetscRegressor regressor, PetscRegressorLinearType type)
420 {
421   PetscFunctionBegin;
422   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
423   PetscValidLogicalCollectiveEnum(regressor, type, 2);
424   PetscTryMethod(regressor, "PetscRegressorLinearSetType_C", (PetscRegressor, PetscRegressorLinearType), (regressor, type));
425   PetscFunctionReturn(PETSC_SUCCESS);
426 }
427 
428 /*@
429   PetscRegressorLinearGetType - Return the type for the `PETSCREGRESSORLINEAR` solver
430 
431   Input Parameter:
432 . regressor - the `PetscRegressor` solver context
433 
434   Output Parameter:
435 . type - `PETSCREGRESSORLINEAR` type
436 
437   Level: advanced
438 
439 .seealso: `PetscRegressor`, `PETSCREGRESSORLINEAR`, `PetscRegressorLinearSetType()`, `PetscRegressorLinearType`
440 @*/
PetscRegressorLinearGetType(PetscRegressor regressor,PetscRegressorLinearType * type)441 PetscErrorCode PetscRegressorLinearGetType(PetscRegressor regressor, PetscRegressorLinearType *type)
442 {
443   PetscFunctionBegin;
444   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
445   PetscAssertPointer(type, 2);
446   PetscUseMethod(regressor, "PetscRegressorLinearGetType_C", (PetscRegressor, PetscRegressorLinearType *), (regressor, type));
447   PetscFunctionReturn(PETSC_SUCCESS);
448 }
449 
PetscRegressorFit_Linear(PetscRegressor regressor)450 static PetscErrorCode PetscRegressorFit_Linear(PetscRegressor regressor)
451 {
452   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
453   KSP                    ksp;
454   PetscScalar            target_mean, *column_means_global, *column_means_local, column_means_dot_coefficients;
455   Vec                    column_means;
456   PetscInt               m, N, istart, i, kspits;
457 
458   PetscFunctionBegin;
459   if (linear->use_ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp));
460   ksp = linear->ksp;
461 
462   /* Solve the least-squares problem (previously set up in PetscRegressorSetUp_Linear()) without finding the intercept. */
463   if (linear->use_ksp) {
464     PetscCall(KSPSolve(ksp, linear->rhs, linear->coefficients));
465     PetscCall(KSPGetIterationNumber(ksp, &kspits));
466     linear->ksp_its += kspits;
467     linear->ksp_tot_its += kspits;
468   } else {
469     PetscCall(TaoSolve(regressor->tao));
470   }
471 
472   /* Calculate the intercept. */
473   if (linear->fit_intercept) {
474     PetscCall(MatGetSize(regressor->training, NULL, &N));
475     PetscCall(PetscMalloc1(N, &column_means_global));
476     PetscCall(VecMean(regressor->target, &target_mean));
477     /* We need the means of all columns of regressor->training, placed into a Vec compatible with linear->coefficients.
478      * Note the potential scalability issue: MatGetColumnMeans() computes means of ALL columns. */
479     PetscCall(MatGetColumnMeans(regressor->training, column_means_global));
480     /* TODO: Calculation of the Vec and matrix column means should probably go into the SetUp phase, and also be placed
481      *       into a routine that is callable from outside of PetscRegressorFit_Linear(), because we'll want to do the same
482      *       thing for other models, such as ridge and LASSO regression, and should avoid code duplication.
483      *       What we are calling 'target_mean' and 'column_means' should be stashed in the base linear regressor struct,
484      *       and perhaps renamed to make it clear they are offsets that should be applied (though the current naming
485      *       makes sense since it makes it clear where these come from.) */
486     PetscCall(VecDuplicate(linear->coefficients, &column_means));
487     PetscCall(VecGetLocalSize(column_means, &m));
488     PetscCall(VecGetOwnershipRange(column_means, &istart, NULL));
489     PetscCall(VecGetArrayWrite(column_means, &column_means_local));
490     for (i = 0; i < m; i++) column_means_local[i] = column_means_global[istart + i];
491     PetscCall(VecRestoreArrayWrite(column_means, &column_means_local));
492     PetscCall(VecDot(column_means, linear->coefficients, &column_means_dot_coefficients));
493     PetscCall(VecDestroy(&column_means));
494     PetscCall(PetscFree(column_means_global));
495     linear->intercept = target_mean - column_means_dot_coefficients;
496   } else {
497     linear->intercept = 0.0;
498   }
499   PetscFunctionReturn(PETSC_SUCCESS);
500 }
501 
PetscRegressorPredict_Linear(PetscRegressor regressor,Mat X,Vec y)502 static PetscErrorCode PetscRegressorPredict_Linear(PetscRegressor regressor, Mat X, Vec y)
503 {
504   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
505 
506   PetscFunctionBegin;
507   PetscCall(MatMult(X, linear->coefficients, y));
508   PetscCall(VecShift(y, linear->intercept));
509   PetscFunctionReturn(PETSC_SUCCESS);
510 }
511 
512 /*MC
513      PETSCREGRESSORLINEAR - Linear regression model (ordinary least squares or regularized variants)
514 
515    Options Database:
516 +  -regressor_linear_fit_intercept - Calculate the intercept for the linear model
517 -  -regressor_linear_use_ksp       - Use `KSP` instead of `Tao` for linear model fitting (non-regularized variants only)
518 
519    Level: beginner
520 
521    Notes:
522    By "linear" we mean that the model is linear in its coefficients, but not necessarily in its input features.
523    One can use the linear regressor to fit polynomial functions by training the model with a design matrix that
524    is a nonlinear function of the input data.
525 
526    This is the default regressor in `PetscRegressor`.
527 
528 .seealso: `PetscRegressorCreate()`, `PetscRegressor`, `PetscRegressorSetType()`
529 M*/
PetscRegressorCreate_Linear(PetscRegressor regressor)530 PETSC_EXTERN PetscErrorCode PetscRegressorCreate_Linear(PetscRegressor regressor)
531 {
532   PetscRegressor_Linear *linear;
533 
534   PetscFunctionBegin;
535   PetscCall(PetscNew(&linear));
536   regressor->data = (void *)linear;
537 
538   regressor->ops->setup          = PetscRegressorSetUp_Linear;
539   regressor->ops->reset          = PetscRegressorReset_Linear;
540   regressor->ops->destroy        = PetscRegressorDestroy_Linear;
541   regressor->ops->setfromoptions = PetscRegressorSetFromOptions_Linear;
542   regressor->ops->view           = PetscRegressorView_Linear;
543   regressor->ops->fit            = PetscRegressorFit_Linear;
544   regressor->ops->predict        = PetscRegressorPredict_Linear;
545 
546   linear->intercept     = 0.0;
547   linear->fit_intercept = PETSC_TRUE;  /* Default to calculating the intercept. */
548   linear->use_ksp       = PETSC_FALSE; /* Do not default to using KSP for solving the model-fitting problem (use TAO instead). */
549   linear->type          = REGRESSOR_LINEAR_OLS;
550   /* Above, manually set the default linear regressor type.
551        We don't use PetscRegressorLinearSetType() here, because that expects the SetUp event to already have happened. */
552 
553   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", PetscRegressorLinearSetFitIntercept_Linear));
554   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", PetscRegressorLinearSetUseKSP_Linear));
555   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", PetscRegressorLinearGetKSP_Linear));
556   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", PetscRegressorLinearGetCoefficients_Linear));
557   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", PetscRegressorLinearGetIntercept_Linear));
558   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", PetscRegressorLinearSetType_Linear));
559   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", PetscRegressorLinearGetType_Linear));
560   PetscFunctionReturn(PETSC_SUCCESS);
561 }
562