xref: /petsc/src/ml/regressor/tests/ex_sharks.c (revision c12c126234ed623246a63bfa78c9f75a3aa00323)
1*34b254c5SRichard Tran Mills /* Example inspired by the toy example in https://www.r-bloggers.com/2020/06/understanding-lasso-and-ridge-regression-2/
2*34b254c5SRichard Tran Mills  * blog post by Dr. Atakan Ekiz.
3*34b254c5SRichard Tran Mills  * Here we wish to predict the number of shark attacks (that is, this number is our response variable),
4*34b254c5SRichard Tran Mills  * using the following predictor variables:
5*34b254c5SRichard Tran Mills  * - percentage of swimmers who watched the movie Jaws
6*34b254c5SRichard Tran Mills  * - the number of swimmers in the water
7*34b254c5SRichard Tran Mills  * - the average temperature of the day
8*34b254c5SRichard Tran Mills  * - the price of your favorite tech stock of the day (totally uncorrelated variable) */
9*34b254c5SRichard Tran Mills 
10*34b254c5SRichard Tran Mills static char help[] = "Tests basic creation and destruction of PetscRegressor objects.\n\n";
11*34b254c5SRichard Tran Mills 
12*34b254c5SRichard Tran Mills #include <petscregressor.h>
13*34b254c5SRichard Tran Mills 
main(int argc,char ** args)14*34b254c5SRichard Tran Mills int main(int argc, char **args)
15*34b254c5SRichard Tran Mills {
16*34b254c5SRichard Tran Mills   PetscRegressor regressor;
17*34b254c5SRichard Tran Mills   PetscMPIInt    rank;
18*34b254c5SRichard Tran Mills   Mat            X;
19*34b254c5SRichard Tran Mills   Vec            y, y_predicted, coefficients;
20*34b254c5SRichard Tran Mills   PetscScalar    intercept;
21*34b254c5SRichard Tran Mills 
22*34b254c5SRichard Tran Mills   PetscScalar y_array[20] = {98, 53, 39, 127, 73, 42, 71, 61, 83, 74, 85, 82, 62, 60, 43, 69, 67, 69, 85, 3}; // Number of shark attacks
23*34b254c5SRichard Tran Mills 
24*34b254c5SRichard Tran Mills   PetscScalar X_array[80] = {37.92934, 513, 92.89899, 137.2139, // % watched Jaws, #swimmers, temperature, stock price
25*34b254c5SRichard Tran Mills                              52.77429, 451, 87.86271, 145.7987, //
26*34b254c5SRichard Tran Mills                              60.84441, 456, 88.28927, 149.7299, //
27*34b254c5SRichard Tran Mills                              26.54302, 546, 89.43875, 147.1180, //
28*34b254c5SRichard Tran Mills                              54.29125, 431, 88.01132, 124.3068, //
29*34b254c5SRichard Tran Mills                              55.06056, 355, 88.06297, 114.1730, //
30*34b254c5SRichard Tran Mills                              44.25260, 557, 87.78536, 112.5773, //
31*34b254c5SRichard Tran Mills                              44.53368, 398, 87.49603, 125.1628, //
32*34b254c5SRichard Tran Mills                              44.35548, 498, 88.95234, 124.8483, //
33*34b254c5SRichard Tran Mills                              41.09962, 406, 89.00630, 115.9223, //
34*34b254c5SRichard Tran Mills                              45.22807, 610, 86.38794, 148.1111, //
35*34b254c5SRichard Tran Mills                              40.01614, 452, 88.83585, 131.7050, //
36*34b254c5SRichard Tran Mills                              42.23746, 429, 87.78222, 106.3717, //
37*34b254c5SRichard Tran Mills                              50.64459, 450, 87.97008, 121.1523, //
38*34b254c5SRichard Tran Mills                              59.59494, 337, 89.67538, 145.7158, //
39*34b254c5SRichard Tran Mills                              48.89715, 383, 91.12611, 123.3896, //
40*34b254c5SRichard Tran Mills                              44.88990, 282, 93.29563, 145.4085, //
41*34b254c5SRichard Tran Mills                              40.88805, 366, 88.45329, 129.8872, //
42*34b254c5SRichard Tran Mills                              41.62828, 471, 93.21182, 131.5871, //
43*34b254c5SRichard Tran Mills                              74.15835, 453, 87.68438, 143.4579};
44*34b254c5SRichard Tran Mills 
45*34b254c5SRichard Tran Mills   PetscInt rows_ix[20] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19};
46*34b254c5SRichard Tran Mills   PetscInt cols_ix[4]  = {0, 1, 2, 3};
47*34b254c5SRichard Tran Mills 
48*34b254c5SRichard Tran Mills   PetscCall(PetscInitialize(&argc, &args, (char *)0, help));
49*34b254c5SRichard Tran Mills   PetscCallMPI(MPI_Comm_rank(PETSC_COMM_WORLD, &rank));
50*34b254c5SRichard Tran Mills 
51*34b254c5SRichard Tran Mills   PetscCall(VecCreate(PETSC_COMM_WORLD, &y));
52*34b254c5SRichard Tran Mills   PetscCall(VecSetSizes(y, PETSC_DECIDE, 20));
53*34b254c5SRichard Tran Mills   PetscCall(VecSetFromOptions(y));
54*34b254c5SRichard Tran Mills   PetscCall(VecDuplicate(y, &y_predicted));
55*34b254c5SRichard Tran Mills   PetscCall(MatCreate(PETSC_COMM_WORLD, &X));
56*34b254c5SRichard Tran Mills   PetscCall(MatSetSizes(X, PETSC_DECIDE, PETSC_DECIDE, 20, 4));
57*34b254c5SRichard Tran Mills   PetscCall(MatSetFromOptions(X));
58*34b254c5SRichard Tran Mills   PetscCall(MatSetUp(X));
59*34b254c5SRichard Tran Mills 
60*34b254c5SRichard Tran Mills   if (!rank) {
61*34b254c5SRichard Tran Mills     PetscCall(VecSetValues(y, 20, rows_ix, y_array, INSERT_VALUES));
62*34b254c5SRichard Tran Mills     PetscCall(MatSetValues(X, 20, rows_ix, 4, cols_ix, X_array, ADD_VALUES));
63*34b254c5SRichard Tran Mills   }
64*34b254c5SRichard Tran Mills   PetscCall(VecAssemblyBegin(y));
65*34b254c5SRichard Tran Mills   PetscCall(VecAssemblyEnd(y));
66*34b254c5SRichard Tran Mills   PetscCall(MatAssemblyBegin(X, MAT_FINAL_ASSEMBLY));
67*34b254c5SRichard Tran Mills   PetscCall(MatAssemblyEnd(X, MAT_FINAL_ASSEMBLY));
68*34b254c5SRichard Tran Mills 
69*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorCreate(PETSC_COMM_WORLD, &regressor));
70*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorSetType(regressor, PETSCREGRESSORLINEAR));
71*34b254c5SRichard Tran Mills   PetscRegressorSetFromOptions(regressor);
72*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorFit(regressor, X, y));
73*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorPredict(regressor, X, y_predicted));
74*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorLinearGetIntercept(regressor, &intercept));
75*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorLinearGetCoefficients(regressor, &coefficients));
76*34b254c5SRichard Tran Mills 
77*34b254c5SRichard Tran Mills   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Intercept is %lf\n", intercept));
78*34b254c5SRichard Tran Mills   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Coefficients are\n"));
79*34b254c5SRichard Tran Mills   PetscCall(VecView(coefficients, PETSC_VIEWER_STDOUT_WORLD));
80*34b254c5SRichard Tran Mills   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Predicted values are\n"));
81*34b254c5SRichard Tran Mills   PetscCall(VecView(y_predicted, PETSC_VIEWER_STDOUT_WORLD));
82*34b254c5SRichard Tran Mills 
83*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorDestroy(&regressor));
84*34b254c5SRichard Tran Mills 
85*34b254c5SRichard Tran Mills   PetscCall(PetscFinalize());
86*34b254c5SRichard Tran Mills   return 0;
87*34b254c5SRichard Tran Mills }
88