xref: /petsc/src/ml/regressor/impls/linear/linear.c (revision d016bdde269de9549a736fe23cc3868ea52c341b)
1 #include <../src/ml/regressor/impls/linear/linearimpl.h> /*I "petscregressor.h" I*/
2 #include <../src/tao/leastsquares/impls/brgn/brgn.h>     /*I "petsctao.h" I*/
3 
4 const char *const PetscRegressorLinearTypes[] = {"ols", "lasso", "ridge", "RegressorLinearType", "REGRESSOR_LINEAR_", NULL};
5 
6 static PetscErrorCode PetscRegressorLinearSetFitIntercept_Linear(PetscRegressor regressor, PetscBool flg)
7 {
8   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
9 
10   PetscFunctionBegin;
11   linear->fit_intercept = flg;
12   PetscFunctionReturn(PETSC_SUCCESS);
13 }
14 
15 static PetscErrorCode PetscRegressorLinearSetType_Linear(PetscRegressor regressor, PetscRegressorLinearType type)
16 {
17   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
18 
19   PetscFunctionBegin;
20   linear->type = type;
21   PetscFunctionReturn(PETSC_SUCCESS);
22 }
23 
24 static PetscErrorCode PetscRegressorLinearGetType_Linear(PetscRegressor regressor, PetscRegressorLinearType *type)
25 {
26   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
27 
28   PetscFunctionBegin;
29   *type = linear->type;
30   PetscFunctionReturn(PETSC_SUCCESS);
31 }
32 
33 static PetscErrorCode PetscRegressorLinearGetIntercept_Linear(PetscRegressor regressor, PetscScalar *intercept)
34 {
35   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
36 
37   PetscFunctionBegin;
38   *intercept = linear->intercept;
39   PetscFunctionReturn(PETSC_SUCCESS);
40 }
41 
42 static PetscErrorCode PetscRegressorLinearGetCoefficients_Linear(PetscRegressor regressor, Vec *coefficients)
43 {
44   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
45 
46   PetscFunctionBegin;
47   *coefficients = linear->coefficients;
48   PetscFunctionReturn(PETSC_SUCCESS);
49 }
50 
51 static PetscErrorCode PetscRegressorLinearGetKSP_Linear(PetscRegressor regressor, KSP *ksp)
52 {
53   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
54 
55   PetscFunctionBegin;
56   if (!linear->ksp) {
57     PetscCall(KSPCreate(PetscObjectComm((PetscObject)regressor), &linear->ksp));
58     PetscCall(PetscObjectIncrementTabLevel((PetscObject)linear->ksp, (PetscObject)regressor, 1));
59     PetscCall(PetscObjectSetOptions((PetscObject)linear->ksp, ((PetscObject)regressor)->options));
60   }
61   *ksp = linear->ksp;
62   PetscFunctionReturn(PETSC_SUCCESS);
63 }
64 
65 static PetscErrorCode PetscRegressorLinearSetUseKSP_Linear(PetscRegressor regressor, PetscBool flg)
66 {
67   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
68 
69   PetscFunctionBegin;
70   linear->use_ksp = flg;
71   PetscFunctionReturn(PETSC_SUCCESS);
72 }
73 
74 static PetscErrorCode EvaluateResidual(Tao tao, Vec x, Vec f, void *ptr)
75 {
76   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)ptr;
77 
78   PetscFunctionBegin;
79   /* Evaluate f = A * x - b */
80   PetscCall(MatMult(linear->X, x, f));
81   PetscCall(VecAXPY(f, -1.0, linear->rhs));
82   PetscFunctionReturn(PETSC_SUCCESS);
83 }
84 
85 static PetscErrorCode EvaluateJacobian(Tao tao, Vec x, Mat J, Mat Jpre, void *ptr)
86 {
87   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)ptr;
88 
89   PetscFunctionBegin;
90   J    = linear->X;
91   Jpre = linear->X;
92   PetscFunctionReturn(PETSC_SUCCESS);
93 }
94 
95 static PetscErrorCode PetscRegressorSetUp_Linear(PetscRegressor regressor)
96 {
97   PetscInt               M, N;
98   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
99   KSP                    ksp;
100   Tao                    tao;
101 
102   PetscFunctionBegin;
103   PetscCall(MatGetSize(regressor->training, &M, &N));
104 
105   if (linear->fit_intercept) {
106     /* If we are fitting the intercept, we need to make A a composite matrix using MATCENTERING to preserve sparsity.
107      * Though there might be some cases we don't want to do this for, depending on what kind of matrix is passed in. (Probably bad idea for dense?)
108      * We will also need to ensure that the right-hand side passed to the KSP is also mean-centered, since we
109      * intend to compute the intercept separately from regression coefficients (that is, we will not be adding a
110      * column of all 1s to our design matrix). */
111     PetscCall(MatCreateCentering(PetscObjectComm((PetscObject)regressor), PETSC_DECIDE, M, &linear->C));
112     PetscCall(MatCreate(PetscObjectComm((PetscObject)regressor), &linear->X));
113     PetscCall(MatSetSizes(linear->X, PETSC_DECIDE, PETSC_DECIDE, M, N));
114     PetscCall(MatSetType(linear->X, MATCOMPOSITE));
115     PetscCall(MatCompositeSetType(linear->X, MAT_COMPOSITE_MULTIPLICATIVE));
116     PetscCall(MatCompositeAddMat(linear->X, regressor->training));
117     PetscCall(MatCompositeAddMat(linear->X, linear->C));
118     PetscCall(VecDuplicate(regressor->target, &linear->rhs));
119     PetscCall(MatMult(linear->C, regressor->target, linear->rhs));
120   } else {
121     // When not fitting intercept, we assume that the input data are already centered.
122     linear->X   = regressor->training;
123     linear->rhs = regressor->target;
124 
125     PetscCall(PetscObjectReference((PetscObject)linear->X));
126     PetscCall(PetscObjectReference((PetscObject)linear->rhs));
127   }
128 
129   if (linear->coefficients) PetscCall(VecDestroy(&linear->coefficients));
130 
131   if (linear->use_ksp) {
132     PetscCheck(linear->type == REGRESSOR_LINEAR_OLS, PetscObjectComm((PetscObject)regressor), PETSC_ERR_ARG_WRONGSTATE, "KSP can be used to fit a linear regressor only when its type is OLS");
133 
134     if (!linear->ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp));
135     ksp = linear->ksp;
136 
137     PetscCall(MatCreateVecs(linear->X, &linear->coefficients, NULL));
138     /* Set up the KSP to solve the least squares problem (without solving for intercept, as this is done separately) using KSPLSQR. */
139     PetscCall(MatCreateNormal(linear->X, &linear->XtX));
140     PetscCall(KSPSetType(ksp, KSPLSQR));
141     PetscCall(KSPSetOperators(ksp, linear->X, linear->XtX));
142     PetscCall(KSPSetOptionsPrefix(ksp, ((PetscObject)regressor)->prefix));
143     PetscCall(KSPAppendOptionsPrefix(ksp, "regressor_linear_"));
144     PetscCall(KSPSetFromOptions(ksp));
145   } else {
146     /* Note: Currently implementation creates TAO inside of implementations.
147       * Thus, all the prefix jobs are done inside implementations, not in interface */
148     const char *prefix;
149 
150     if (!regressor->tao) PetscCall(PetscRegressorGetTao(regressor, &tao));
151 
152     PetscCall(MatCreateVecs(linear->X, &linear->coefficients, &linear->residual));
153     /* Set up the TAO object to solve the (regularized) least squares problem (without solving for intercept, which is done separately) using TAOBRGN. */
154     PetscCall(TaoSetType(tao, TAOBRGN));
155     PetscCall(TaoSetSolution(tao, linear->coefficients));
156     PetscCall(TaoSetResidualRoutine(tao, linear->residual, EvaluateResidual, linear));
157     PetscCall(TaoSetJacobianResidualRoutine(tao, linear->X, linear->X, EvaluateJacobian, linear));
158     if (!linear->use_ksp) PetscCall(TaoBRGNSetRegularizerWeight(tao, regressor->regularizer_weight));
159     // Set the regularization type and weight for the BRGN as linear->type dictates:
160     // TODO BRGN needs to be BRGNSetRegularizationType
161     // PetscOptionsSetValue no longer works due to functioning prefix system
162     PetscCall(PetscRegressorGetOptionsPrefix(regressor, &prefix));
163     PetscCall(TaoSetOptionsPrefix(regressor->tao, prefix));
164     PetscCall(TaoAppendOptionsPrefix(tao, "regressor_linear_"));
165     {
166       TAO_BRGN *gn = (TAO_BRGN *)regressor->tao->data;
167 
168       switch (linear->type) {
169       case REGRESSOR_LINEAR_OLS:
170         regressor->regularizer_weight = 0.0; // OLS, by definition, uses a regularizer weight of 0
171         break;
172       case REGRESSOR_LINEAR_LASSO:
173         gn->reg_type = BRGN_REGULARIZATION_L1DICT;
174         break;
175       case REGRESSOR_LINEAR_RIDGE:
176         gn->reg_type = BRGN_REGULARIZATION_L2PURE;
177         break;
178       default:
179         break;
180       }
181     }
182     PetscCall(TaoSetFromOptions(tao));
183   }
184   PetscFunctionReturn(PETSC_SUCCESS);
185 }
186 
187 static PetscErrorCode PetscRegressorReset_Linear(PetscRegressor regressor)
188 {
189   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
190 
191   PetscFunctionBegin;
192   /* Destroy the PETSc objects associated with the linear regressor implementation. */
193   linear->ksp_its     = 0;
194   linear->ksp_tot_its = 0;
195 
196   PetscCall(MatDestroy(&linear->X));
197   PetscCall(MatDestroy(&linear->XtX));
198   PetscCall(MatDestroy(&linear->C));
199   PetscCall(KSPDestroy(&linear->ksp));
200   PetscCall(VecDestroy(&linear->coefficients));
201   PetscCall(VecDestroy(&linear->rhs));
202   PetscCall(VecDestroy(&linear->residual));
203   PetscFunctionReturn(PETSC_SUCCESS);
204 }
205 
206 static PetscErrorCode PetscRegressorDestroy_Linear(PetscRegressor regressor)
207 {
208   PetscFunctionBegin;
209   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", NULL));
210   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", NULL));
211   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", NULL));
212   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", NULL));
213   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", NULL));
214   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", NULL));
215   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", NULL));
216   PetscCall(PetscRegressorReset_Linear(regressor));
217   PetscCall(PetscFree(regressor->data));
218   PetscFunctionReturn(PETSC_SUCCESS);
219 }
220 
221 /*@
222   PetscRegressorLinearSetFitIntercept - Set a flag to indicate that the intercept (also known as the "bias" or "offset") should
223   be calculated; data are assumed to be mean-centered if false.
224 
225   Logically Collective
226 
227   Input Parameters:
228 + regressor - the `PetscRegressor` context
229 - flg       - `PETSC_TRUE` to calculate the intercept, `PETSC_FALSE` to assume mean-centered data (default is `PETSC_TRUE`)
230 
231   Level: intermediate
232 
233   Options Database Key:
234 . regressor_linear_fit_intercept <true,false> - fit the intercept
235 
236   Note:
237   If the user indicates that the intercept should not be calculated, the intercept will be set to zero.
238 
239 .seealso: `PetscRegressor`, `PetscRegressorFit()`
240 @*/
241 PetscErrorCode PetscRegressorLinearSetFitIntercept(PetscRegressor regressor, PetscBool flg)
242 {
243   PetscFunctionBegin;
244   /* TODO: Add companion PetscRegressorLinearGetFitIntercept(), and put it in the .seealso: */
245   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
246   PetscValidLogicalCollectiveBool(regressor, flg, 2);
247   PetscTryMethod(regressor, "PetscRegressorLinearSetFitIntercept_C", (PetscRegressor, PetscBool), (regressor, flg));
248   PetscFunctionReturn(PETSC_SUCCESS);
249 }
250 
251 /*@
252   PetscRegressorLinearSetUseKSP - Set a flag to indicate that a `KSP` object, instead of a `Tao` one, should be used
253   to fit the regressor
254 
255   Logically Collective
256 
257   Input Parameters:
258 + regressor - the `PetscRegressor` context
259 - flg       - `PETSC_TRUE` to use a `KSP`, `PETSC_FALSE` to use a `Tao` object (default is false)
260 
261   Options Database Key:
262 . regressor_linear_use_ksp <true,false> - use `KSP`
263 
264   Level: intermediate
265 
266 .seealso: `PetscRegressor`, `PetscRegressorLinearGetKSP()`
267 @*/
268 PetscErrorCode PetscRegressorLinearSetUseKSP(PetscRegressor regressor, PetscBool flg)
269 {
270   PetscFunctionBegin;
271   /* TODO: Add companion PetscRegressorLinearGetUseKSP(), and put it in the .seealso: */
272   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
273   PetscValidLogicalCollectiveBool(regressor, flg, 2);
274   PetscTryMethod(regressor, "PetscRegressorLinearSetUseKSP_C", (PetscRegressor, PetscBool), (regressor, flg));
275   PetscFunctionReturn(PETSC_SUCCESS);
276 }
277 
278 static PetscErrorCode PetscRegressorSetFromOptions_Linear(PetscRegressor regressor, PetscOptionItems PetscOptionsObject)
279 {
280   PetscBool              set, flg = PETSC_FALSE;
281   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
282 
283   PetscFunctionBegin;
284   PetscOptionsHeadBegin(PetscOptionsObject, "PetscRegressor options for linear regressors");
285   PetscCall(PetscOptionsBool("-regressor_linear_fit_intercept", "Calculate intercept for linear model", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set));
286   if (set) PetscCall(PetscRegressorLinearSetFitIntercept(regressor, flg));
287   PetscCall(PetscOptionsBool("-regressor_linear_use_ksp", "Use KSP instead of TAO for linear model fitting problem", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set));
288   if (set) PetscCall(PetscRegressorLinearSetUseKSP(regressor, flg));
289   PetscCall(PetscOptionsEnum("-regressor_linear_type", "Linear regression method", "PetscRegressorLinearTypes", PetscRegressorLinearTypes, (PetscEnum)linear->type, (PetscEnum *)&linear->type, NULL));
290   PetscOptionsHeadEnd();
291   PetscFunctionReturn(PETSC_SUCCESS);
292 }
293 
294 static PetscErrorCode PetscRegressorView_Linear(PetscRegressor regressor, PetscViewer viewer)
295 {
296   PetscBool              isascii;
297   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
298 
299   PetscFunctionBegin;
300   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
301   if (isascii) {
302     PetscCall(PetscViewerASCIIPushTab(viewer));
303     PetscCall(PetscViewerASCIIPrintf(viewer, "PetscRegressor Linear Type: %s\n", PetscRegressorLinearTypes[linear->type]));
304     if (linear->ksp) {
305       PetscCall(KSPView(linear->ksp, viewer));
306       PetscCall(PetscViewerASCIIPrintf(viewer, "total KSP iterations: %" PetscInt_FMT "\n", linear->ksp_tot_its));
307     }
308     if (linear->fit_intercept) PetscCall(PetscViewerASCIIPrintf(viewer, "Intercept=%g\n", (double)linear->intercept));
309     PetscCall(PetscViewerASCIIPopTab(viewer));
310   }
311   PetscFunctionReturn(PETSC_SUCCESS);
312 }
313 
314 /*@
315   PetscRegressorLinearGetKSP - Returns the `KSP` context for a `PETSCREGRESSORLINEAR` object.
316 
317   Not Collective, but if the `PetscRegressor` is parallel, then the `KSP` object is parallel
318 
319   Input Parameter:
320 . regressor - the `PetscRegressor` context
321 
322   Output Parameter:
323 . ksp - the `KSP` context
324 
325   Level: beginner
326 
327   Note:
328   This routine will always return a `KSP`, but, depending on the type of the linear regressor and the options that are set, the regressor may actually use a `Tao` object instead of this `KSP`.
329 
330 .seealso: `PetscRegressorGetTao()`
331 @*/
332 PetscErrorCode PetscRegressorLinearGetKSP(PetscRegressor regressor, KSP *ksp)
333 {
334   PetscFunctionBegin;
335   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
336   PetscAssertPointer(ksp, 2);
337   PetscUseMethod(regressor, "PetscRegressorLinearGetKSP_C", (PetscRegressor, KSP *), (regressor, ksp));
338   PetscFunctionReturn(PETSC_SUCCESS);
339 }
340 
341 /*@
342   PetscRegressorLinearGetCoefficients - Get a vector of the fitted coefficients from a linear regression model
343 
344   Not Collective but the vector is parallel
345 
346   Input Parameter:
347 . regressor - the `PetscRegressor` context
348 
349   Output Parameter:
350 . coefficients - the vector of the coefficients
351 
352   Level: beginner
353 
354 .seealso: `PetscRegressor`, `PetscRegressorLinearGetIntercept()`, `PETSCREGRESSORLINEAR`, `Vec`
355 @*/
356 PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetCoefficients(PetscRegressor regressor, Vec *coefficients)
357 {
358   PetscFunctionBegin;
359   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
360   PetscAssertPointer(coefficients, 2);
361   PetscUseMethod(regressor, "PetscRegressorLinearGetCoefficients_C", (PetscRegressor, Vec *), (regressor, coefficients));
362   PetscFunctionReturn(PETSC_SUCCESS);
363 }
364 
365 /*@
366   PetscRegressorLinearGetIntercept - Get the intercept from a linear regression model
367 
368   Not Collective
369 
370   Input Parameter:
371 . regressor - the `PetscRegressor` context
372 
373   Output Parameter:
374 . intercept - the intercept
375 
376   Level: beginner
377 
378 .seealso: `PetscRegressor`, `PetscRegressorLinearSetFitIntercept()`, `PetscRegressorLinearGetCoefficients()`, `PETSCREGRESSORLINEAR`
379 @*/
380 PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetIntercept(PetscRegressor regressor, PetscScalar *intercept)
381 {
382   PetscFunctionBegin;
383   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
384   PetscAssertPointer(intercept, 2);
385   PetscUseMethod(regressor, "PetscRegressorLinearGetIntercept_C", (PetscRegressor, PetscScalar *), (regressor, intercept));
386   PetscFunctionReturn(PETSC_SUCCESS);
387 }
388 
389 /*@C
390   PetscRegressorLinearSetType - Sets the type of linear regression to be performed
391 
392   Logically Collective
393 
394   Input Parameters:
395 + regressor - the `PetscRegressor` context (should be of type `PETSCREGRESSORLINEAR`)
396 - type      - a known linear regression method
397 
398   Options Database Key:
399 . -regressor_linear_type - Sets the linear regression method; use -help for a list of available methods
400    (for instance "-regressor_linear_type ols" or "-regressor_linear_type lasso")
401 
402   Level: intermediate
403 
404 .seealso: `PetscRegressorLinearGetType()`, `PetscRegressorLinearType`, `PetscRegressorSetType()`
405 @*/
406 PetscErrorCode PetscRegressorLinearSetType(PetscRegressor regressor, PetscRegressorLinearType type)
407 {
408   PetscFunctionBegin;
409   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
410   PetscValidLogicalCollectiveEnum(regressor, type, 2);
411   PetscTryMethod(regressor, "PetscRegressorLinearSetType_C", (PetscRegressor, PetscRegressorLinearType), (regressor, type));
412   PetscFunctionReturn(PETSC_SUCCESS);
413 }
414 
415 /*@
416   PetscRegressorLinearGetType - Return the type for the `PETSCREGRESSORLINEAR` solver
417 
418   Input Parameter:
419 . regressor - the `PetscRegressor` solver context
420 
421   Output Parameter:
422 . type - `PETSCREGRESSORLINEAR` type
423 
424   Level: advanced
425 
426 .seealso: `PetscRegressor`, `PETSCREGRESSORLINEAR`, `PetscRegressorLinearSetType()`, `PetscRegressorLinearType`
427 @*/
428 PetscErrorCode PetscRegressorLinearGetType(PetscRegressor regressor, PetscRegressorLinearType *type)
429 {
430   PetscFunctionBegin;
431   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
432   PetscAssertPointer(type, 2);
433   PetscUseMethod(regressor, "PetscRegressorLinearGetType_C", (PetscRegressor, PetscRegressorLinearType *), (regressor, type));
434   PetscFunctionReturn(PETSC_SUCCESS);
435 }
436 
437 static PetscErrorCode PetscRegressorFit_Linear(PetscRegressor regressor)
438 {
439   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
440   KSP                    ksp;
441   PetscScalar            target_mean, *column_means_global, *column_means_local, column_means_dot_coefficients;
442   Vec                    column_means;
443   PetscInt               m, N, istart, i, kspits;
444 
445   PetscFunctionBegin;
446   if (linear->use_ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp));
447   ksp = linear->ksp;
448 
449   /* Solve the least-squares problem (previously set up in PetscRegressorSetUp_Linear()) without finding the intercept. */
450   if (linear->use_ksp) {
451     PetscCall(KSPSolve(ksp, linear->rhs, linear->coefficients));
452     PetscCall(KSPGetIterationNumber(ksp, &kspits));
453     linear->ksp_its += kspits;
454     linear->ksp_tot_its += kspits;
455   } else {
456     PetscCall(TaoSolve(regressor->tao));
457   }
458 
459   /* Calculate the intercept. */
460   if (linear->fit_intercept) {
461     PetscCall(MatGetSize(regressor->training, NULL, &N));
462     PetscCall(PetscMalloc1(N, &column_means_global));
463     PetscCall(VecMean(regressor->target, &target_mean));
464     /* We need the means of all columns of regressor->training, placed into a Vec compatible with linear->coefficients.
465      * Note the potential scalability issue: MatGetColumnMeans() computes means of ALL colummns. */
466     PetscCall(MatGetColumnMeans(regressor->training, column_means_global));
467     /* TODO: Calculation of the Vec and matrix column means should probably go into the SetUp phase, and also be placed
468      *       into a routine that is callable from outside of PetscRegressorFit_Linear(), because we'll want to do the same
469      *       thing for other models, such as ridge and LASSO regression, and should avoid code duplication.
470      *       What we are calling 'target_mean' and 'column_means' should be stashed in the base linear regressor struct,
471      *       and perhaps renamed to make it clear they are offsets that should be applied (though the current naming
472      *       makes sense since it makes it clear where these come from.) */
473     PetscCall(VecDuplicate(linear->coefficients, &column_means));
474     PetscCall(VecGetLocalSize(column_means, &m));
475     PetscCall(VecGetOwnershipRange(column_means, &istart, NULL));
476     PetscCall(VecGetArrayWrite(column_means, &column_means_local));
477     for (i = 0; i < m; i++) column_means_local[i] = column_means_global[istart + i];
478     PetscCall(VecRestoreArrayWrite(column_means, &column_means_local));
479     PetscCall(VecDot(column_means, linear->coefficients, &column_means_dot_coefficients));
480     PetscCall(VecDestroy(&column_means));
481     PetscCall(PetscFree(column_means_global));
482     linear->intercept = target_mean - column_means_dot_coefficients;
483   } else {
484     linear->intercept = 0.0;
485   }
486   PetscFunctionReturn(PETSC_SUCCESS);
487 }
488 
489 static PetscErrorCode PetscRegressorPredict_Linear(PetscRegressor regressor, Mat X, Vec y)
490 {
491   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
492 
493   PetscFunctionBegin;
494   PetscCall(MatMult(X, linear->coefficients, y));
495   PetscCall(VecShift(y, linear->intercept));
496   PetscFunctionReturn(PETSC_SUCCESS);
497 }
498 
499 /*MC
500      PETSCREGRESSORLINEAR - Linear regression model (ordinary least squares or regularized variants)
501 
502    Options Database:
503 +  -regressor_linear_fit_intercept - Calculate the intercept for the linear model
504 -  -regressor_linear_use_ksp       - Use `KSP` instead of `Tao` for linear model fitting (non-regularized variants only)
505 
506    Level: beginner
507 
508    Note:
509    This is the default regressor in `PetscRegressor`.
510 
511 .seealso: `PetscRegressorCreate()`, `PetscRegressor`, `PetscRegressorSetType()`
512 M*/
513 PETSC_EXTERN PetscErrorCode PetscRegressorCreate_Linear(PetscRegressor regressor)
514 {
515   PetscRegressor_Linear *linear;
516 
517   PetscFunctionBegin;
518   PetscCall(PetscNew(&linear));
519   regressor->data = (void *)linear;
520 
521   regressor->ops->setup          = PetscRegressorSetUp_Linear;
522   regressor->ops->reset          = PetscRegressorReset_Linear;
523   regressor->ops->destroy        = PetscRegressorDestroy_Linear;
524   regressor->ops->setfromoptions = PetscRegressorSetFromOptions_Linear;
525   regressor->ops->view           = PetscRegressorView_Linear;
526   regressor->ops->fit            = PetscRegressorFit_Linear;
527   regressor->ops->predict        = PetscRegressorPredict_Linear;
528 
529   linear->intercept     = 0.0;
530   linear->fit_intercept = PETSC_TRUE;  /* Default to calculating the intercept. */
531   linear->use_ksp       = PETSC_FALSE; /* Do not default to using KSP for solving the model-fitting problem (use TAO instead). */
532   linear->type          = REGRESSOR_LINEAR_OLS;
533   /* Above, manually set the default linear regressor type.
534        We don't use PetscRegressorLinearSetType() here, because that expects the SetUp event to already have happened. */
535 
536   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", PetscRegressorLinearSetFitIntercept_Linear));
537   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", PetscRegressorLinearSetUseKSP_Linear));
538   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", PetscRegressorLinearGetKSP_Linear));
539   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", PetscRegressorLinearGetCoefficients_Linear));
540   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", PetscRegressorLinearGetIntercept_Linear));
541   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", PetscRegressorLinearSetType_Linear));
542   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", PetscRegressorLinearGetType_Linear));
543   PetscFunctionReturn(PETSC_SUCCESS);
544 }
545