xref: /petsc/src/ksp/ksp/utils/lmvm/tests/solve_performance.c (revision 58bddbc0aeb8e2276be3739270a4176cb222ba3a)
1 const char help[] = "Profile the performance of MATLMVM MatSolve() in a loop";
2 
3 #include <petscksp.h>
4 #include <petscmath.h>
5 
main(int argc,char ** argv)6 int main(int argc, char **argv)
7 {
8   PetscInt      n        = 1000;
9   PetscInt      n_epochs = 10;
10   PetscInt      n_iters  = 10;
11   Vec           x, g, dx, df, p;
12   PetscRandom   rand;
13   PetscLogStage matsolve_loop, main_stage;
14   Mat           B, J0;
15 
16   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
17   PetscCall(KSPInitializePackage());
18   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, help, "KSP");
19   PetscCall(PetscOptionsInt("-n", "Vector size", __FILE__, n, &n, NULL));
20   PetscCall(PetscOptionsInt("-epochs", "Number of epochs", __FILE__, n_epochs, &n_epochs, NULL));
21   PetscCall(PetscOptionsInt("-iters", "Number of iterations per epoch", __FILE__, n_iters, &n_iters, NULL));
22   PetscOptionsEnd();
23   PetscCall(VecCreateMPI(PETSC_COMM_WORLD, PETSC_DETERMINE, n, &x));
24   PetscCall(VecSetFromOptions(x));
25   PetscCall(VecDuplicate(x, &g));
26   PetscCall(VecDuplicate(x, &dx));
27   PetscCall(VecDuplicate(x, &df));
28   PetscCall(VecDuplicate(x, &p));
29   PetscCall(MatCreate(PETSC_COMM_WORLD, &B));
30   PetscCall(MatSetType(B, MATLMVMBFGS));
31   PetscCall(MatLMVMAllocate(B, x, g));
32   PetscCall(MatSetFromOptions(B));
33   PetscCall(MatLMVMGetJ0(B, &J0));
34   PetscCall(MatZeroEntries(J0));
35   PetscCall(MatShift(J0, 1.0));
36   PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &rand));
37   PetscCall(PetscRandomSetInterval(rand, -1.0, 1.0));
38   PetscCall(PetscRandomSetFromOptions(rand));
39   PetscCall(PetscLogStageRegister("LMVM MatSolve Loop", &matsolve_loop));
40   PetscCall(PetscLogStageGetId("Main Stage", &main_stage));
41   PetscCall(PetscLogStageSetVisible(main_stage, PETSC_FALSE));
42   for (PetscInt epoch = 0; epoch < n_epochs + 1; epoch++) {
43     PetscScalar dot;
44     PetscReal   xscale, fscale, absdot;
45     PetscInt    history_size;
46 
47     PetscCall(VecSetRandom(dx, rand));
48     PetscCall(VecSetRandom(df, rand));
49     PetscCall(VecDot(dx, df, &dot));
50     absdot = PetscAbsScalar(dot);
51     PetscCall(VecSetRandom(x, rand));
52     PetscCall(VecSetRandom(g, rand));
53     xscale = 1.0;
54     fscale = absdot / PetscRealPart(dot);
55     PetscCall(MatLMVMGetHistorySize(B, &history_size));
56 
57     PetscCall(MatLMVMUpdate(B, x, g));
58     for (PetscInt iter = 0; iter < history_size; iter++, xscale *= -1.0, fscale *= -1.0) {
59       PetscCall(VecAXPY(x, xscale, dx));
60       PetscCall(VecAXPY(g, fscale, df));
61       PetscCall(MatLMVMUpdate(B, x, g));
62       PetscCall(MatSolve(B, g, p));
63     }
64     if (epoch > 0) PetscCall(PetscLogStagePush(matsolve_loop));
65     for (PetscInt iter = 0; iter < n_iters; iter++, xscale *= -1.0, fscale *= -1.0) {
66       PetscCall(VecAXPY(x, xscale, dx));
67       PetscCall(VecAXPY(g, fscale, df));
68       PetscCall(MatLMVMUpdate(B, x, g));
69       PetscCall(MatSolve(B, g, p));
70     }
71     PetscCall(MatLMVMReset(B, PETSC_FALSE));
72     if (epoch > 0) PetscCall(PetscLogStagePop());
73   }
74   PetscCall(PetscViewerPushFormat(PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD), PETSC_VIEWER_ASCII_INFO_DETAIL));
75   PetscCall(MatView(B, PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD)));
76   PetscCall(PetscViewerPopFormat(PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD)));
77   PetscCall(PetscRandomDestroy(&rand));
78   PetscCall(MatDestroy(&B));
79   PetscCall(VecDestroy(&p));
80   PetscCall(VecDestroy(&df));
81   PetscCall(VecDestroy(&dx));
82   PetscCall(VecDestroy(&g));
83   PetscCall(VecDestroy(&x));
84   PetscCall(PetscFinalize());
85   return 0;
86 }
87 
88 /*TEST
89 
90   test:
91     suffix: 0
92     args: -mat_lmvm_scale_type none
93 
94 TEST*/
95