xref: /petsc/src/ml/regressor/impls/linear/linear.c (revision 1a962a18669b8a13320147188b51c0d759e22c01)
1 #include <../src/ml/regressor/impls/linear/linearimpl.h> /*I "petscregressor.h" I*/
2 #include <../src/tao/leastsquares/impls/brgn/brgn.h>     /*I "petsctao.h" I*/
3 
4 const char *const PetscRegressorLinearTypes[] = {"ols", "lasso", "ridge", "RegressorLinearType", "REGRESSOR_LINEAR_", NULL};
5 
6 static PetscErrorCode PetscRegressorLinearSetFitIntercept_Linear(PetscRegressor regressor, PetscBool flg)
7 {
8   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
9 
10   PetscFunctionBegin;
11   linear->fit_intercept = flg;
12   PetscFunctionReturn(PETSC_SUCCESS);
13 }
14 
15 static PetscErrorCode PetscRegressorLinearSetType_Linear(PetscRegressor regressor, PetscRegressorLinearType type)
16 {
17   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
18 
19   PetscFunctionBegin;
20   linear->type = type;
21   PetscFunctionReturn(PETSC_SUCCESS);
22 }
23 
24 static PetscErrorCode PetscRegressorLinearGetType_Linear(PetscRegressor regressor, PetscRegressorLinearType *type)
25 {
26   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
27 
28   PetscFunctionBegin;
29   *type = linear->type;
30   PetscFunctionReturn(PETSC_SUCCESS);
31 }
32 
33 static PetscErrorCode PetscRegressorLinearGetIntercept_Linear(PetscRegressor regressor, PetscScalar *intercept)
34 {
35   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
36 
37   PetscFunctionBegin;
38   *intercept = linear->intercept;
39   PetscFunctionReturn(PETSC_SUCCESS);
40 }
41 
42 static PetscErrorCode PetscRegressorLinearGetCoefficients_Linear(PetscRegressor regressor, Vec *coefficients)
43 {
44   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
45 
46   PetscFunctionBegin;
47   *coefficients = linear->coefficients;
48   PetscFunctionReturn(PETSC_SUCCESS);
49 }
50 
51 static PetscErrorCode PetscRegressorLinearGetKSP_Linear(PetscRegressor regressor, KSP *ksp)
52 {
53   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
54 
55   PetscFunctionBegin;
56   if (!linear->ksp) {
57     PetscCall(KSPCreate(PetscObjectComm((PetscObject)regressor), &linear->ksp));
58     PetscCall(PetscObjectIncrementTabLevel((PetscObject)linear->ksp, (PetscObject)regressor, 1));
59     PetscCall(PetscObjectSetOptions((PetscObject)linear->ksp, ((PetscObject)regressor)->options));
60   }
61   *ksp = linear->ksp;
62   PetscFunctionReturn(PETSC_SUCCESS);
63 }
64 
65 static PetscErrorCode PetscRegressorLinearSetUseKSP_Linear(PetscRegressor regressor, PetscBool flg)
66 {
67   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
68 
69   PetscFunctionBegin;
70   linear->use_ksp = flg;
71   PetscFunctionReturn(PETSC_SUCCESS);
72 }
73 
74 static PetscErrorCode EvaluateResidual(Tao tao, Vec x, Vec f, void *ptr)
75 {
76   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)ptr;
77 
78   PetscFunctionBegin;
79   /* Evaluate f = A * x - b */
80   PetscCall(MatMult(linear->X, x, f));
81   PetscCall(VecAXPY(f, -1.0, linear->rhs));
82   PetscFunctionReturn(PETSC_SUCCESS);
83 }
84 
85 static PetscErrorCode EvaluateJacobian(Tao tao, Vec x, Mat J, Mat Jpre, void *ptr)
86 {
87   /* The TAOBRGN API expects us to pass an EvaluateJacobian() routine to it, but in this case it is a dummy function.
88      Denoting our data matrix as X, for linear least squares J[m][n] = df[m]/dx[n] = X[m][n]. */
89   PetscFunctionBegin;
90   PetscFunctionReturn(PETSC_SUCCESS);
91 }
92 
93 static PetscErrorCode PetscRegressorSetUp_Linear(PetscRegressor regressor)
94 {
95   PetscInt               M, N;
96   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
97   KSP                    ksp;
98   Tao                    tao;
99 
100   PetscFunctionBegin;
101   PetscCall(MatGetSize(regressor->training, &M, &N));
102 
103   if (linear->fit_intercept) {
104     /* If we are fitting the intercept, we need to make A a composite matrix using MATCENTERING to preserve sparsity.
105      * Though there might be some cases we don't want to do this for, depending on what kind of matrix is passed in. (Probably bad idea for dense?)
106      * We will also need to ensure that the right-hand side passed to the KSP is also mean-centered, since we
107      * intend to compute the intercept separately from regression coefficients (that is, we will not be adding a
108      * column of all 1s to our design matrix). */
109     PetscCall(MatCreateCentering(PetscObjectComm((PetscObject)regressor), PETSC_DECIDE, M, &linear->C));
110     PetscCall(MatCreate(PetscObjectComm((PetscObject)regressor), &linear->X));
111     PetscCall(MatSetSizes(linear->X, PETSC_DECIDE, PETSC_DECIDE, M, N));
112     PetscCall(MatSetType(linear->X, MATCOMPOSITE));
113     PetscCall(MatCompositeSetType(linear->X, MAT_COMPOSITE_MULTIPLICATIVE));
114     PetscCall(MatCompositeAddMat(linear->X, regressor->training));
115     PetscCall(MatCompositeAddMat(linear->X, linear->C));
116     PetscCall(VecDuplicate(regressor->target, &linear->rhs));
117     PetscCall(MatMult(linear->C, regressor->target, linear->rhs));
118   } else {
119     // When not fitting intercept, we assume that the input data are already centered.
120     linear->X   = regressor->training;
121     linear->rhs = regressor->target;
122 
123     PetscCall(PetscObjectReference((PetscObject)linear->X));
124     PetscCall(PetscObjectReference((PetscObject)linear->rhs));
125   }
126 
127   if (linear->coefficients) PetscCall(VecDestroy(&linear->coefficients));
128 
129   if (linear->use_ksp) {
130     PetscCheck(linear->type == REGRESSOR_LINEAR_OLS, PetscObjectComm((PetscObject)regressor), PETSC_ERR_ARG_WRONGSTATE, "KSP can be used to fit a linear regressor only when its type is OLS");
131 
132     if (!linear->ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp));
133     ksp = linear->ksp;
134 
135     PetscCall(MatCreateVecs(linear->X, &linear->coefficients, NULL));
136     /* Set up the KSP to solve the least squares problem (without solving for intercept, as this is done separately) using KSPLSQR. */
137     PetscCall(MatCreateNormal(linear->X, &linear->XtX));
138     PetscCall(KSPSetType(ksp, KSPLSQR));
139     PetscCall(KSPSetOperators(ksp, linear->X, linear->XtX));
140     PetscCall(KSPSetOptionsPrefix(ksp, ((PetscObject)regressor)->prefix));
141     PetscCall(KSPAppendOptionsPrefix(ksp, "regressor_linear_"));
142     PetscCall(KSPSetFromOptions(ksp));
143   } else {
144     /* Note: Currently implementation creates TAO inside of implementations.
145       * Thus, all the prefix jobs are done inside implementations, not in interface */
146     const char *prefix;
147 
148     if (!regressor->tao) PetscCall(PetscRegressorGetTao(regressor, &tao));
149 
150     PetscCall(MatCreateVecs(linear->X, &linear->coefficients, &linear->residual));
151     /* Set up the TAO object to solve the (regularized) least squares problem (without solving for intercept, which is done separately) using TAOBRGN. */
152     PetscCall(TaoSetType(tao, TAOBRGN));
153     PetscCall(TaoSetSolution(tao, linear->coefficients));
154     PetscCall(TaoSetResidualRoutine(tao, linear->residual, EvaluateResidual, linear));
155     PetscCall(TaoSetJacobianResidualRoutine(tao, linear->X, linear->X, EvaluateJacobian, linear));
156     if (!linear->use_ksp) PetscCall(TaoBRGNSetRegularizerWeight(tao, regressor->regularizer_weight));
157     // Set the regularization type and weight for the BRGN as linear->type dictates:
158     // TODO BRGN needs to be BRGNSetRegularizationType
159     // PetscOptionsSetValue no longer works due to functioning prefix system
160     PetscCall(PetscRegressorGetOptionsPrefix(regressor, &prefix));
161     PetscCall(TaoSetOptionsPrefix(regressor->tao, prefix));
162     PetscCall(TaoAppendOptionsPrefix(tao, "regressor_linear_"));
163     {
164       TAO_BRGN *gn = (TAO_BRGN *)regressor->tao->data;
165 
166       switch (linear->type) {
167       case REGRESSOR_LINEAR_OLS:
168         regressor->regularizer_weight = 0.0; // OLS, by definition, uses a regularizer weight of 0
169         break;
170       case REGRESSOR_LINEAR_LASSO:
171         gn->reg_type = BRGN_REGULARIZATION_L1DICT;
172         break;
173       case REGRESSOR_LINEAR_RIDGE:
174         gn->reg_type = BRGN_REGULARIZATION_L2PURE;
175         break;
176       default:
177         break;
178       }
179     }
180     PetscCall(TaoSetFromOptions(tao));
181   }
182   PetscFunctionReturn(PETSC_SUCCESS);
183 }
184 
185 static PetscErrorCode PetscRegressorReset_Linear(PetscRegressor regressor)
186 {
187   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
188 
189   PetscFunctionBegin;
190   /* Destroy the PETSc objects associated with the linear regressor implementation. */
191   linear->ksp_its     = 0;
192   linear->ksp_tot_its = 0;
193 
194   PetscCall(MatDestroy(&linear->X));
195   PetscCall(MatDestroy(&linear->XtX));
196   PetscCall(MatDestroy(&linear->C));
197   PetscCall(KSPDestroy(&linear->ksp));
198   PetscCall(VecDestroy(&linear->coefficients));
199   PetscCall(VecDestroy(&linear->rhs));
200   PetscCall(VecDestroy(&linear->residual));
201   PetscFunctionReturn(PETSC_SUCCESS);
202 }
203 
204 static PetscErrorCode PetscRegressorDestroy_Linear(PetscRegressor regressor)
205 {
206   PetscFunctionBegin;
207   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", NULL));
208   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", NULL));
209   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", NULL));
210   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", NULL));
211   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", NULL));
212   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", NULL));
213   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", NULL));
214   PetscCall(PetscRegressorReset_Linear(regressor));
215   PetscCall(PetscFree(regressor->data));
216   PetscFunctionReturn(PETSC_SUCCESS);
217 }
218 
219 /*@
220   PetscRegressorLinearSetFitIntercept - Set a flag to indicate that the intercept (also known as the "bias" or "offset") should
221   be calculated; data are assumed to be mean-centered if false.
222 
223   Logically Collective
224 
225   Input Parameters:
226 + regressor - the `PetscRegressor` context
227 - flg       - `PETSC_TRUE` to calculate the intercept, `PETSC_FALSE` to assume mean-centered data (default is `PETSC_TRUE`)
228 
229   Level: intermediate
230 
231   Options Database Key:
232 . regressor_linear_fit_intercept <true,false> - fit the intercept
233 
234   Note:
235   If the user indicates that the intercept should not be calculated, the intercept will be set to zero.
236 
237 .seealso: `PetscRegressor`, `PetscRegressorFit()`
238 @*/
239 PetscErrorCode PetscRegressorLinearSetFitIntercept(PetscRegressor regressor, PetscBool flg)
240 {
241   PetscFunctionBegin;
242   /* TODO: Add companion PetscRegressorLinearGetFitIntercept(), and put it in the .seealso: */
243   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
244   PetscValidLogicalCollectiveBool(regressor, flg, 2);
245   PetscTryMethod(regressor, "PetscRegressorLinearSetFitIntercept_C", (PetscRegressor, PetscBool), (regressor, flg));
246   PetscFunctionReturn(PETSC_SUCCESS);
247 }
248 
249 /*@
250   PetscRegressorLinearSetUseKSP - Set a flag to indicate that a `KSP` object, instead of a `Tao` one, should be used
251   to fit the regressor
252 
253   Logically Collective
254 
255   Input Parameters:
256 + regressor - the `PetscRegressor` context
257 - flg       - `PETSC_TRUE` to use a `KSP`, `PETSC_FALSE` to use a `Tao` object (default is false)
258 
259   Options Database Key:
260 . regressor_linear_use_ksp <true,false> - use `KSP`
261 
262   Level: intermediate
263 
264 .seealso: `PetscRegressor`, `PetscRegressorLinearGetKSP()`
265 @*/
266 PetscErrorCode PetscRegressorLinearSetUseKSP(PetscRegressor regressor, PetscBool flg)
267 {
268   PetscFunctionBegin;
269   /* TODO: Add companion PetscRegressorLinearGetUseKSP(), and put it in the .seealso: */
270   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
271   PetscValidLogicalCollectiveBool(regressor, flg, 2);
272   PetscTryMethod(regressor, "PetscRegressorLinearSetUseKSP_C", (PetscRegressor, PetscBool), (regressor, flg));
273   PetscFunctionReturn(PETSC_SUCCESS);
274 }
275 
276 static PetscErrorCode PetscRegressorSetFromOptions_Linear(PetscRegressor regressor, PetscOptionItems PetscOptionsObject)
277 {
278   PetscBool              set, flg = PETSC_FALSE;
279   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
280 
281   PetscFunctionBegin;
282   PetscOptionsHeadBegin(PetscOptionsObject, "PetscRegressor options for linear regressors");
283   PetscCall(PetscOptionsBool("-regressor_linear_fit_intercept", "Calculate intercept for linear model", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set));
284   if (set) PetscCall(PetscRegressorLinearSetFitIntercept(regressor, flg));
285   PetscCall(PetscOptionsBool("-regressor_linear_use_ksp", "Use KSP instead of TAO for linear model fitting problem", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set));
286   if (set) PetscCall(PetscRegressorLinearSetUseKSP(regressor, flg));
287   PetscCall(PetscOptionsEnum("-regressor_linear_type", "Linear regression method", "PetscRegressorLinearTypes", PetscRegressorLinearTypes, (PetscEnum)linear->type, (PetscEnum *)&linear->type, NULL));
288   PetscOptionsHeadEnd();
289   PetscFunctionReturn(PETSC_SUCCESS);
290 }
291 
292 static PetscErrorCode PetscRegressorView_Linear(PetscRegressor regressor, PetscViewer viewer)
293 {
294   PetscBool              isascii;
295   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
296 
297   PetscFunctionBegin;
298   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
299   if (isascii) {
300     PetscCall(PetscViewerASCIIPushTab(viewer));
301     PetscCall(PetscViewerASCIIPrintf(viewer, "PetscRegressor Linear Type: %s\n", PetscRegressorLinearTypes[linear->type]));
302     if (linear->ksp) {
303       PetscCall(KSPView(linear->ksp, viewer));
304       PetscCall(PetscViewerASCIIPrintf(viewer, "total KSP iterations: %" PetscInt_FMT "\n", linear->ksp_tot_its));
305     }
306     if (linear->fit_intercept) PetscCall(PetscViewerASCIIPrintf(viewer, "Intercept=%g\n", (double)linear->intercept));
307     PetscCall(PetscViewerASCIIPopTab(viewer));
308   }
309   PetscFunctionReturn(PETSC_SUCCESS);
310 }
311 
312 /*@
313   PetscRegressorLinearGetKSP - Returns the `KSP` context for a `PETSCREGRESSORLINEAR` object.
314 
315   Not Collective, but if the `PetscRegressor` is parallel, then the `KSP` object is parallel
316 
317   Input Parameter:
318 . regressor - the `PetscRegressor` context
319 
320   Output Parameter:
321 . ksp - the `KSP` context
322 
323   Level: beginner
324 
325   Note:
326   This routine will always return a `KSP`, but, depending on the type of the linear regressor and the options that are set, the regressor may actually use a `Tao` object instead of this `KSP`.
327 
328 .seealso: `PetscRegressorGetTao()`
329 @*/
330 PetscErrorCode PetscRegressorLinearGetKSP(PetscRegressor regressor, KSP *ksp)
331 {
332   PetscFunctionBegin;
333   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
334   PetscAssertPointer(ksp, 2);
335   PetscUseMethod(regressor, "PetscRegressorLinearGetKSP_C", (PetscRegressor, KSP *), (regressor, ksp));
336   PetscFunctionReturn(PETSC_SUCCESS);
337 }
338 
339 /*@
340   PetscRegressorLinearGetCoefficients - Get a vector of the fitted coefficients from a linear regression model
341 
342   Not Collective but the vector is parallel
343 
344   Input Parameter:
345 . regressor - the `PetscRegressor` context
346 
347   Output Parameter:
348 . coefficients - the vector of the coefficients
349 
350   Level: beginner
351 
352 .seealso: `PetscRegressor`, `PetscRegressorLinearGetIntercept()`, `PETSCREGRESSORLINEAR`, `Vec`
353 @*/
354 PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetCoefficients(PetscRegressor regressor, Vec *coefficients)
355 {
356   PetscFunctionBegin;
357   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
358   PetscAssertPointer(coefficients, 2);
359   PetscUseMethod(regressor, "PetscRegressorLinearGetCoefficients_C", (PetscRegressor, Vec *), (regressor, coefficients));
360   PetscFunctionReturn(PETSC_SUCCESS);
361 }
362 
363 /*@
364   PetscRegressorLinearGetIntercept - Get the intercept from a linear regression model
365 
366   Not Collective
367 
368   Input Parameter:
369 . regressor - the `PetscRegressor` context
370 
371   Output Parameter:
372 . intercept - the intercept
373 
374   Level: beginner
375 
376 .seealso: `PetscRegressor`, `PetscRegressorLinearSetFitIntercept()`, `PetscRegressorLinearGetCoefficients()`, `PETSCREGRESSORLINEAR`
377 @*/
378 PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetIntercept(PetscRegressor regressor, PetscScalar *intercept)
379 {
380   PetscFunctionBegin;
381   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
382   PetscAssertPointer(intercept, 2);
383   PetscUseMethod(regressor, "PetscRegressorLinearGetIntercept_C", (PetscRegressor, PetscScalar *), (regressor, intercept));
384   PetscFunctionReturn(PETSC_SUCCESS);
385 }
386 
387 /*@C
388   PetscRegressorLinearSetType - Sets the type of linear regression to be performed
389 
390   Logically Collective
391 
392   Input Parameters:
393 + regressor - the `PetscRegressor` context (should be of type `PETSCREGRESSORLINEAR`)
394 - type      - a known linear regression method
395 
396   Options Database Key:
397 . -regressor_linear_type - Sets the linear regression method; use -help for a list of available methods
398    (for instance "-regressor_linear_type ols" or "-regressor_linear_type lasso")
399 
400   Level: intermediate
401 
402 .seealso: `PetscRegressorLinearGetType()`, `PetscRegressorLinearType`, `PetscRegressorSetType()`
403 @*/
404 PetscErrorCode PetscRegressorLinearSetType(PetscRegressor regressor, PetscRegressorLinearType type)
405 {
406   PetscFunctionBegin;
407   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
408   PetscValidLogicalCollectiveEnum(regressor, type, 2);
409   PetscTryMethod(regressor, "PetscRegressorLinearSetType_C", (PetscRegressor, PetscRegressorLinearType), (regressor, type));
410   PetscFunctionReturn(PETSC_SUCCESS);
411 }
412 
413 /*@
414   PetscRegressorLinearGetType - Return the type for the `PETSCREGRESSORLINEAR` solver
415 
416   Input Parameter:
417 . regressor - the `PetscRegressor` solver context
418 
419   Output Parameter:
420 . type - `PETSCREGRESSORLINEAR` type
421 
422   Level: advanced
423 
424 .seealso: `PetscRegressor`, `PETSCREGRESSORLINEAR`, `PetscRegressorLinearSetType()`, `PetscRegressorLinearType`
425 @*/
426 PetscErrorCode PetscRegressorLinearGetType(PetscRegressor regressor, PetscRegressorLinearType *type)
427 {
428   PetscFunctionBegin;
429   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
430   PetscAssertPointer(type, 2);
431   PetscUseMethod(regressor, "PetscRegressorLinearGetType_C", (PetscRegressor, PetscRegressorLinearType *), (regressor, type));
432   PetscFunctionReturn(PETSC_SUCCESS);
433 }
434 
435 static PetscErrorCode PetscRegressorFit_Linear(PetscRegressor regressor)
436 {
437   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
438   KSP                    ksp;
439   PetscScalar            target_mean, *column_means_global, *column_means_local, column_means_dot_coefficients;
440   Vec                    column_means;
441   PetscInt               m, N, istart, i, kspits;
442 
443   PetscFunctionBegin;
444   if (linear->use_ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp));
445   ksp = linear->ksp;
446 
447   /* Solve the least-squares problem (previously set up in PetscRegressorSetUp_Linear()) without finding the intercept. */
448   if (linear->use_ksp) {
449     PetscCall(KSPSolve(ksp, linear->rhs, linear->coefficients));
450     PetscCall(KSPGetIterationNumber(ksp, &kspits));
451     linear->ksp_its += kspits;
452     linear->ksp_tot_its += kspits;
453   } else {
454     PetscCall(TaoSolve(regressor->tao));
455   }
456 
457   /* Calculate the intercept. */
458   if (linear->fit_intercept) {
459     PetscCall(MatGetSize(regressor->training, NULL, &N));
460     PetscCall(PetscMalloc1(N, &column_means_global));
461     PetscCall(VecMean(regressor->target, &target_mean));
462     /* We need the means of all columns of regressor->training, placed into a Vec compatible with linear->coefficients.
463      * Note the potential scalability issue: MatGetColumnMeans() computes means of ALL colummns. */
464     PetscCall(MatGetColumnMeans(regressor->training, column_means_global));
465     /* TODO: Calculation of the Vec and matrix column means should probably go into the SetUp phase, and also be placed
466      *       into a routine that is callable from outside of PetscRegressorFit_Linear(), because we'll want to do the same
467      *       thing for other models, such as ridge and LASSO regression, and should avoid code duplication.
468      *       What we are calling 'target_mean' and 'column_means' should be stashed in the base linear regressor struct,
469      *       and perhaps renamed to make it clear they are offsets that should be applied (though the current naming
470      *       makes sense since it makes it clear where these come from.) */
471     PetscCall(VecDuplicate(linear->coefficients, &column_means));
472     PetscCall(VecGetLocalSize(column_means, &m));
473     PetscCall(VecGetOwnershipRange(column_means, &istart, NULL));
474     PetscCall(VecGetArrayWrite(column_means, &column_means_local));
475     for (i = 0; i < m; i++) column_means_local[i] = column_means_global[istart + i];
476     PetscCall(VecRestoreArrayWrite(column_means, &column_means_local));
477     PetscCall(VecDot(column_means, linear->coefficients, &column_means_dot_coefficients));
478     PetscCall(VecDestroy(&column_means));
479     PetscCall(PetscFree(column_means_global));
480     linear->intercept = target_mean - column_means_dot_coefficients;
481   } else {
482     linear->intercept = 0.0;
483   }
484   PetscFunctionReturn(PETSC_SUCCESS);
485 }
486 
487 static PetscErrorCode PetscRegressorPredict_Linear(PetscRegressor regressor, Mat X, Vec y)
488 {
489   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
490 
491   PetscFunctionBegin;
492   PetscCall(MatMult(X, linear->coefficients, y));
493   PetscCall(VecShift(y, linear->intercept));
494   PetscFunctionReturn(PETSC_SUCCESS);
495 }
496 
497 /*MC
498      PETSCREGRESSORLINEAR - Linear regression model (ordinary least squares or regularized variants)
499 
500    Options Database:
501 +  -regressor_linear_fit_intercept - Calculate the intercept for the linear model
502 -  -regressor_linear_use_ksp       - Use `KSP` instead of `Tao` for linear model fitting (non-regularized variants only)
503 
504    Level: beginner
505 
506    Note:
507    This is the default regressor in `PetscRegressor`.
508 
509 .seealso: `PetscRegressorCreate()`, `PetscRegressor`, `PetscRegressorSetType()`
510 M*/
511 PETSC_EXTERN PetscErrorCode PetscRegressorCreate_Linear(PetscRegressor regressor)
512 {
513   PetscRegressor_Linear *linear;
514 
515   PetscFunctionBegin;
516   PetscCall(PetscNew(&linear));
517   regressor->data = (void *)linear;
518 
519   regressor->ops->setup          = PetscRegressorSetUp_Linear;
520   regressor->ops->reset          = PetscRegressorReset_Linear;
521   regressor->ops->destroy        = PetscRegressorDestroy_Linear;
522   regressor->ops->setfromoptions = PetscRegressorSetFromOptions_Linear;
523   regressor->ops->view           = PetscRegressorView_Linear;
524   regressor->ops->fit            = PetscRegressorFit_Linear;
525   regressor->ops->predict        = PetscRegressorPredict_Linear;
526 
527   linear->intercept     = 0.0;
528   linear->fit_intercept = PETSC_TRUE;  /* Default to calculating the intercept. */
529   linear->use_ksp       = PETSC_FALSE; /* Do not default to using KSP for solving the model-fitting problem (use TAO instead). */
530   linear->type          = REGRESSOR_LINEAR_OLS;
531   /* Above, manually set the default linear regressor type.
532        We don't use PetscRegressorLinearSetType() here, because that expects the SetUp event to already have happened. */
533 
534   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", PetscRegressorLinearSetFitIntercept_Linear));
535   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", PetscRegressorLinearSetUseKSP_Linear));
536   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", PetscRegressorLinearGetKSP_Linear));
537   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", PetscRegressorLinearGetCoefficients_Linear));
538   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", PetscRegressorLinearGetIntercept_Linear));
539   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", PetscRegressorLinearSetType_Linear));
540   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", PetscRegressorLinearGetType_Linear));
541   PetscFunctionReturn(PETSC_SUCCESS);
542 }
543