1 #include <../src/ml/regressor/impls/linear/linearimpl.h> /*I "petscregressor.h" I*/ 2 #include <../src/tao/leastsquares/impls/brgn/brgn.h> /*I "petsctao.h" I*/ 3 4 const char *const PetscRegressorLinearTypes[] = {"ols", "lasso", "ridge", "RegressorLinearType", "REGRESSOR_LINEAR_", NULL}; 5 6 static PetscErrorCode PetscRegressorLinearSetFitIntercept_Linear(PetscRegressor regressor, PetscBool flg) 7 { 8 PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 9 10 PetscFunctionBegin; 11 linear->fit_intercept = flg; 12 PetscFunctionReturn(PETSC_SUCCESS); 13 } 14 15 static PetscErrorCode PetscRegressorLinearSetType_Linear(PetscRegressor regressor, PetscRegressorLinearType type) 16 { 17 PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 18 19 PetscFunctionBegin; 20 linear->type = type; 21 PetscFunctionReturn(PETSC_SUCCESS); 22 } 23 24 static PetscErrorCode PetscRegressorLinearGetType_Linear(PetscRegressor regressor, PetscRegressorLinearType *type) 25 { 26 PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 27 28 PetscFunctionBegin; 29 *type = linear->type; 30 PetscFunctionReturn(PETSC_SUCCESS); 31 } 32 33 static PetscErrorCode PetscRegressorLinearGetIntercept_Linear(PetscRegressor regressor, PetscScalar *intercept) 34 { 35 PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 36 37 PetscFunctionBegin; 38 *intercept = linear->intercept; 39 PetscFunctionReturn(PETSC_SUCCESS); 40 } 41 42 static PetscErrorCode PetscRegressorLinearGetCoefficients_Linear(PetscRegressor regressor, Vec *coefficients) 43 { 44 PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 45 46 PetscFunctionBegin; 47 *coefficients = linear->coefficients; 48 PetscFunctionReturn(PETSC_SUCCESS); 49 } 50 51 static PetscErrorCode PetscRegressorLinearGetKSP_Linear(PetscRegressor regressor, KSP *ksp) 52 { 53 PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 54 55 PetscFunctionBegin; 56 if (!linear->ksp) { 57 PetscCall(KSPCreate(PetscObjectComm((PetscObject)regressor), &linear->ksp)); 58 PetscCall(PetscObjectIncrementTabLevel((PetscObject)linear->ksp, (PetscObject)regressor, 1)); 59 PetscCall(PetscObjectSetOptions((PetscObject)linear->ksp, ((PetscObject)regressor)->options)); 60 } 61 *ksp = linear->ksp; 62 PetscFunctionReturn(PETSC_SUCCESS); 63 } 64 65 static PetscErrorCode PetscRegressorLinearSetUseKSP_Linear(PetscRegressor regressor, PetscBool flg) 66 { 67 PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 68 69 PetscFunctionBegin; 70 linear->use_ksp = flg; 71 PetscFunctionReturn(PETSC_SUCCESS); 72 } 73 74 static PetscErrorCode EvaluateResidual(Tao tao, Vec x, Vec f, void *ptr) 75 { 76 PetscRegressor_Linear *linear = (PetscRegressor_Linear *)ptr; 77 78 PetscFunctionBegin; 79 /* Evaluate f = A * x - b */ 80 PetscCall(MatMult(linear->X, x, f)); 81 PetscCall(VecAXPY(f, -1.0, linear->rhs)); 82 PetscFunctionReturn(PETSC_SUCCESS); 83 } 84 85 static PetscErrorCode EvaluateJacobian(Tao tao, Vec x, Mat J, Mat Jpre, void *ptr) 86 { 87 PetscRegressor_Linear *linear = (PetscRegressor_Linear *)ptr; 88 89 PetscFunctionBegin; 90 J = linear->X; 91 Jpre = linear->X; 92 PetscFunctionReturn(PETSC_SUCCESS); 93 } 94 95 static PetscErrorCode PetscRegressorSetUp_Linear(PetscRegressor regressor) 96 { 97 PetscInt M, N; 98 PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 99 KSP ksp; 100 Tao tao; 101 102 PetscFunctionBegin; 103 PetscCall(MatGetSize(regressor->training, &M, &N)); 104 105 if (linear->fit_intercept) { 106 /* If we are fitting the intercept, we need to make A a composite matrix using MATCENTERING to preserve sparsity. 107 * Though there might be some cases we don't want to do this for, depending on what kind of matrix is passed in. (Probably bad idea for dense?) 108 * We will also need to ensure that the right-hand side passed to the KSP is also mean-centered, since we 109 * intend to compute the intercept separately from regression coefficients (that is, we will not be adding a 110 * column of all 1s to our design matrix). */ 111 PetscCall(MatCreateCentering(PetscObjectComm((PetscObject)regressor), PETSC_DECIDE, M, &linear->C)); 112 PetscCall(MatCreate(PetscObjectComm((PetscObject)regressor), &linear->X)); 113 PetscCall(MatSetSizes(linear->X, PETSC_DECIDE, PETSC_DECIDE, M, N)); 114 PetscCall(MatSetType(linear->X, MATCOMPOSITE)); 115 PetscCall(MatCompositeSetType(linear->X, MAT_COMPOSITE_MULTIPLICATIVE)); 116 PetscCall(MatCompositeAddMat(linear->X, regressor->training)); 117 PetscCall(MatCompositeAddMat(linear->X, linear->C)); 118 PetscCall(VecDuplicate(regressor->target, &linear->rhs)); 119 PetscCall(MatMult(linear->C, regressor->target, linear->rhs)); 120 } else { 121 // When not fitting intercept, we assume that the input data are already centered. 122 linear->X = regressor->training; 123 linear->rhs = regressor->target; 124 125 PetscCall(PetscObjectReference((PetscObject)linear->X)); 126 PetscCall(PetscObjectReference((PetscObject)linear->rhs)); 127 } 128 129 if (linear->coefficients) PetscCall(VecDestroy(&linear->coefficients)); 130 131 if (linear->use_ksp) { 132 PetscCheck(linear->type == REGRESSOR_LINEAR_OLS, PetscObjectComm((PetscObject)regressor), PETSC_ERR_ARG_WRONGSTATE, "KSP can be used to fit a linear regressor only when its type is OLS"); 133 134 if (!linear->ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp)); 135 ksp = linear->ksp; 136 137 PetscCall(MatCreateVecs(linear->X, &linear->coefficients, NULL)); 138 /* Set up the KSP to solve the least squares problem (without solving for intercept, as this is done separately) using KSPLSQR. */ 139 PetscCall(MatCreateNormal(linear->X, &linear->XtX)); 140 PetscCall(KSPSetType(ksp, KSPLSQR)); 141 PetscCall(KSPSetOperators(ksp, linear->X, linear->XtX)); 142 PetscCall(KSPSetOptionsPrefix(ksp, ((PetscObject)regressor)->prefix)); 143 PetscCall(KSPAppendOptionsPrefix(ksp, "regressor_linear_")); 144 PetscCall(KSPSetFromOptions(ksp)); 145 } else { 146 /* Note: Currently implementation creates TAO inside of implementations. 147 * Thus, all the prefix jobs are done inside implementations, not in interface */ 148 const char *prefix; 149 150 if (!regressor->tao) PetscCall(PetscRegressorGetTao(regressor, &tao)); 151 152 PetscCall(MatCreateVecs(linear->X, &linear->coefficients, &linear->residual)); 153 /* Set up the TAO object to solve the (regularized) least squares problem (without solving for intercept, which is done separately) using TAOBRGN. */ 154 PetscCall(TaoSetType(tao, TAOBRGN)); 155 PetscCall(TaoSetSolution(tao, linear->coefficients)); 156 PetscCall(TaoSetResidualRoutine(tao, linear->residual, EvaluateResidual, linear)); 157 PetscCall(TaoSetJacobianResidualRoutine(tao, linear->X, linear->X, EvaluateJacobian, linear)); 158 if (!linear->use_ksp) PetscCall(TaoBRGNSetRegularizerWeight(tao, regressor->regularizer_weight)); 159 // Set the regularization type and weight for the BRGN as linear->type dictates: 160 // TODO BRGN needs to be BRGNSetRegularizationType 161 // PetscOptionsSetValue no longer works due to functioning prefix system 162 PetscCall(PetscRegressorGetOptionsPrefix(regressor, &prefix)); 163 PetscCall(TaoSetOptionsPrefix(regressor->tao, prefix)); 164 PetscCall(TaoAppendOptionsPrefix(tao, "regressor_linear_")); 165 { 166 TAO_BRGN *gn = (TAO_BRGN *)regressor->tao->data; 167 168 switch (linear->type) { 169 case REGRESSOR_LINEAR_OLS: 170 regressor->regularizer_weight = 0.0; // OLS, by definition, uses a regularizer weight of 0 171 break; 172 case REGRESSOR_LINEAR_LASSO: 173 gn->reg_type = BRGN_REGULARIZATION_L1DICT; 174 break; 175 case REGRESSOR_LINEAR_RIDGE: 176 gn->reg_type = BRGN_REGULARIZATION_L2PURE; 177 break; 178 default: 179 break; 180 } 181 } 182 PetscCall(TaoSetFromOptions(tao)); 183 } 184 PetscFunctionReturn(PETSC_SUCCESS); 185 } 186 187 static PetscErrorCode PetscRegressorReset_Linear(PetscRegressor regressor) 188 { 189 PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 190 191 PetscFunctionBegin; 192 /* Destroy the PETSc objects associated with the linear regressor implementation. */ 193 linear->ksp_its = 0; 194 linear->ksp_tot_its = 0; 195 196 PetscCall(MatDestroy(&linear->X)); 197 PetscCall(MatDestroy(&linear->XtX)); 198 PetscCall(MatDestroy(&linear->C)); 199 PetscCall(KSPDestroy(&linear->ksp)); 200 PetscCall(VecDestroy(&linear->coefficients)); 201 PetscCall(VecDestroy(&linear->rhs)); 202 PetscCall(VecDestroy(&linear->residual)); 203 PetscFunctionReturn(PETSC_SUCCESS); 204 } 205 206 static PetscErrorCode PetscRegressorDestroy_Linear(PetscRegressor regressor) 207 { 208 PetscFunctionBegin; 209 PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", NULL)); 210 PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", NULL)); 211 PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", NULL)); 212 PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", NULL)); 213 PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", NULL)); 214 PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", NULL)); 215 PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", NULL)); 216 PetscCall(PetscRegressorReset_Linear(regressor)); 217 PetscCall(PetscFree(regressor->data)); 218 PetscFunctionReturn(PETSC_SUCCESS); 219 } 220 221 /*@ 222 PetscRegressorLinearSetFitIntercept - Set a flag to indicate that the intercept (also known as the "bias" or "offset") should 223 be calculated; data are assumed to be mean-centered if false. 224 225 Logically Collective 226 227 Input Parameters: 228 + regressor - the `PetscRegressor` context 229 - flg - `PETSC_TRUE` to calculate the intercept, `PETSC_FALSE` to assume mean-centered data (default is `PETSC_TRUE`) 230 231 Level: intermediate 232 233 Options Database Key: 234 . regressor_linear_fit_intercept <true,false> - fit the intercept 235 236 Note: 237 If the user indicates that the intercept should not be calculated, the intercept will be set to zero. 238 239 .seealso: `PetscRegressor`, `PetscRegressorFit()` 240 @*/ 241 PetscErrorCode PetscRegressorLinearSetFitIntercept(PetscRegressor regressor, PetscBool flg) 242 { 243 PetscFunctionBegin; 244 /* TODO: Add companion PetscRegressorLinearGetFitIntercept(), and put it in the .seealso: */ 245 PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1); 246 PetscValidLogicalCollectiveBool(regressor, flg, 2); 247 PetscTryMethod(regressor, "PetscRegressorLinearSetFitIntercept_C", (PetscRegressor, PetscBool), (regressor, flg)); 248 PetscFunctionReturn(PETSC_SUCCESS); 249 } 250 251 /*@ 252 PetscRegressorLinearSetUseKSP - Set a flag to indicate that a `KSP` object, instead of a `Tao` one, should be used 253 to fit the regressor 254 255 Logically Collective 256 257 Input Parameters: 258 + regressor - the `PetscRegressor` context 259 - flg - `PETSC_TRUE` to use a `KSP`, `PETSC_FALSE` to use a `Tao` object (default is false) 260 261 Options Database Key: 262 . regressor_linear_use_ksp <true,false> - use `KSP` 263 264 Level: intermediate 265 266 .seealso: `PetscRegressor`, `PetscRegressorLinearGetKSP()` 267 @*/ 268 PetscErrorCode PetscRegressorLinearSetUseKSP(PetscRegressor regressor, PetscBool flg) 269 { 270 PetscFunctionBegin; 271 /* TODO: Add companion PetscRegressorLinearGetUseKSP(), and put it in the .seealso: */ 272 PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1); 273 PetscValidLogicalCollectiveBool(regressor, flg, 2); 274 PetscTryMethod(regressor, "PetscRegressorLinearSetUseKSP_C", (PetscRegressor, PetscBool), (regressor, flg)); 275 PetscFunctionReturn(PETSC_SUCCESS); 276 } 277 278 static PetscErrorCode PetscRegressorSetFromOptions_Linear(PetscRegressor regressor, PetscOptionItems PetscOptionsObject) 279 { 280 PetscBool set, flg = PETSC_FALSE; 281 PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 282 283 PetscFunctionBegin; 284 PetscOptionsHeadBegin(PetscOptionsObject, "PetscRegressor options for linear regressors"); 285 PetscCall(PetscOptionsBool("-regressor_linear_fit_intercept", "Calculate intercept for linear model", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set)); 286 if (set) PetscCall(PetscRegressorLinearSetFitIntercept(regressor, flg)); 287 PetscCall(PetscOptionsBool("-regressor_linear_use_ksp", "Use KSP instead of TAO for linear model fitting problem", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set)); 288 if (set) PetscCall(PetscRegressorLinearSetUseKSP(regressor, flg)); 289 PetscCall(PetscOptionsEnum("-regressor_linear_type", "Linear regression method", "PetscRegressorLinearTypes", PetscRegressorLinearTypes, (PetscEnum)linear->type, (PetscEnum *)&linear->type, NULL)); 290 PetscOptionsHeadEnd(); 291 PetscFunctionReturn(PETSC_SUCCESS); 292 } 293 294 static PetscErrorCode PetscRegressorView_Linear(PetscRegressor regressor, PetscViewer viewer) 295 { 296 PetscBool isascii; 297 PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 298 299 PetscFunctionBegin; 300 PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii)); 301 if (isascii) { 302 PetscCall(PetscViewerASCIIPushTab(viewer)); 303 PetscCall(PetscViewerASCIIPrintf(viewer, "PetscRegressor Linear Type: %s\n", PetscRegressorLinearTypes[linear->type])); 304 if (linear->ksp) { 305 PetscCall(KSPView(linear->ksp, viewer)); 306 PetscCall(PetscViewerASCIIPrintf(viewer, "total KSP iterations: %" PetscInt_FMT "\n", linear->ksp_tot_its)); 307 } 308 if (linear->fit_intercept) PetscCall(PetscViewerASCIIPrintf(viewer, "Intercept=%g\n", (double)linear->intercept)); 309 PetscCall(PetscViewerASCIIPopTab(viewer)); 310 } 311 PetscFunctionReturn(PETSC_SUCCESS); 312 } 313 314 /*@ 315 PetscRegressorLinearGetKSP - Returns the `KSP` context for a `PETSCREGRESSORLINEAR` object. 316 317 Not Collective, but if the `PetscRegressor` is parallel, then the `KSP` object is parallel 318 319 Input Parameter: 320 . regressor - the `PetscRegressor` context 321 322 Output Parameter: 323 . ksp - the `KSP` context 324 325 Level: beginner 326 327 Note: 328 This routine will always return a `KSP`, but, depending on the type of the linear regressor and the options that are set, the regressor may actually use a `Tao` object instead of this `KSP`. 329 330 .seealso: `PetscRegressorGetTao()` 331 @*/ 332 PetscErrorCode PetscRegressorLinearGetKSP(PetscRegressor regressor, KSP *ksp) 333 { 334 PetscFunctionBegin; 335 PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1); 336 PetscAssertPointer(ksp, 2); 337 PetscUseMethod(regressor, "PetscRegressorLinearGetKSP_C", (PetscRegressor, KSP *), (regressor, ksp)); 338 PetscFunctionReturn(PETSC_SUCCESS); 339 } 340 341 /*@ 342 PetscRegressorLinearGetCoefficients - Get a vector of the fitted coefficients from a linear regression model 343 344 Not Collective but the vector is parallel 345 346 Input Parameter: 347 . regressor - the `PetscRegressor` context 348 349 Output Parameter: 350 . coefficients - the vector of the coefficients 351 352 Level: beginner 353 354 .seealso: `PetscRegressor`, `PetscRegressorLinearGetIntercept()`, `PETSCREGRESSORLINEAR`, `Vec` 355 @*/ 356 PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetCoefficients(PetscRegressor regressor, Vec *coefficients) 357 { 358 PetscFunctionBegin; 359 PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1); 360 PetscAssertPointer(coefficients, 2); 361 PetscUseMethod(regressor, "PetscRegressorLinearGetCoefficients_C", (PetscRegressor, Vec *), (regressor, coefficients)); 362 PetscFunctionReturn(PETSC_SUCCESS); 363 } 364 365 /*@ 366 PetscRegressorLinearGetIntercept - Get the intercept from a linear regression model 367 368 Not Collective 369 370 Input Parameter: 371 . regressor - the `PetscRegressor` context 372 373 Output Parameter: 374 . intercept - the intercept 375 376 Level: beginner 377 378 .seealso: `PetscRegressor`, `PetscRegressorLinearSetFitIntercept()`, `PetscRegressorLinearGetCoefficients()`, `PETSCREGRESSORLINEAR` 379 @*/ 380 PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetIntercept(PetscRegressor regressor, PetscScalar *intercept) 381 { 382 PetscFunctionBegin; 383 PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1); 384 PetscAssertPointer(intercept, 2); 385 PetscUseMethod(regressor, "PetscRegressorLinearGetIntercept_C", (PetscRegressor, PetscScalar *), (regressor, intercept)); 386 PetscFunctionReturn(PETSC_SUCCESS); 387 } 388 389 /*@C 390 PetscRegressorLinearSetType - Sets the type of linear regression to be performed 391 392 Logically Collective 393 394 Input Parameters: 395 + regressor - the `PetscRegressor` context (should be of type `PETSCREGRESSORLINEAR`) 396 - type - a known linear regression method 397 398 Options Database Key: 399 . -regressor_linear_type - Sets the linear regression method; use -help for a list of available methods 400 (for instance "-regressor_linear_type ols" or "-regressor_linear_type lasso") 401 402 Level: intermediate 403 404 .seealso: `PetscRegressorLinearGetType()`, `PetscRegressorLinearType`, `PetscRegressorSetType()` 405 @*/ 406 PetscErrorCode PetscRegressorLinearSetType(PetscRegressor regressor, PetscRegressorLinearType type) 407 { 408 PetscFunctionBegin; 409 PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1); 410 PetscValidLogicalCollectiveEnum(regressor, type, 2); 411 PetscTryMethod(regressor, "PetscRegressorLinearSetType_C", (PetscRegressor, PetscRegressorLinearType), (regressor, type)); 412 PetscFunctionReturn(PETSC_SUCCESS); 413 } 414 415 /*@ 416 PetscRegressorLinearGetType - Return the type for the `PETSCREGRESSORLINEAR` solver 417 418 Input Parameter: 419 . regressor - the `PetscRegressor` solver context 420 421 Output Parameter: 422 . type - `PETSCREGRESSORLINEAR` type 423 424 Level: advanced 425 426 .seealso: `PetscRegressor`, `PETSCREGRESSORLINEAR`, `PetscRegressorLinearSetType()`, `PetscRegressorLinearType` 427 @*/ 428 PetscErrorCode PetscRegressorLinearGetType(PetscRegressor regressor, PetscRegressorLinearType *type) 429 { 430 PetscFunctionBegin; 431 PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1); 432 PetscAssertPointer(type, 2); 433 PetscUseMethod(regressor, "PetscRegressorLinearGetType_C", (PetscRegressor, PetscRegressorLinearType *), (regressor, type)); 434 PetscFunctionReturn(PETSC_SUCCESS); 435 } 436 437 static PetscErrorCode PetscRegressorFit_Linear(PetscRegressor regressor) 438 { 439 PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 440 KSP ksp; 441 PetscScalar target_mean, *column_means_global, *column_means_local, column_means_dot_coefficients; 442 Vec column_means; 443 PetscInt m, N, istart, i, kspits; 444 445 PetscFunctionBegin; 446 if (linear->use_ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp)); 447 ksp = linear->ksp; 448 449 /* Solve the least-squares problem (previously set up in PetscRegressorSetUp_Linear()) without finding the intercept. */ 450 if (linear->use_ksp) { 451 PetscCall(KSPSolve(ksp, linear->rhs, linear->coefficients)); 452 PetscCall(KSPGetIterationNumber(ksp, &kspits)); 453 linear->ksp_its += kspits; 454 linear->ksp_tot_its += kspits; 455 } else { 456 PetscCall(TaoSolve(regressor->tao)); 457 } 458 459 /* Calculate the intercept. */ 460 if (linear->fit_intercept) { 461 PetscCall(MatGetSize(regressor->training, NULL, &N)); 462 PetscCall(PetscMalloc1(N, &column_means_global)); 463 PetscCall(VecMean(regressor->target, &target_mean)); 464 /* We need the means of all columns of regressor->training, placed into a Vec compatible with linear->coefficients. 465 * Note the potential scalability issue: MatGetColumnMeans() computes means of ALL colummns. */ 466 PetscCall(MatGetColumnMeans(regressor->training, column_means_global)); 467 /* TODO: Calculation of the Vec and matrix column means should probably go into the SetUp phase, and also be placed 468 * into a routine that is callable from outside of PetscRegressorFit_Linear(), because we'll want to do the same 469 * thing for other models, such as ridge and LASSO regression, and should avoid code duplication. 470 * What we are calling 'target_mean' and 'column_means' should be stashed in the base linear regressor struct, 471 * and perhaps renamed to make it clear they are offsets that should be applied (though the current naming 472 * makes sense since it makes it clear where these come from.) */ 473 PetscCall(VecDuplicate(linear->coefficients, &column_means)); 474 PetscCall(VecGetLocalSize(column_means, &m)); 475 PetscCall(VecGetOwnershipRange(column_means, &istart, NULL)); 476 PetscCall(VecGetArrayWrite(column_means, &column_means_local)); 477 for (i = 0; i < m; i++) column_means_local[i] = column_means_global[istart + i]; 478 PetscCall(VecRestoreArrayWrite(column_means, &column_means_local)); 479 PetscCall(VecDot(column_means, linear->coefficients, &column_means_dot_coefficients)); 480 PetscCall(VecDestroy(&column_means)); 481 PetscCall(PetscFree(column_means_global)); 482 linear->intercept = target_mean - column_means_dot_coefficients; 483 } else { 484 linear->intercept = 0.0; 485 } 486 PetscFunctionReturn(PETSC_SUCCESS); 487 } 488 489 static PetscErrorCode PetscRegressorPredict_Linear(PetscRegressor regressor, Mat X, Vec y) 490 { 491 PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 492 493 PetscFunctionBegin; 494 PetscCall(MatMult(X, linear->coefficients, y)); 495 PetscCall(VecShift(y, linear->intercept)); 496 PetscFunctionReturn(PETSC_SUCCESS); 497 } 498 499 /*MC 500 PETSCREGRESSORLINEAR - Linear regression model (ordinary least squares or regularized variants) 501 502 Options Database: 503 + -regressor_linear_fit_intercept - Calculate the intercept for the linear model 504 - -regressor_linear_use_ksp - Use `KSP` instead of `Tao` for linear model fitting (non-regularized variants only) 505 506 Level: beginner 507 508 Note: 509 This is the default regressor in `PetscRegressor`. 510 511 .seealso: `PetscRegressorCreate()`, `PetscRegressor`, `PetscRegressorSetType()` 512 M*/ 513 PETSC_EXTERN PetscErrorCode PetscRegressorCreate_Linear(PetscRegressor regressor) 514 { 515 PetscRegressor_Linear *linear; 516 517 PetscFunctionBegin; 518 PetscCall(PetscNew(&linear)); 519 regressor->data = (void *)linear; 520 521 regressor->ops->setup = PetscRegressorSetUp_Linear; 522 regressor->ops->reset = PetscRegressorReset_Linear; 523 regressor->ops->destroy = PetscRegressorDestroy_Linear; 524 regressor->ops->setfromoptions = PetscRegressorSetFromOptions_Linear; 525 regressor->ops->view = PetscRegressorView_Linear; 526 regressor->ops->fit = PetscRegressorFit_Linear; 527 regressor->ops->predict = PetscRegressorPredict_Linear; 528 529 linear->intercept = 0.0; 530 linear->fit_intercept = PETSC_TRUE; /* Default to calculating the intercept. */ 531 linear->use_ksp = PETSC_FALSE; /* Do not default to using KSP for solving the model-fitting problem (use TAO instead). */ 532 linear->type = REGRESSOR_LINEAR_OLS; 533 /* Above, manually set the default linear regressor type. 534 We don't use PetscRegressorLinearSetType() here, because that expects the SetUp event to already have happened. */ 535 536 PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", PetscRegressorLinearSetFitIntercept_Linear)); 537 PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", PetscRegressorLinearSetUseKSP_Linear)); 538 PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", PetscRegressorLinearGetKSP_Linear)); 539 PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", PetscRegressorLinearGetCoefficients_Linear)); 540 PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", PetscRegressorLinearGetIntercept_Linear)); 541 PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", PetscRegressorLinearSetType_Linear)); 542 PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", PetscRegressorLinearGetType_Linear)); 543 PetscFunctionReturn(PETSC_SUCCESS); 544 } 545