xref: /petsc/src/ml/regressor/impls/linear/linear.c (revision c0b7dd1945badc50da999927ed61eaf839ea6a8c)
1 #include <../src/ml/regressor/impls/linear/linearimpl.h> /*I "petscregressor.h" I*/
2 
3 const char *const PetscRegressorLinearTypes[] = {"ols", "lasso", "ridge", "RegressorLinearType", "REGRESSOR_LINEAR_", NULL};
4 
5 static PetscErrorCode PetscRegressorLinearSetFitIntercept_Linear(PetscRegressor regressor, PetscBool flg)
6 {
7   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
8 
9   PetscFunctionBegin;
10   linear->fit_intercept = flg;
11   PetscFunctionReturn(PETSC_SUCCESS);
12 }
13 
14 static PetscErrorCode PetscRegressorLinearSetType_Linear(PetscRegressor regressor, PetscRegressorLinearType type)
15 {
16   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
17 
18   PetscFunctionBegin;
19   linear->type = type;
20   PetscFunctionReturn(PETSC_SUCCESS);
21 }
22 
23 static PetscErrorCode PetscRegressorLinearGetType_Linear(PetscRegressor regressor, PetscRegressorLinearType *type)
24 {
25   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
26 
27   PetscFunctionBegin;
28   *type = linear->type;
29   PetscFunctionReturn(PETSC_SUCCESS);
30 }
31 
32 static PetscErrorCode PetscRegressorLinearGetIntercept_Linear(PetscRegressor regressor, PetscScalar *intercept)
33 {
34   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
35 
36   PetscFunctionBegin;
37   *intercept = linear->intercept;
38   PetscFunctionReturn(PETSC_SUCCESS);
39 }
40 
41 static PetscErrorCode PetscRegressorLinearGetCoefficients_Linear(PetscRegressor regressor, Vec *coefficients)
42 {
43   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
44 
45   PetscFunctionBegin;
46   *coefficients = linear->coefficients;
47   PetscFunctionReturn(PETSC_SUCCESS);
48 }
49 
50 static PetscErrorCode PetscRegressorLinearGetKSP_Linear(PetscRegressor regressor, KSP *ksp)
51 {
52   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
53 
54   PetscFunctionBegin;
55   if (!linear->ksp) {
56     PetscCall(KSPCreate(PetscObjectComm((PetscObject)regressor), &linear->ksp));
57     PetscCall(PetscObjectIncrementTabLevel((PetscObject)linear->ksp, (PetscObject)regressor, 1));
58     PetscCall(PetscObjectSetOptions((PetscObject)linear->ksp, ((PetscObject)regressor)->options));
59   }
60   *ksp = linear->ksp;
61   PetscFunctionReturn(PETSC_SUCCESS);
62 }
63 
64 static PetscErrorCode PetscRegressorLinearSetUseKSP_Linear(PetscRegressor regressor, PetscBool flg)
65 {
66   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
67 
68   PetscFunctionBegin;
69   linear->use_ksp = flg;
70   PetscFunctionReturn(PETSC_SUCCESS);
71 }
72 
73 static PetscErrorCode EvaluateResidual(Tao tao, Vec x, Vec f, void *ptr)
74 {
75   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)ptr;
76 
77   PetscFunctionBegin;
78   /* Evaluate f = A * x - b */
79   PetscCall(MatMult(linear->X, x, f));
80   PetscCall(VecAXPY(f, -1.0, linear->rhs));
81   PetscFunctionReturn(PETSC_SUCCESS);
82 }
83 
84 static PetscErrorCode EvaluateJacobian(Tao tao, Vec x, Mat J, Mat Jpre, void *ptr)
85 {
86   /* The TAOBRGN API expects us to pass an EvaluateJacobian() routine to it, but in this case it is a dummy function.
87      Denoting our data matrix as X, for linear least squares J[m][n] = df[m]/dx[n] = X[m][n]. */
88   PetscFunctionBegin;
89   PetscFunctionReturn(PETSC_SUCCESS);
90 }
91 
92 static PetscErrorCode PetscRegressorSetUp_Linear(PetscRegressor regressor)
93 {
94   PetscInt               M, N;
95   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
96   KSP                    ksp;
97   Tao                    tao;
98 
99   PetscFunctionBegin;
100   PetscCall(MatGetSize(regressor->training, &M, &N));
101 
102   if (linear->fit_intercept) {
103     /* If we are fitting the intercept, we need to make A a composite matrix using MATCENTERING to preserve sparsity.
104      * Though there might be some cases we don't want to do this for, depending on what kind of matrix is passed in. (Probably bad idea for dense?)
105      * We will also need to ensure that the right-hand side passed to the KSP is also mean-centered, since we
106      * intend to compute the intercept separately from regression coefficients (that is, we will not be adding a
107      * column of all 1s to our design matrix). */
108     PetscCall(MatCreateCentering(PetscObjectComm((PetscObject)regressor), PETSC_DECIDE, M, &linear->C));
109     PetscCall(MatCreate(PetscObjectComm((PetscObject)regressor), &linear->X));
110     PetscCall(MatSetSizes(linear->X, PETSC_DECIDE, PETSC_DECIDE, M, N));
111     PetscCall(MatSetType(linear->X, MATCOMPOSITE));
112     PetscCall(MatCompositeSetType(linear->X, MAT_COMPOSITE_MULTIPLICATIVE));
113     PetscCall(MatCompositeAddMat(linear->X, regressor->training));
114     PetscCall(MatCompositeAddMat(linear->X, linear->C));
115     PetscCall(VecDuplicate(regressor->target, &linear->rhs));
116     PetscCall(MatMult(linear->C, regressor->target, linear->rhs));
117   } else {
118     // When not fitting intercept, we assume that the input data are already centered.
119     linear->X   = regressor->training;
120     linear->rhs = regressor->target;
121 
122     PetscCall(PetscObjectReference((PetscObject)linear->X));
123     PetscCall(PetscObjectReference((PetscObject)linear->rhs));
124   }
125 
126   if (linear->coefficients) PetscCall(VecDestroy(&linear->coefficients));
127 
128   if (linear->use_ksp) {
129     PetscCheck(linear->type == REGRESSOR_LINEAR_OLS, PetscObjectComm((PetscObject)regressor), PETSC_ERR_ARG_WRONGSTATE, "KSP can be used to fit a linear regressor only when its type is OLS");
130 
131     if (!linear->ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp));
132     ksp = linear->ksp;
133 
134     PetscCall(MatCreateVecs(linear->X, &linear->coefficients, NULL));
135     /* Set up the KSP to solve the least squares problem (without solving for intercept, as this is done separately) using KSPLSQR. */
136     PetscCall(MatCreateNormal(linear->X, &linear->XtX));
137     PetscCall(KSPSetType(ksp, KSPLSQR));
138     PetscCall(KSPSetOperators(ksp, linear->X, linear->XtX));
139     PetscCall(KSPSetOptionsPrefix(ksp, ((PetscObject)regressor)->prefix));
140     PetscCall(KSPAppendOptionsPrefix(ksp, "regressor_linear_"));
141     PetscCall(KSPSetFromOptions(ksp));
142   } else {
143     /* Note: Currently implementation creates TAO inside of implementations.
144       * Thus, all the prefix jobs are done inside implementations, not in interface */
145     const char *prefix;
146 
147     if (!regressor->tao) PetscCall(PetscRegressorGetTao(regressor, &tao));
148 
149     PetscCall(MatCreateVecs(linear->X, &linear->coefficients, &linear->residual));
150     /* Set up the TAO object to solve the (regularized) least squares problem (without solving for intercept, which is done separately) using TAOBRGN. */
151     PetscCall(TaoSetType(tao, TAOBRGN));
152     PetscCall(TaoSetSolution(tao, linear->coefficients));
153     PetscCall(TaoSetResidualRoutine(tao, linear->residual, EvaluateResidual, linear));
154     PetscCall(TaoSetJacobianResidualRoutine(tao, linear->X, linear->X, EvaluateJacobian, linear));
155     if (!linear->use_ksp) PetscCall(TaoBRGNSetRegularizerWeight(tao, regressor->regularizer_weight));
156     // Set the regularization type and weight for the BRGN as linear->type dictates:
157     // TODO BRGN needs to be BRGNSetRegularizationType
158     // PetscOptionsSetValue no longer works due to functioning prefix system
159     PetscCall(PetscRegressorGetOptionsPrefix(regressor, &prefix));
160     PetscCall(TaoSetOptionsPrefix(regressor->tao, prefix));
161     PetscCall(TaoAppendOptionsPrefix(tao, "regressor_linear_"));
162     switch (linear->type) {
163     case REGRESSOR_LINEAR_OLS:
164       regressor->regularizer_weight = 0.0; // OLS, by definition, uses a regularizer weight of 0
165       break;
166     case REGRESSOR_LINEAR_LASSO:
167       PetscCall(TaoBRGNSetRegularizationType(regressor->tao, TAOBRGN_REGULARIZATION_L1DICT));
168       break;
169     case REGRESSOR_LINEAR_RIDGE:
170       PetscCall(TaoBRGNSetRegularizationType(regressor->tao, TAOBRGN_REGULARIZATION_L2PURE));
171       break;
172     default:
173       break;
174     }
175     PetscCall(TaoSetFromOptions(tao));
176   }
177   PetscFunctionReturn(PETSC_SUCCESS);
178 }
179 
180 static PetscErrorCode PetscRegressorReset_Linear(PetscRegressor regressor)
181 {
182   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
183 
184   PetscFunctionBegin;
185   /* Destroy the PETSc objects associated with the linear regressor implementation. */
186   linear->ksp_its     = 0;
187   linear->ksp_tot_its = 0;
188 
189   PetscCall(MatDestroy(&linear->X));
190   PetscCall(MatDestroy(&linear->XtX));
191   PetscCall(MatDestroy(&linear->C));
192   PetscCall(KSPDestroy(&linear->ksp));
193   PetscCall(VecDestroy(&linear->coefficients));
194   PetscCall(VecDestroy(&linear->rhs));
195   PetscCall(VecDestroy(&linear->residual));
196   PetscFunctionReturn(PETSC_SUCCESS);
197 }
198 
199 static PetscErrorCode PetscRegressorDestroy_Linear(PetscRegressor regressor)
200 {
201   PetscFunctionBegin;
202   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", NULL));
203   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", NULL));
204   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", NULL));
205   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", NULL));
206   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", NULL));
207   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", NULL));
208   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", NULL));
209   PetscCall(PetscRegressorReset_Linear(regressor));
210   PetscCall(PetscFree(regressor->data));
211   PetscFunctionReturn(PETSC_SUCCESS);
212 }
213 
214 /*@
215   PetscRegressorLinearSetFitIntercept - Set a flag to indicate that the intercept (also known as the "bias" or "offset") should
216   be calculated; data are assumed to be mean-centered if false.
217 
218   Logically Collective
219 
220   Input Parameters:
221 + regressor - the `PetscRegressor` context
222 - flg       - `PETSC_TRUE` to calculate the intercept, `PETSC_FALSE` to assume mean-centered data (default is `PETSC_TRUE`)
223 
224   Level: intermediate
225 
226   Options Database Key:
227 . regressor_linear_fit_intercept <true,false> - fit the intercept
228 
229   Note:
230   If the user indicates that the intercept should not be calculated, the intercept will be set to zero.
231 
232 .seealso: `PetscRegressor`, `PetscRegressorFit()`
233 @*/
234 PetscErrorCode PetscRegressorLinearSetFitIntercept(PetscRegressor regressor, PetscBool flg)
235 {
236   PetscFunctionBegin;
237   /* TODO: Add companion PetscRegressorLinearGetFitIntercept(), and put it in the .seealso: */
238   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
239   PetscValidLogicalCollectiveBool(regressor, flg, 2);
240   PetscTryMethod(regressor, "PetscRegressorLinearSetFitIntercept_C", (PetscRegressor, PetscBool), (regressor, flg));
241   PetscFunctionReturn(PETSC_SUCCESS);
242 }
243 
244 /*@
245   PetscRegressorLinearSetUseKSP - Set a flag to indicate that a `KSP` object, instead of a `Tao` one, should be used
246   to fit the regressor
247 
248   Logically Collective
249 
250   Input Parameters:
251 + regressor - the `PetscRegressor` context
252 - flg       - `PETSC_TRUE` to use a `KSP`, `PETSC_FALSE` to use a `Tao` object (default is false)
253 
254   Options Database Key:
255 . regressor_linear_use_ksp <true,false> - use `KSP`
256 
257   Level: intermediate
258 
259 .seealso: `PetscRegressor`, `PetscRegressorLinearGetKSP()`
260 @*/
261 PetscErrorCode PetscRegressorLinearSetUseKSP(PetscRegressor regressor, PetscBool flg)
262 {
263   PetscFunctionBegin;
264   /* TODO: Add companion PetscRegressorLinearGetUseKSP(), and put it in the .seealso: */
265   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
266   PetscValidLogicalCollectiveBool(regressor, flg, 2);
267   PetscTryMethod(regressor, "PetscRegressorLinearSetUseKSP_C", (PetscRegressor, PetscBool), (regressor, flg));
268   PetscFunctionReturn(PETSC_SUCCESS);
269 }
270 
271 static PetscErrorCode PetscRegressorSetFromOptions_Linear(PetscRegressor regressor, PetscOptionItems PetscOptionsObject)
272 {
273   PetscBool              set, flg = PETSC_FALSE;
274   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
275 
276   PetscFunctionBegin;
277   PetscOptionsHeadBegin(PetscOptionsObject, "PetscRegressor options for linear regressors");
278   PetscCall(PetscOptionsBool("-regressor_linear_fit_intercept", "Calculate intercept for linear model", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set));
279   if (set) PetscCall(PetscRegressorLinearSetFitIntercept(regressor, flg));
280   PetscCall(PetscOptionsBool("-regressor_linear_use_ksp", "Use KSP instead of TAO for linear model fitting problem", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set));
281   if (set) PetscCall(PetscRegressorLinearSetUseKSP(regressor, flg));
282   PetscCall(PetscOptionsEnum("-regressor_linear_type", "Linear regression method", "PetscRegressorLinearTypes", PetscRegressorLinearTypes, (PetscEnum)linear->type, (PetscEnum *)&linear->type, NULL));
283   PetscOptionsHeadEnd();
284   PetscFunctionReturn(PETSC_SUCCESS);
285 }
286 
287 static PetscErrorCode PetscRegressorView_Linear(PetscRegressor regressor, PetscViewer viewer)
288 {
289   PetscBool              isascii;
290   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
291 
292   PetscFunctionBegin;
293   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
294   if (isascii) {
295     PetscCall(PetscViewerASCIIPushTab(viewer));
296     PetscCall(PetscViewerASCIIPrintf(viewer, "PetscRegressor Linear Type: %s\n", PetscRegressorLinearTypes[linear->type]));
297     if (linear->ksp) {
298       PetscCall(KSPView(linear->ksp, viewer));
299       PetscCall(PetscViewerASCIIPrintf(viewer, "total KSP iterations: %" PetscInt_FMT "\n", linear->ksp_tot_its));
300     }
301     if (linear->fit_intercept) PetscCall(PetscViewerASCIIPrintf(viewer, "Intercept=%g\n", (double)linear->intercept));
302     PetscCall(PetscViewerASCIIPopTab(viewer));
303   }
304   PetscFunctionReturn(PETSC_SUCCESS);
305 }
306 
307 /*@
308   PetscRegressorLinearGetKSP - Returns the `KSP` context for a `PETSCREGRESSORLINEAR` object.
309 
310   Not Collective, but if the `PetscRegressor` is parallel, then the `KSP` object is parallel
311 
312   Input Parameter:
313 . regressor - the `PetscRegressor` context
314 
315   Output Parameter:
316 . ksp - the `KSP` context
317 
318   Level: beginner
319 
320   Note:
321   This routine will always return a `KSP`, but, depending on the type of the linear regressor and the options that are set, the regressor may actually use a `Tao` object instead of this `KSP`.
322 
323 .seealso: `PetscRegressorGetTao()`
324 @*/
325 PetscErrorCode PetscRegressorLinearGetKSP(PetscRegressor regressor, KSP *ksp)
326 {
327   PetscFunctionBegin;
328   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
329   PetscAssertPointer(ksp, 2);
330   PetscUseMethod(regressor, "PetscRegressorLinearGetKSP_C", (PetscRegressor, KSP *), (regressor, ksp));
331   PetscFunctionReturn(PETSC_SUCCESS);
332 }
333 
334 /*@
335   PetscRegressorLinearGetCoefficients - Get a vector of the fitted coefficients from a linear regression model
336 
337   Not Collective but the vector is parallel
338 
339   Input Parameter:
340 . regressor - the `PetscRegressor` context
341 
342   Output Parameter:
343 . coefficients - the vector of the coefficients
344 
345   Level: beginner
346 
347 .seealso: `PetscRegressor`, `PetscRegressorLinearGetIntercept()`, `PETSCREGRESSORLINEAR`, `Vec`
348 @*/
349 PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetCoefficients(PetscRegressor regressor, Vec *coefficients)
350 {
351   PetscFunctionBegin;
352   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
353   PetscAssertPointer(coefficients, 2);
354   PetscUseMethod(regressor, "PetscRegressorLinearGetCoefficients_C", (PetscRegressor, Vec *), (regressor, coefficients));
355   PetscFunctionReturn(PETSC_SUCCESS);
356 }
357 
358 /*@
359   PetscRegressorLinearGetIntercept - Get the intercept from a linear regression model
360 
361   Not Collective
362 
363   Input Parameter:
364 . regressor - the `PetscRegressor` context
365 
366   Output Parameter:
367 . intercept - the intercept
368 
369   Level: beginner
370 
371 .seealso: `PetscRegressor`, `PetscRegressorLinearSetFitIntercept()`, `PetscRegressorLinearGetCoefficients()`, `PETSCREGRESSORLINEAR`
372 @*/
373 PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetIntercept(PetscRegressor regressor, PetscScalar *intercept)
374 {
375   PetscFunctionBegin;
376   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
377   PetscAssertPointer(intercept, 2);
378   PetscUseMethod(regressor, "PetscRegressorLinearGetIntercept_C", (PetscRegressor, PetscScalar *), (regressor, intercept));
379   PetscFunctionReturn(PETSC_SUCCESS);
380 }
381 
382 /*@C
383   PetscRegressorLinearSetType - Sets the type of linear regression to be performed
384 
385   Logically Collective
386 
387   Input Parameters:
388 + regressor - the `PetscRegressor` context (should be of type `PETSCREGRESSORLINEAR`)
389 - type      - a known linear regression method
390 
391   Options Database Key:
392 . -regressor_linear_type - Sets the linear regression method; use -help for a list of available methods
393    (for instance "-regressor_linear_type ols" or "-regressor_linear_type lasso")
394 
395   Level: intermediate
396 
397 .seealso: `PetscRegressorLinearGetType()`, `PetscRegressorLinearType`, `PetscRegressorSetType()`
398 @*/
399 PetscErrorCode PetscRegressorLinearSetType(PetscRegressor regressor, PetscRegressorLinearType type)
400 {
401   PetscFunctionBegin;
402   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
403   PetscValidLogicalCollectiveEnum(regressor, type, 2);
404   PetscTryMethod(regressor, "PetscRegressorLinearSetType_C", (PetscRegressor, PetscRegressorLinearType), (regressor, type));
405   PetscFunctionReturn(PETSC_SUCCESS);
406 }
407 
408 /*@
409   PetscRegressorLinearGetType - Return the type for the `PETSCREGRESSORLINEAR` solver
410 
411   Input Parameter:
412 . regressor - the `PetscRegressor` solver context
413 
414   Output Parameter:
415 . type - `PETSCREGRESSORLINEAR` type
416 
417   Level: advanced
418 
419 .seealso: `PetscRegressor`, `PETSCREGRESSORLINEAR`, `PetscRegressorLinearSetType()`, `PetscRegressorLinearType`
420 @*/
421 PetscErrorCode PetscRegressorLinearGetType(PetscRegressor regressor, PetscRegressorLinearType *type)
422 {
423   PetscFunctionBegin;
424   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
425   PetscAssertPointer(type, 2);
426   PetscUseMethod(regressor, "PetscRegressorLinearGetType_C", (PetscRegressor, PetscRegressorLinearType *), (regressor, type));
427   PetscFunctionReturn(PETSC_SUCCESS);
428 }
429 
430 static PetscErrorCode PetscRegressorFit_Linear(PetscRegressor regressor)
431 {
432   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
433   KSP                    ksp;
434   PetscScalar            target_mean, *column_means_global, *column_means_local, column_means_dot_coefficients;
435   Vec                    column_means;
436   PetscInt               m, N, istart, i, kspits;
437 
438   PetscFunctionBegin;
439   if (linear->use_ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp));
440   ksp = linear->ksp;
441 
442   /* Solve the least-squares problem (previously set up in PetscRegressorSetUp_Linear()) without finding the intercept. */
443   if (linear->use_ksp) {
444     PetscCall(KSPSolve(ksp, linear->rhs, linear->coefficients));
445     PetscCall(KSPGetIterationNumber(ksp, &kspits));
446     linear->ksp_its += kspits;
447     linear->ksp_tot_its += kspits;
448   } else {
449     PetscCall(TaoSolve(regressor->tao));
450   }
451 
452   /* Calculate the intercept. */
453   if (linear->fit_intercept) {
454     PetscCall(MatGetSize(regressor->training, NULL, &N));
455     PetscCall(PetscMalloc1(N, &column_means_global));
456     PetscCall(VecMean(regressor->target, &target_mean));
457     /* We need the means of all columns of regressor->training, placed into a Vec compatible with linear->coefficients.
458      * Note the potential scalability issue: MatGetColumnMeans() computes means of ALL colummns. */
459     PetscCall(MatGetColumnMeans(regressor->training, column_means_global));
460     /* TODO: Calculation of the Vec and matrix column means should probably go into the SetUp phase, and also be placed
461      *       into a routine that is callable from outside of PetscRegressorFit_Linear(), because we'll want to do the same
462      *       thing for other models, such as ridge and LASSO regression, and should avoid code duplication.
463      *       What we are calling 'target_mean' and 'column_means' should be stashed in the base linear regressor struct,
464      *       and perhaps renamed to make it clear they are offsets that should be applied (though the current naming
465      *       makes sense since it makes it clear where these come from.) */
466     PetscCall(VecDuplicate(linear->coefficients, &column_means));
467     PetscCall(VecGetLocalSize(column_means, &m));
468     PetscCall(VecGetOwnershipRange(column_means, &istart, NULL));
469     PetscCall(VecGetArrayWrite(column_means, &column_means_local));
470     for (i = 0; i < m; i++) column_means_local[i] = column_means_global[istart + i];
471     PetscCall(VecRestoreArrayWrite(column_means, &column_means_local));
472     PetscCall(VecDot(column_means, linear->coefficients, &column_means_dot_coefficients));
473     PetscCall(VecDestroy(&column_means));
474     PetscCall(PetscFree(column_means_global));
475     linear->intercept = target_mean - column_means_dot_coefficients;
476   } else {
477     linear->intercept = 0.0;
478   }
479   PetscFunctionReturn(PETSC_SUCCESS);
480 }
481 
482 static PetscErrorCode PetscRegressorPredict_Linear(PetscRegressor regressor, Mat X, Vec y)
483 {
484   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
485 
486   PetscFunctionBegin;
487   PetscCall(MatMult(X, linear->coefficients, y));
488   PetscCall(VecShift(y, linear->intercept));
489   PetscFunctionReturn(PETSC_SUCCESS);
490 }
491 
492 /*MC
493      PETSCREGRESSORLINEAR - Linear regression model (ordinary least squares or regularized variants)
494 
495    Options Database:
496 +  -regressor_linear_fit_intercept - Calculate the intercept for the linear model
497 -  -regressor_linear_use_ksp       - Use `KSP` instead of `Tao` for linear model fitting (non-regularized variants only)
498 
499    Level: beginner
500 
501    Note:
502    This is the default regressor in `PetscRegressor`.
503 
504 .seealso: `PetscRegressorCreate()`, `PetscRegressor`, `PetscRegressorSetType()`
505 M*/
506 PETSC_EXTERN PetscErrorCode PetscRegressorCreate_Linear(PetscRegressor regressor)
507 {
508   PetscRegressor_Linear *linear;
509 
510   PetscFunctionBegin;
511   PetscCall(PetscNew(&linear));
512   regressor->data = (void *)linear;
513 
514   regressor->ops->setup          = PetscRegressorSetUp_Linear;
515   regressor->ops->reset          = PetscRegressorReset_Linear;
516   regressor->ops->destroy        = PetscRegressorDestroy_Linear;
517   regressor->ops->setfromoptions = PetscRegressorSetFromOptions_Linear;
518   regressor->ops->view           = PetscRegressorView_Linear;
519   regressor->ops->fit            = PetscRegressorFit_Linear;
520   regressor->ops->predict        = PetscRegressorPredict_Linear;
521 
522   linear->intercept     = 0.0;
523   linear->fit_intercept = PETSC_TRUE;  /* Default to calculating the intercept. */
524   linear->use_ksp       = PETSC_FALSE; /* Do not default to using KSP for solving the model-fitting problem (use TAO instead). */
525   linear->type          = REGRESSOR_LINEAR_OLS;
526   /* Above, manually set the default linear regressor type.
527        We don't use PetscRegressorLinearSetType() here, because that expects the SetUp event to already have happened. */
528 
529   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", PetscRegressorLinearSetFitIntercept_Linear));
530   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", PetscRegressorLinearSetUseKSP_Linear));
531   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", PetscRegressorLinearGetKSP_Linear));
532   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", PetscRegressorLinearGetCoefficients_Linear));
533   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", PetscRegressorLinearGetIntercept_Linear));
534   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", PetscRegressorLinearSetType_Linear));
535   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", PetscRegressorLinearGetType_Linear));
536   PetscFunctionReturn(PETSC_SUCCESS);
537 }
538