xref: /petsc/src/ml/regressor/impls/linear/linear.c (revision 8577b683712d1cca1e9b8fdaa9ae028364224dad) !
134b254c5SRichard Tran Mills #include <../src/ml/regressor/impls/linear/linearimpl.h> /*I "petscregressor.h" I*/
234b254c5SRichard Tran Mills 
334b254c5SRichard Tran Mills const char *const PetscRegressorLinearTypes[] = {"ols", "lasso", "ridge", "RegressorLinearType", "REGRESSOR_LINEAR_", NULL};
434b254c5SRichard Tran Mills 
PetscRegressorLinearSetFitIntercept_Linear(PetscRegressor regressor,PetscBool flg)534b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearSetFitIntercept_Linear(PetscRegressor regressor, PetscBool flg)
634b254c5SRichard Tran Mills {
734b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
834b254c5SRichard Tran Mills 
934b254c5SRichard Tran Mills   PetscFunctionBegin;
1034b254c5SRichard Tran Mills   linear->fit_intercept = flg;
1134b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
1234b254c5SRichard Tran Mills }
1334b254c5SRichard Tran Mills 
PetscRegressorLinearSetType_Linear(PetscRegressor regressor,PetscRegressorLinearType type)1434b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearSetType_Linear(PetscRegressor regressor, PetscRegressorLinearType type)
1534b254c5SRichard Tran Mills {
1634b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
1734b254c5SRichard Tran Mills 
1834b254c5SRichard Tran Mills   PetscFunctionBegin;
1934b254c5SRichard Tran Mills   linear->type = type;
2034b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
2134b254c5SRichard Tran Mills }
2234b254c5SRichard Tran Mills 
PetscRegressorLinearGetType_Linear(PetscRegressor regressor,PetscRegressorLinearType * type)2334b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearGetType_Linear(PetscRegressor regressor, PetscRegressorLinearType *type)
2434b254c5SRichard Tran Mills {
2534b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
2634b254c5SRichard Tran Mills 
2734b254c5SRichard Tran Mills   PetscFunctionBegin;
2834b254c5SRichard Tran Mills   *type = linear->type;
2934b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
3034b254c5SRichard Tran Mills }
3134b254c5SRichard Tran Mills 
PetscRegressorLinearGetIntercept_Linear(PetscRegressor regressor,PetscScalar * intercept)3234b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearGetIntercept_Linear(PetscRegressor regressor, PetscScalar *intercept)
3334b254c5SRichard Tran Mills {
3434b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
3534b254c5SRichard Tran Mills 
3634b254c5SRichard Tran Mills   PetscFunctionBegin;
3734b254c5SRichard Tran Mills   *intercept = linear->intercept;
3834b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
3934b254c5SRichard Tran Mills }
4034b254c5SRichard Tran Mills 
PetscRegressorLinearGetCoefficients_Linear(PetscRegressor regressor,Vec * coefficients)4134b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearGetCoefficients_Linear(PetscRegressor regressor, Vec *coefficients)
4234b254c5SRichard Tran Mills {
4334b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
4434b254c5SRichard Tran Mills 
4534b254c5SRichard Tran Mills   PetscFunctionBegin;
4634b254c5SRichard Tran Mills   *coefficients = linear->coefficients;
4734b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
4834b254c5SRichard Tran Mills }
4934b254c5SRichard Tran Mills 
PetscRegressorLinearGetKSP_Linear(PetscRegressor regressor,KSP * ksp)5034b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearGetKSP_Linear(PetscRegressor regressor, KSP *ksp)
5134b254c5SRichard Tran Mills {
5234b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
5334b254c5SRichard Tran Mills 
5434b254c5SRichard Tran Mills   PetscFunctionBegin;
5534b254c5SRichard Tran Mills   if (!linear->ksp) {
5634b254c5SRichard Tran Mills     PetscCall(KSPCreate(PetscObjectComm((PetscObject)regressor), &linear->ksp));
5734b254c5SRichard Tran Mills     PetscCall(PetscObjectIncrementTabLevel((PetscObject)linear->ksp, (PetscObject)regressor, 1));
5834b254c5SRichard Tran Mills     PetscCall(PetscObjectSetOptions((PetscObject)linear->ksp, ((PetscObject)regressor)->options));
5934b254c5SRichard Tran Mills   }
6034b254c5SRichard Tran Mills   *ksp = linear->ksp;
6134b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
6234b254c5SRichard Tran Mills }
6334b254c5SRichard Tran Mills 
PetscRegressorLinearSetUseKSP_Linear(PetscRegressor regressor,PetscBool flg)6434b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearSetUseKSP_Linear(PetscRegressor regressor, PetscBool flg)
6534b254c5SRichard Tran Mills {
6634b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
6734b254c5SRichard Tran Mills 
6834b254c5SRichard Tran Mills   PetscFunctionBegin;
6934b254c5SRichard Tran Mills   linear->use_ksp = flg;
7034b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
7134b254c5SRichard Tran Mills }
7234b254c5SRichard Tran Mills 
EvaluateResidual(Tao tao,Vec x,Vec f,void * ptr)7334b254c5SRichard Tran Mills static PetscErrorCode EvaluateResidual(Tao tao, Vec x, Vec f, void *ptr)
7434b254c5SRichard Tran Mills {
7534b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)ptr;
7634b254c5SRichard Tran Mills 
7734b254c5SRichard Tran Mills   PetscFunctionBegin;
7834b254c5SRichard Tran Mills   /* Evaluate f = A * x - b */
7934b254c5SRichard Tran Mills   PetscCall(MatMult(linear->X, x, f));
8034b254c5SRichard Tran Mills   PetscCall(VecAXPY(f, -1.0, linear->rhs));
8134b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
8234b254c5SRichard Tran Mills }
8334b254c5SRichard Tran Mills 
EvaluateJacobian(Tao tao,Vec x,Mat J,Mat Jpre,void * ptr)8434b254c5SRichard Tran Mills static PetscErrorCode EvaluateJacobian(Tao tao, Vec x, Mat J, Mat Jpre, void *ptr)
8534b254c5SRichard Tran Mills {
86540f39e1SHansol Suh   /* The TAOBRGN API expects us to pass an EvaluateJacobian() routine to it, but in this case it is a dummy function.
87540f39e1SHansol Suh      Denoting our data matrix as X, for linear least squares J[m][n] = df[m]/dx[n] = X[m][n]. */
8834b254c5SRichard Tran Mills   PetscFunctionBegin;
8934b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
9034b254c5SRichard Tran Mills }
9134b254c5SRichard Tran Mills 
PetscRegressorSetUp_Linear(PetscRegressor regressor)9234b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorSetUp_Linear(PetscRegressor regressor)
9334b254c5SRichard Tran Mills {
9434b254c5SRichard Tran Mills   PetscInt               M, N;
9534b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
9634b254c5SRichard Tran Mills   KSP                    ksp;
9734b254c5SRichard Tran Mills   Tao                    tao;
9834b254c5SRichard Tran Mills 
9934b254c5SRichard Tran Mills   PetscFunctionBegin;
10034b254c5SRichard Tran Mills   PetscCall(MatGetSize(regressor->training, &M, &N));
10134b254c5SRichard Tran Mills 
10234b254c5SRichard Tran Mills   if (linear->fit_intercept) {
10334b254c5SRichard Tran Mills     /* If we are fitting the intercept, we need to make A a composite matrix using MATCENTERING to preserve sparsity.
10434b254c5SRichard Tran Mills      * Though there might be some cases we don't want to do this for, depending on what kind of matrix is passed in. (Probably bad idea for dense?)
10534b254c5SRichard Tran Mills      * We will also need to ensure that the right-hand side passed to the KSP is also mean-centered, since we
10634b254c5SRichard Tran Mills      * intend to compute the intercept separately from regression coefficients (that is, we will not be adding a
10734b254c5SRichard Tran Mills      * column of all 1s to our design matrix). */
10834b254c5SRichard Tran Mills     PetscCall(MatCreateCentering(PetscObjectComm((PetscObject)regressor), PETSC_DECIDE, M, &linear->C));
10934b254c5SRichard Tran Mills     PetscCall(MatCreate(PetscObjectComm((PetscObject)regressor), &linear->X));
11034b254c5SRichard Tran Mills     PetscCall(MatSetSizes(linear->X, PETSC_DECIDE, PETSC_DECIDE, M, N));
11134b254c5SRichard Tran Mills     PetscCall(MatSetType(linear->X, MATCOMPOSITE));
11234b254c5SRichard Tran Mills     PetscCall(MatCompositeSetType(linear->X, MAT_COMPOSITE_MULTIPLICATIVE));
11334b254c5SRichard Tran Mills     PetscCall(MatCompositeAddMat(linear->X, regressor->training));
11434b254c5SRichard Tran Mills     PetscCall(MatCompositeAddMat(linear->X, linear->C));
11534b254c5SRichard Tran Mills     PetscCall(VecDuplicate(regressor->target, &linear->rhs));
11634b254c5SRichard Tran Mills     PetscCall(MatMult(linear->C, regressor->target, linear->rhs));
11734b254c5SRichard Tran Mills   } else {
11834b254c5SRichard Tran Mills     // When not fitting intercept, we assume that the input data are already centered.
11934b254c5SRichard Tran Mills     linear->X   = regressor->training;
12034b254c5SRichard Tran Mills     linear->rhs = regressor->target;
12134b254c5SRichard Tran Mills 
12234b254c5SRichard Tran Mills     PetscCall(PetscObjectReference((PetscObject)linear->X));
12334b254c5SRichard Tran Mills     PetscCall(PetscObjectReference((PetscObject)linear->rhs));
12434b254c5SRichard Tran Mills   }
12534b254c5SRichard Tran Mills 
12634b254c5SRichard Tran Mills   if (linear->coefficients) PetscCall(VecDestroy(&linear->coefficients));
12734b254c5SRichard Tran Mills 
12834b254c5SRichard Tran Mills   if (linear->use_ksp) {
12934b254c5SRichard Tran Mills     PetscCheck(linear->type == REGRESSOR_LINEAR_OLS, PetscObjectComm((PetscObject)regressor), PETSC_ERR_ARG_WRONGSTATE, "KSP can be used to fit a linear regressor only when its type is OLS");
13034b254c5SRichard Tran Mills 
13134b254c5SRichard Tran Mills     if (!linear->ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp));
13234b254c5SRichard Tran Mills     ksp = linear->ksp;
13334b254c5SRichard Tran Mills 
13434b254c5SRichard Tran Mills     PetscCall(MatCreateVecs(linear->X, &linear->coefficients, NULL));
13534b254c5SRichard Tran Mills     /* Set up the KSP to solve the least squares problem (without solving for intercept, as this is done separately) using KSPLSQR. */
13634b254c5SRichard Tran Mills     PetscCall(MatCreateNormal(linear->X, &linear->XtX));
13734b254c5SRichard Tran Mills     PetscCall(KSPSetType(ksp, KSPLSQR));
13834b254c5SRichard Tran Mills     PetscCall(KSPSetOperators(ksp, linear->X, linear->XtX));
13934b254c5SRichard Tran Mills     PetscCall(KSPSetOptionsPrefix(ksp, ((PetscObject)regressor)->prefix));
14034b254c5SRichard Tran Mills     PetscCall(KSPAppendOptionsPrefix(ksp, "regressor_linear_"));
14134b254c5SRichard Tran Mills     PetscCall(KSPSetFromOptions(ksp));
14234b254c5SRichard Tran Mills   } else {
14334b254c5SRichard Tran Mills     /* Note: Currently implementation creates TAO inside of implementations.
14434b254c5SRichard Tran Mills       * Thus, all the prefix jobs are done inside implementations, not in interface */
14534b254c5SRichard Tran Mills     const char *prefix;
14634b254c5SRichard Tran Mills 
14734b254c5SRichard Tran Mills     if (!regressor->tao) PetscCall(PetscRegressorGetTao(regressor, &tao));
14834b254c5SRichard Tran Mills 
14934b254c5SRichard Tran Mills     PetscCall(MatCreateVecs(linear->X, &linear->coefficients, &linear->residual));
15034b254c5SRichard Tran Mills     /* Set up the TAO object to solve the (regularized) least squares problem (without solving for intercept, which is done separately) using TAOBRGN. */
15134b254c5SRichard Tran Mills     PetscCall(TaoSetType(tao, TAOBRGN));
15234b254c5SRichard Tran Mills     PetscCall(TaoSetSolution(tao, linear->coefficients));
15334b254c5SRichard Tran Mills     PetscCall(TaoSetResidualRoutine(tao, linear->residual, EvaluateResidual, linear));
15434b254c5SRichard Tran Mills     PetscCall(TaoSetJacobianResidualRoutine(tao, linear->X, linear->X, EvaluateJacobian, linear));
15534b254c5SRichard Tran Mills     // Set the regularization type and weight for the BRGN as linear->type dictates:
15634b254c5SRichard Tran Mills     // TODO BRGN needs to be BRGNSetRegularizationType
15734b254c5SRichard Tran Mills     // PetscOptionsSetValue no longer works due to functioning prefix system
15834b254c5SRichard Tran Mills     PetscCall(PetscRegressorGetOptionsPrefix(regressor, &prefix));
15934b254c5SRichard Tran Mills     PetscCall(TaoSetOptionsPrefix(regressor->tao, prefix));
16034b254c5SRichard Tran Mills     PetscCall(TaoAppendOptionsPrefix(tao, "regressor_linear_"));
16134b254c5SRichard Tran Mills     switch (linear->type) {
16234b254c5SRichard Tran Mills     case REGRESSOR_LINEAR_OLS:
16334b254c5SRichard Tran Mills       regressor->regularizer_weight = 0.0; // OLS, by definition, uses a regularizer weight of 0
16434b254c5SRichard Tran Mills       break;
16534b254c5SRichard Tran Mills     case REGRESSOR_LINEAR_LASSO:
166c0b7dd19SHansol Suh       PetscCall(TaoBRGNSetRegularizationType(regressor->tao, TAOBRGN_REGULARIZATION_L1DICT));
16734b254c5SRichard Tran Mills       break;
16834b254c5SRichard Tran Mills     case REGRESSOR_LINEAR_RIDGE:
169c0b7dd19SHansol Suh       PetscCall(TaoBRGNSetRegularizationType(regressor->tao, TAOBRGN_REGULARIZATION_L2PURE));
17034b254c5SRichard Tran Mills       break;
17134b254c5SRichard Tran Mills     default:
17234b254c5SRichard Tran Mills       break;
17334b254c5SRichard Tran Mills     }
174c7c036fbSHansol Suh     if (!linear->use_ksp) PetscCall(TaoBRGNSetRegularizerWeight(tao, regressor->regularizer_weight));
17534b254c5SRichard Tran Mills     PetscCall(TaoSetFromOptions(tao));
17634b254c5SRichard Tran Mills   }
17734b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
17834b254c5SRichard Tran Mills }
17934b254c5SRichard Tran Mills 
PetscRegressorReset_Linear(PetscRegressor regressor)18034b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorReset_Linear(PetscRegressor regressor)
18134b254c5SRichard Tran Mills {
18234b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
18334b254c5SRichard Tran Mills 
18434b254c5SRichard Tran Mills   PetscFunctionBegin;
18534b254c5SRichard Tran Mills   /* Destroy the PETSc objects associated with the linear regressor implementation. */
18634b254c5SRichard Tran Mills   linear->ksp_its     = 0;
18734b254c5SRichard Tran Mills   linear->ksp_tot_its = 0;
18834b254c5SRichard Tran Mills 
18934b254c5SRichard Tran Mills   PetscCall(MatDestroy(&linear->X));
19034b254c5SRichard Tran Mills   PetscCall(MatDestroy(&linear->XtX));
19134b254c5SRichard Tran Mills   PetscCall(MatDestroy(&linear->C));
19234b254c5SRichard Tran Mills   PetscCall(KSPDestroy(&linear->ksp));
19334b254c5SRichard Tran Mills   PetscCall(VecDestroy(&linear->coefficients));
19434b254c5SRichard Tran Mills   PetscCall(VecDestroy(&linear->rhs));
19534b254c5SRichard Tran Mills   PetscCall(VecDestroy(&linear->residual));
19634b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
19734b254c5SRichard Tran Mills }
19834b254c5SRichard Tran Mills 
PetscRegressorDestroy_Linear(PetscRegressor regressor)19934b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorDestroy_Linear(PetscRegressor regressor)
20034b254c5SRichard Tran Mills {
20134b254c5SRichard Tran Mills   PetscFunctionBegin;
20234b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", NULL));
20334b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", NULL));
20434b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", NULL));
20534b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", NULL));
20634b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", NULL));
20734b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", NULL));
20834b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", NULL));
20934b254c5SRichard Tran Mills   PetscCall(PetscRegressorReset_Linear(regressor));
21034b254c5SRichard Tran Mills   PetscCall(PetscFree(regressor->data));
21134b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
21234b254c5SRichard Tran Mills }
21334b254c5SRichard Tran Mills 
21434b254c5SRichard Tran Mills /*@
21534b254c5SRichard Tran Mills   PetscRegressorLinearSetFitIntercept - Set a flag to indicate that the intercept (also known as the "bias" or "offset") should
21634b254c5SRichard Tran Mills   be calculated; data are assumed to be mean-centered if false.
21734b254c5SRichard Tran Mills 
21834b254c5SRichard Tran Mills   Logically Collective
21934b254c5SRichard Tran Mills 
22034b254c5SRichard Tran Mills   Input Parameters:
22134b254c5SRichard Tran Mills + regressor - the `PetscRegressor` context
22234b254c5SRichard Tran Mills - flg       - `PETSC_TRUE` to calculate the intercept, `PETSC_FALSE` to assume mean-centered data (default is `PETSC_TRUE`)
22334b254c5SRichard Tran Mills 
22434b254c5SRichard Tran Mills   Level: intermediate
22534b254c5SRichard Tran Mills 
22634b254c5SRichard Tran Mills   Options Database Key:
22734b254c5SRichard Tran Mills . regressor_linear_fit_intercept <true,false> - fit the intercept
22834b254c5SRichard Tran Mills 
22934b254c5SRichard Tran Mills   Note:
23034b254c5SRichard Tran Mills   If the user indicates that the intercept should not be calculated, the intercept will be set to zero.
23134b254c5SRichard Tran Mills 
23234b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PetscRegressorFit()`
23334b254c5SRichard Tran Mills @*/
PetscRegressorLinearSetFitIntercept(PetscRegressor regressor,PetscBool flg)23434b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearSetFitIntercept(PetscRegressor regressor, PetscBool flg)
23534b254c5SRichard Tran Mills {
23634b254c5SRichard Tran Mills   PetscFunctionBegin;
23734b254c5SRichard Tran Mills   /* TODO: Add companion PetscRegressorLinearGetFitIntercept(), and put it in the .seealso: */
23834b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
23934b254c5SRichard Tran Mills   PetscValidLogicalCollectiveBool(regressor, flg, 2);
24034b254c5SRichard Tran Mills   PetscTryMethod(regressor, "PetscRegressorLinearSetFitIntercept_C", (PetscRegressor, PetscBool), (regressor, flg));
24134b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
24234b254c5SRichard Tran Mills }
24334b254c5SRichard Tran Mills 
24434b254c5SRichard Tran Mills /*@
24534b254c5SRichard Tran Mills   PetscRegressorLinearSetUseKSP - Set a flag to indicate that a `KSP` object, instead of a `Tao` one, should be used
246789736e1SBarry Smith   to fit the linear regressor
24734b254c5SRichard Tran Mills 
24834b254c5SRichard Tran Mills   Logically Collective
24934b254c5SRichard Tran Mills 
25034b254c5SRichard Tran Mills   Input Parameters:
25134b254c5SRichard Tran Mills + regressor - the `PetscRegressor` context
25234b254c5SRichard Tran Mills - flg       - `PETSC_TRUE` to use a `KSP`, `PETSC_FALSE` to use a `Tao` object (default is false)
25334b254c5SRichard Tran Mills 
25434b254c5SRichard Tran Mills   Options Database Key:
25534b254c5SRichard Tran Mills . regressor_linear_use_ksp <true,false> - use `KSP`
25634b254c5SRichard Tran Mills 
25734b254c5SRichard Tran Mills   Level: intermediate
25834b254c5SRichard Tran Mills 
259789736e1SBarry Smith   Notes:
260789736e1SBarry Smith   `KSPLSQR` with no preconditioner is used to solve the normal equations by default.
261789736e1SBarry Smith 
262789736e1SBarry Smith   For sequential `MATSEQAIJ` sparse matrices QR factorization a `PCType` of `PCQR` can be used to solve the least-squares system with a `MatSolverType` of
263789736e1SBarry Smith   `MATSOLVERSPQR`, using, for example,
264789736e1SBarry Smith .vb
265789736e1SBarry Smith   -ksp_type none -pc_type qr -pc_factor_mat_solver_type sp
266789736e1SBarry Smith .ve
267789736e1SBarry Smith   if centering, `PetscRegressorLinearSetFitIntercept()`, is not used.
268789736e1SBarry Smith 
269789736e1SBarry Smith   Developer Notes:
270789736e1SBarry Smith   It should be possible to use Cholesky (and any other preconditioners) to solve the normal equations.
271789736e1SBarry Smith 
272789736e1SBarry Smith   It should be possible to use QR if centering is used. See ml/regressor/ex1.c and ex2.c
273789736e1SBarry Smith 
274789736e1SBarry Smith   It should be possible to use dense SVD `PCSVD` and dense qr directly on the rectangular matrix to solve the least squares problem.
275789736e1SBarry Smith 
276789736e1SBarry Smith   Adding the above support seems to require a refactorization of how least squares problems are solved with PETSc in `KSPLSQR`
277789736e1SBarry Smith 
278789736e1SBarry Smith .seealso: `PetscRegressor`, `PetscRegressorLinearGetKSP()`, `KSPLSQR`, `PCQR`, `MATSOLVERSPQR`, `MatSolverType`, `MATSEQDENSE`, `PCSVD`
27934b254c5SRichard Tran Mills @*/
PetscRegressorLinearSetUseKSP(PetscRegressor regressor,PetscBool flg)28034b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearSetUseKSP(PetscRegressor regressor, PetscBool flg)
28134b254c5SRichard Tran Mills {
28234b254c5SRichard Tran Mills   PetscFunctionBegin;
28334b254c5SRichard Tran Mills   /* TODO: Add companion PetscRegressorLinearGetUseKSP(), and put it in the .seealso: */
28434b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
28534b254c5SRichard Tran Mills   PetscValidLogicalCollectiveBool(regressor, flg, 2);
28634b254c5SRichard Tran Mills   PetscTryMethod(regressor, "PetscRegressorLinearSetUseKSP_C", (PetscRegressor, PetscBool), (regressor, flg));
28734b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
28834b254c5SRichard Tran Mills }
28934b254c5SRichard Tran Mills 
PetscRegressorSetFromOptions_Linear(PetscRegressor regressor,PetscOptionItems PetscOptionsObject)29034b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorSetFromOptions_Linear(PetscRegressor regressor, PetscOptionItems PetscOptionsObject)
29134b254c5SRichard Tran Mills {
29234b254c5SRichard Tran Mills   PetscBool              set, flg = PETSC_FALSE;
29334b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
29434b254c5SRichard Tran Mills 
29534b254c5SRichard Tran Mills   PetscFunctionBegin;
29634b254c5SRichard Tran Mills   PetscOptionsHeadBegin(PetscOptionsObject, "PetscRegressor options for linear regressors");
29734b254c5SRichard Tran Mills   PetscCall(PetscOptionsBool("-regressor_linear_fit_intercept", "Calculate intercept for linear model", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set));
29834b254c5SRichard Tran Mills   if (set) PetscCall(PetscRegressorLinearSetFitIntercept(regressor, flg));
29934b254c5SRichard Tran Mills   PetscCall(PetscOptionsBool("-regressor_linear_use_ksp", "Use KSP instead of TAO for linear model fitting problem", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set));
30034b254c5SRichard Tran Mills   if (set) PetscCall(PetscRegressorLinearSetUseKSP(regressor, flg));
30134b254c5SRichard Tran Mills   PetscCall(PetscOptionsEnum("-regressor_linear_type", "Linear regression method", "PetscRegressorLinearTypes", PetscRegressorLinearTypes, (PetscEnum)linear->type, (PetscEnum *)&linear->type, NULL));
30234b254c5SRichard Tran Mills   PetscOptionsHeadEnd();
30334b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
30434b254c5SRichard Tran Mills }
30534b254c5SRichard Tran Mills 
PetscRegressorView_Linear(PetscRegressor regressor,PetscViewer viewer)30634b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorView_Linear(PetscRegressor regressor, PetscViewer viewer)
30734b254c5SRichard Tran Mills {
30834b254c5SRichard Tran Mills   PetscBool              isascii;
30934b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
31034b254c5SRichard Tran Mills 
31134b254c5SRichard Tran Mills   PetscFunctionBegin;
31234b254c5SRichard Tran Mills   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
31334b254c5SRichard Tran Mills   if (isascii) {
31434b254c5SRichard Tran Mills     PetscCall(PetscViewerASCIIPushTab(viewer));
31534b254c5SRichard Tran Mills     PetscCall(PetscViewerASCIIPrintf(viewer, "PetscRegressor Linear Type: %s\n", PetscRegressorLinearTypes[linear->type]));
31634b254c5SRichard Tran Mills     if (linear->ksp) {
31734b254c5SRichard Tran Mills       PetscCall(KSPView(linear->ksp, viewer));
31834b254c5SRichard Tran Mills       PetscCall(PetscViewerASCIIPrintf(viewer, "total KSP iterations: %" PetscInt_FMT "\n", linear->ksp_tot_its));
31934b254c5SRichard Tran Mills     }
32034b254c5SRichard Tran Mills     if (linear->fit_intercept) PetscCall(PetscViewerASCIIPrintf(viewer, "Intercept=%g\n", (double)linear->intercept));
32134b254c5SRichard Tran Mills     PetscCall(PetscViewerASCIIPopTab(viewer));
32234b254c5SRichard Tran Mills   }
32334b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
32434b254c5SRichard Tran Mills }
32534b254c5SRichard Tran Mills 
32634b254c5SRichard Tran Mills /*@
32734b254c5SRichard Tran Mills   PetscRegressorLinearGetKSP - Returns the `KSP` context for a `PETSCREGRESSORLINEAR` object.
32834b254c5SRichard Tran Mills 
32934b254c5SRichard Tran Mills   Not Collective, but if the `PetscRegressor` is parallel, then the `KSP` object is parallel
33034b254c5SRichard Tran Mills 
33134b254c5SRichard Tran Mills   Input Parameter:
33234b254c5SRichard Tran Mills . regressor - the `PetscRegressor` context
33334b254c5SRichard Tran Mills 
33434b254c5SRichard Tran Mills   Output Parameter:
33534b254c5SRichard Tran Mills . ksp - the `KSP` context
33634b254c5SRichard Tran Mills 
33734b254c5SRichard Tran Mills   Level: beginner
33834b254c5SRichard Tran Mills 
33934b254c5SRichard Tran Mills   Note:
34034b254c5SRichard Tran Mills   This routine will always return a `KSP`, but, depending on the type of the linear regressor and the options that are set, the regressor may actually use a `Tao` object instead of this `KSP`.
34134b254c5SRichard Tran Mills 
34234b254c5SRichard Tran Mills .seealso: `PetscRegressorGetTao()`
34334b254c5SRichard Tran Mills @*/
PetscRegressorLinearGetKSP(PetscRegressor regressor,KSP * ksp)34434b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearGetKSP(PetscRegressor regressor, KSP *ksp)
34534b254c5SRichard Tran Mills {
34634b254c5SRichard Tran Mills   PetscFunctionBegin;
34734b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
34834b254c5SRichard Tran Mills   PetscAssertPointer(ksp, 2);
34934b254c5SRichard Tran Mills   PetscUseMethod(regressor, "PetscRegressorLinearGetKSP_C", (PetscRegressor, KSP *), (regressor, ksp));
35034b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
35134b254c5SRichard Tran Mills }
35234b254c5SRichard Tran Mills 
35334b254c5SRichard Tran Mills /*@
35434b254c5SRichard Tran Mills   PetscRegressorLinearGetCoefficients - Get a vector of the fitted coefficients from a linear regression model
35534b254c5SRichard Tran Mills 
35634b254c5SRichard Tran Mills   Not Collective but the vector is parallel
35734b254c5SRichard Tran Mills 
35834b254c5SRichard Tran Mills   Input Parameter:
35934b254c5SRichard Tran Mills . regressor - the `PetscRegressor` context
36034b254c5SRichard Tran Mills 
36134b254c5SRichard Tran Mills   Output Parameter:
36234b254c5SRichard Tran Mills . coefficients - the vector of the coefficients
36334b254c5SRichard Tran Mills 
36434b254c5SRichard Tran Mills   Level: beginner
36534b254c5SRichard Tran Mills 
36634b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PetscRegressorLinearGetIntercept()`, `PETSCREGRESSORLINEAR`, `Vec`
36734b254c5SRichard Tran Mills @*/
PetscRegressorLinearGetCoefficients(PetscRegressor regressor,Vec * coefficients)36834b254c5SRichard Tran Mills PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetCoefficients(PetscRegressor regressor, Vec *coefficients)
36934b254c5SRichard Tran Mills {
37034b254c5SRichard Tran Mills   PetscFunctionBegin;
37134b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
37234b254c5SRichard Tran Mills   PetscAssertPointer(coefficients, 2);
37334b254c5SRichard Tran Mills   PetscUseMethod(regressor, "PetscRegressorLinearGetCoefficients_C", (PetscRegressor, Vec *), (regressor, coefficients));
37434b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
37534b254c5SRichard Tran Mills }
37634b254c5SRichard Tran Mills 
37734b254c5SRichard Tran Mills /*@
37834b254c5SRichard Tran Mills   PetscRegressorLinearGetIntercept - Get the intercept from a linear regression model
37934b254c5SRichard Tran Mills 
38034b254c5SRichard Tran Mills   Not Collective
38134b254c5SRichard Tran Mills 
38234b254c5SRichard Tran Mills   Input Parameter:
38334b254c5SRichard Tran Mills . regressor - the `PetscRegressor` context
38434b254c5SRichard Tran Mills 
38534b254c5SRichard Tran Mills   Output Parameter:
38634b254c5SRichard Tran Mills . intercept - the intercept
38734b254c5SRichard Tran Mills 
38834b254c5SRichard Tran Mills   Level: beginner
38934b254c5SRichard Tran Mills 
39034b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PetscRegressorLinearSetFitIntercept()`, `PetscRegressorLinearGetCoefficients()`, `PETSCREGRESSORLINEAR`
39134b254c5SRichard Tran Mills @*/
PetscRegressorLinearGetIntercept(PetscRegressor regressor,PetscScalar * intercept)39234b254c5SRichard Tran Mills PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetIntercept(PetscRegressor regressor, PetscScalar *intercept)
39334b254c5SRichard Tran Mills {
39434b254c5SRichard Tran Mills   PetscFunctionBegin;
39534b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
39634b254c5SRichard Tran Mills   PetscAssertPointer(intercept, 2);
39734b254c5SRichard Tran Mills   PetscUseMethod(regressor, "PetscRegressorLinearGetIntercept_C", (PetscRegressor, PetscScalar *), (regressor, intercept));
39834b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
39934b254c5SRichard Tran Mills }
40034b254c5SRichard Tran Mills 
40134b254c5SRichard Tran Mills /*@C
40234b254c5SRichard Tran Mills   PetscRegressorLinearSetType - Sets the type of linear regression to be performed
40334b254c5SRichard Tran Mills 
40434b254c5SRichard Tran Mills   Logically Collective
40534b254c5SRichard Tran Mills 
40634b254c5SRichard Tran Mills   Input Parameters:
40734b254c5SRichard Tran Mills + regressor - the `PetscRegressor` context (should be of type `PETSCREGRESSORLINEAR`)
40834b254c5SRichard Tran Mills - type      - a known linear regression method
40934b254c5SRichard Tran Mills 
41034b254c5SRichard Tran Mills   Options Database Key:
41134b254c5SRichard Tran Mills . -regressor_linear_type - Sets the linear regression method; use -help for a list of available methods
41234b254c5SRichard Tran Mills    (for instance "-regressor_linear_type ols" or "-regressor_linear_type lasso")
41334b254c5SRichard Tran Mills 
41434b254c5SRichard Tran Mills   Level: intermediate
41534b254c5SRichard Tran Mills 
416789736e1SBarry Smith .seealso: `PetscRegressorLinearGetType()`, `PetscRegressorLinearType`, `PetscRegressorSetType()`, `REGRESSOR_LINEAR_OLS`,
417789736e1SBarry Smith           `REGRESSOR_LINEAR_LASSO`, `REGRESSOR_LINEAR_RIDGE`
41834b254c5SRichard Tran Mills @*/
PetscRegressorLinearSetType(PetscRegressor regressor,PetscRegressorLinearType type)41934b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearSetType(PetscRegressor regressor, PetscRegressorLinearType type)
42034b254c5SRichard Tran Mills {
42134b254c5SRichard Tran Mills   PetscFunctionBegin;
42234b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
42334b254c5SRichard Tran Mills   PetscValidLogicalCollectiveEnum(regressor, type, 2);
42434b254c5SRichard Tran Mills   PetscTryMethod(regressor, "PetscRegressorLinearSetType_C", (PetscRegressor, PetscRegressorLinearType), (regressor, type));
42534b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
42634b254c5SRichard Tran Mills }
42734b254c5SRichard Tran Mills 
42834b254c5SRichard Tran Mills /*@
42934b254c5SRichard Tran Mills   PetscRegressorLinearGetType - Return the type for the `PETSCREGRESSORLINEAR` solver
43034b254c5SRichard Tran Mills 
43134b254c5SRichard Tran Mills   Input Parameter:
43234b254c5SRichard Tran Mills . regressor - the `PetscRegressor` solver context
43334b254c5SRichard Tran Mills 
43434b254c5SRichard Tran Mills   Output Parameter:
43534b254c5SRichard Tran Mills . type - `PETSCREGRESSORLINEAR` type
43634b254c5SRichard Tran Mills 
43734b254c5SRichard Tran Mills   Level: advanced
43834b254c5SRichard Tran Mills 
43934b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PETSCREGRESSORLINEAR`, `PetscRegressorLinearSetType()`, `PetscRegressorLinearType`
44034b254c5SRichard Tran Mills @*/
PetscRegressorLinearGetType(PetscRegressor regressor,PetscRegressorLinearType * type)44134b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearGetType(PetscRegressor regressor, PetscRegressorLinearType *type)
44234b254c5SRichard Tran Mills {
44334b254c5SRichard Tran Mills   PetscFunctionBegin;
44434b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
44534b254c5SRichard Tran Mills   PetscAssertPointer(type, 2);
44634b254c5SRichard Tran Mills   PetscUseMethod(regressor, "PetscRegressorLinearGetType_C", (PetscRegressor, PetscRegressorLinearType *), (regressor, type));
44734b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
44834b254c5SRichard Tran Mills }
44934b254c5SRichard Tran Mills 
PetscRegressorFit_Linear(PetscRegressor regressor)45034b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorFit_Linear(PetscRegressor regressor)
45134b254c5SRichard Tran Mills {
45234b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
45334b254c5SRichard Tran Mills   KSP                    ksp;
45434b254c5SRichard Tran Mills   PetscScalar            target_mean, *column_means_global, *column_means_local, column_means_dot_coefficients;
45534b254c5SRichard Tran Mills   Vec                    column_means;
45634b254c5SRichard Tran Mills   PetscInt               m, N, istart, i, kspits;
45734b254c5SRichard Tran Mills 
45834b254c5SRichard Tran Mills   PetscFunctionBegin;
45934b254c5SRichard Tran Mills   if (linear->use_ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp));
46034b254c5SRichard Tran Mills   ksp = linear->ksp;
46134b254c5SRichard Tran Mills 
46234b254c5SRichard Tran Mills   /* Solve the least-squares problem (previously set up in PetscRegressorSetUp_Linear()) without finding the intercept. */
46334b254c5SRichard Tran Mills   if (linear->use_ksp) {
46434b254c5SRichard Tran Mills     PetscCall(KSPSolve(ksp, linear->rhs, linear->coefficients));
46534b254c5SRichard Tran Mills     PetscCall(KSPGetIterationNumber(ksp, &kspits));
46634b254c5SRichard Tran Mills     linear->ksp_its += kspits;
46734b254c5SRichard Tran Mills     linear->ksp_tot_its += kspits;
46834b254c5SRichard Tran Mills   } else {
46934b254c5SRichard Tran Mills     PetscCall(TaoSolve(regressor->tao));
47034b254c5SRichard Tran Mills   }
47134b254c5SRichard Tran Mills 
47234b254c5SRichard Tran Mills   /* Calculate the intercept. */
47334b254c5SRichard Tran Mills   if (linear->fit_intercept) {
47434b254c5SRichard Tran Mills     PetscCall(MatGetSize(regressor->training, NULL, &N));
47534b254c5SRichard Tran Mills     PetscCall(PetscMalloc1(N, &column_means_global));
47634b254c5SRichard Tran Mills     PetscCall(VecMean(regressor->target, &target_mean));
47734b254c5SRichard Tran Mills     /* We need the means of all columns of regressor->training, placed into a Vec compatible with linear->coefficients.
478*8c5add6aSPierre Jolivet      * Note the potential scalability issue: MatGetColumnMeans() computes means of ALL columns. */
47934b254c5SRichard Tran Mills     PetscCall(MatGetColumnMeans(regressor->training, column_means_global));
48034b254c5SRichard Tran Mills     /* TODO: Calculation of the Vec and matrix column means should probably go into the SetUp phase, and also be placed
48134b254c5SRichard Tran Mills      *       into a routine that is callable from outside of PetscRegressorFit_Linear(), because we'll want to do the same
48234b254c5SRichard Tran Mills      *       thing for other models, such as ridge and LASSO regression, and should avoid code duplication.
48334b254c5SRichard Tran Mills      *       What we are calling 'target_mean' and 'column_means' should be stashed in the base linear regressor struct,
48434b254c5SRichard Tran Mills      *       and perhaps renamed to make it clear they are offsets that should be applied (though the current naming
48534b254c5SRichard Tran Mills      *       makes sense since it makes it clear where these come from.) */
48634b254c5SRichard Tran Mills     PetscCall(VecDuplicate(linear->coefficients, &column_means));
48734b254c5SRichard Tran Mills     PetscCall(VecGetLocalSize(column_means, &m));
48834b254c5SRichard Tran Mills     PetscCall(VecGetOwnershipRange(column_means, &istart, NULL));
48934b254c5SRichard Tran Mills     PetscCall(VecGetArrayWrite(column_means, &column_means_local));
49034b254c5SRichard Tran Mills     for (i = 0; i < m; i++) column_means_local[i] = column_means_global[istart + i];
49134b254c5SRichard Tran Mills     PetscCall(VecRestoreArrayWrite(column_means, &column_means_local));
49234b254c5SRichard Tran Mills     PetscCall(VecDot(column_means, linear->coefficients, &column_means_dot_coefficients));
49334b254c5SRichard Tran Mills     PetscCall(VecDestroy(&column_means));
49434b254c5SRichard Tran Mills     PetscCall(PetscFree(column_means_global));
49534b254c5SRichard Tran Mills     linear->intercept = target_mean - column_means_dot_coefficients;
49634b254c5SRichard Tran Mills   } else {
49734b254c5SRichard Tran Mills     linear->intercept = 0.0;
49834b254c5SRichard Tran Mills   }
49934b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
50034b254c5SRichard Tran Mills }
50134b254c5SRichard Tran Mills 
PetscRegressorPredict_Linear(PetscRegressor regressor,Mat X,Vec y)50234b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorPredict_Linear(PetscRegressor regressor, Mat X, Vec y)
50334b254c5SRichard Tran Mills {
50434b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
50534b254c5SRichard Tran Mills 
50634b254c5SRichard Tran Mills   PetscFunctionBegin;
50734b254c5SRichard Tran Mills   PetscCall(MatMult(X, linear->coefficients, y));
50834b254c5SRichard Tran Mills   PetscCall(VecShift(y, linear->intercept));
50934b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
51034b254c5SRichard Tran Mills }
51134b254c5SRichard Tran Mills 
51234b254c5SRichard Tran Mills /*MC
51334b254c5SRichard Tran Mills      PETSCREGRESSORLINEAR - Linear regression model (ordinary least squares or regularized variants)
51434b254c5SRichard Tran Mills 
51534b254c5SRichard Tran Mills    Options Database:
51634b254c5SRichard Tran Mills +  -regressor_linear_fit_intercept - Calculate the intercept for the linear model
51734b254c5SRichard Tran Mills -  -regressor_linear_use_ksp       - Use `KSP` instead of `Tao` for linear model fitting (non-regularized variants only)
51834b254c5SRichard Tran Mills 
51934b254c5SRichard Tran Mills    Level: beginner
52034b254c5SRichard Tran Mills 
5210664cd31SRichard Tran Mills    Notes:
5220664cd31SRichard Tran Mills    By "linear" we mean that the model is linear in its coefficients, but not necessarily in its input features.
5230664cd31SRichard Tran Mills    One can use the linear regressor to fit polynomial functions by training the model with a design matrix that
5240664cd31SRichard Tran Mills    is a nonlinear function of the input data.
5250664cd31SRichard Tran Mills 
52634b254c5SRichard Tran Mills    This is the default regressor in `PetscRegressor`.
52734b254c5SRichard Tran Mills 
52834b254c5SRichard Tran Mills .seealso: `PetscRegressorCreate()`, `PetscRegressor`, `PetscRegressorSetType()`
52934b254c5SRichard Tran Mills M*/
PetscRegressorCreate_Linear(PetscRegressor regressor)53034b254c5SRichard Tran Mills PETSC_EXTERN PetscErrorCode PetscRegressorCreate_Linear(PetscRegressor regressor)
53134b254c5SRichard Tran Mills {
53234b254c5SRichard Tran Mills   PetscRegressor_Linear *linear;
53334b254c5SRichard Tran Mills 
53434b254c5SRichard Tran Mills   PetscFunctionBegin;
53534b254c5SRichard Tran Mills   PetscCall(PetscNew(&linear));
53634b254c5SRichard Tran Mills   regressor->data = (void *)linear;
53734b254c5SRichard Tran Mills 
53834b254c5SRichard Tran Mills   regressor->ops->setup          = PetscRegressorSetUp_Linear;
53934b254c5SRichard Tran Mills   regressor->ops->reset          = PetscRegressorReset_Linear;
54034b254c5SRichard Tran Mills   regressor->ops->destroy        = PetscRegressorDestroy_Linear;
54134b254c5SRichard Tran Mills   regressor->ops->setfromoptions = PetscRegressorSetFromOptions_Linear;
54234b254c5SRichard Tran Mills   regressor->ops->view           = PetscRegressorView_Linear;
54334b254c5SRichard Tran Mills   regressor->ops->fit            = PetscRegressorFit_Linear;
54434b254c5SRichard Tran Mills   regressor->ops->predict        = PetscRegressorPredict_Linear;
54534b254c5SRichard Tran Mills 
54634b254c5SRichard Tran Mills   linear->intercept     = 0.0;
54734b254c5SRichard Tran Mills   linear->fit_intercept = PETSC_TRUE;  /* Default to calculating the intercept. */
54834b254c5SRichard Tran Mills   linear->use_ksp       = PETSC_FALSE; /* Do not default to using KSP for solving the model-fitting problem (use TAO instead). */
54934b254c5SRichard Tran Mills   linear->type          = REGRESSOR_LINEAR_OLS;
55034b254c5SRichard Tran Mills   /* Above, manually set the default linear regressor type.
55134b254c5SRichard Tran Mills        We don't use PetscRegressorLinearSetType() here, because that expects the SetUp event to already have happened. */
55234b254c5SRichard Tran Mills 
55334b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", PetscRegressorLinearSetFitIntercept_Linear));
55434b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", PetscRegressorLinearSetUseKSP_Linear));
55534b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", PetscRegressorLinearGetKSP_Linear));
55634b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", PetscRegressorLinearGetCoefficients_Linear));
55734b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", PetscRegressorLinearGetIntercept_Linear));
55834b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", PetscRegressorLinearSetType_Linear));
55934b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", PetscRegressorLinearGetType_Linear));
56034b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
56134b254c5SRichard Tran Mills }
562