xref: /petsc/src/ml/regressor/impls/linear/linearimpl.h (revision c12c126234ed623246a63bfa78c9f75a3aa00323)
1*34b254c5SRichard Tran Mills #pragma once
2*34b254c5SRichard Tran Mills #include <petsc/private/regressorimpl.h>
3*34b254c5SRichard Tran Mills #include <petscksp.h>
4*34b254c5SRichard Tran Mills #include <petsctao.h>
5*34b254c5SRichard Tran Mills 
6*34b254c5SRichard Tran Mills /* We define this header, since it serves as a "base" for all linear models. */
7*34b254c5SRichard Tran Mills #define REGRESSOR_LINEAR_HEADER \
8*34b254c5SRichard Tran Mills   PetscRegressorLinearType type; \
9*34b254c5SRichard Tran Mills   /* Parameters of the fitted regression model */ \
10*34b254c5SRichard Tran Mills   Vec         coefficients; \
11*34b254c5SRichard Tran Mills   PetscScalar intercept; \
12*34b254c5SRichard Tran Mills \
13*34b254c5SRichard Tran Mills   Mat X;        /* Operator of the linear model; often the training data matrix, but might be a MATCOMPOSITE */ \
14*34b254c5SRichard Tran Mills   Mat C;        /* Centering matrix */ \
15*34b254c5SRichard Tran Mills   Vec rhs;      /* Right-hand side of the linear model; often the target vector, but may be the mean-centered version */ \
16*34b254c5SRichard Tran Mills   Vec residual; /* Residual for our model, or the loss vector */ \
17*34b254c5SRichard Tran Mills   /* Various options */ \
18*34b254c5SRichard Tran Mills   PetscBool fit_intercept; /* Calculate intercept ("bias" or "offset") if true. Assume centered data if false. */ \
19*34b254c5SRichard Tran Mills   PetscBool use_ksp        /* Use KSP for the model-fitting problem; otherwise we will use TAO. */
20*34b254c5SRichard Tran Mills 
21*34b254c5SRichard Tran Mills typedef struct {
22*34b254c5SRichard Tran Mills   REGRESSOR_LINEAR_HEADER;
23*34b254c5SRichard Tran Mills 
24*34b254c5SRichard Tran Mills   PetscInt ksp_its, ksp_tot_its;
25*34b254c5SRichard Tran Mills   KSP      ksp;
26*34b254c5SRichard Tran Mills   Mat      XtX; /* Normal matrix formed from X */
27*34b254c5SRichard Tran Mills } PetscRegressor_Linear;
28