1 const char help[] = "Profile the performance of MATLMVM MatSolve() in a loop";
2
3 #include <petscksp.h>
4 #include <petscmath.h>
5
main(int argc,char ** argv)6 int main(int argc, char **argv)
7 {
8 PetscInt n = 1000;
9 PetscInt n_epochs = 10;
10 PetscInt n_iters = 10;
11 Vec x, g, dx, df, p;
12 PetscRandom rand;
13 PetscLogStage matsolve_loop, main_stage;
14 Mat B, J0;
15
16 PetscCall(PetscInitialize(&argc, &argv, NULL, help));
17 PetscCall(KSPInitializePackage());
18 PetscOptionsBegin(PETSC_COMM_WORLD, NULL, help, "KSP");
19 PetscCall(PetscOptionsInt("-n", "Vector size", __FILE__, n, &n, NULL));
20 PetscCall(PetscOptionsInt("-epochs", "Number of epochs", __FILE__, n_epochs, &n_epochs, NULL));
21 PetscCall(PetscOptionsInt("-iters", "Number of iterations per epoch", __FILE__, n_iters, &n_iters, NULL));
22 PetscOptionsEnd();
23 PetscCall(VecCreateMPI(PETSC_COMM_WORLD, PETSC_DETERMINE, n, &x));
24 PetscCall(VecSetFromOptions(x));
25 PetscCall(VecDuplicate(x, &g));
26 PetscCall(VecDuplicate(x, &dx));
27 PetscCall(VecDuplicate(x, &df));
28 PetscCall(VecDuplicate(x, &p));
29 PetscCall(MatCreate(PETSC_COMM_WORLD, &B));
30 PetscCall(MatSetType(B, MATLMVMBFGS));
31 PetscCall(MatLMVMAllocate(B, x, g));
32 PetscCall(MatSetFromOptions(B));
33 PetscCall(MatLMVMGetJ0(B, &J0));
34 PetscCall(MatZeroEntries(J0));
35 PetscCall(MatShift(J0, 1.0));
36 PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &rand));
37 PetscCall(PetscRandomSetInterval(rand, -1.0, 1.0));
38 PetscCall(PetscRandomSetFromOptions(rand));
39 PetscCall(PetscLogStageRegister("LMVM MatSolve Loop", &matsolve_loop));
40 PetscCall(PetscLogStageGetId("Main Stage", &main_stage));
41 PetscCall(PetscLogStageSetVisible(main_stage, PETSC_FALSE));
42 for (PetscInt epoch = 0; epoch < n_epochs + 1; epoch++) {
43 PetscScalar dot;
44 PetscReal xscale, fscale, absdot;
45 PetscInt history_size;
46
47 PetscCall(VecSetRandom(dx, rand));
48 PetscCall(VecSetRandom(df, rand));
49 PetscCall(VecDot(dx, df, &dot));
50 absdot = PetscAbsScalar(dot);
51 PetscCall(VecSetRandom(x, rand));
52 PetscCall(VecSetRandom(g, rand));
53 xscale = 1.0;
54 fscale = absdot / PetscRealPart(dot);
55 PetscCall(MatLMVMGetHistorySize(B, &history_size));
56
57 PetscCall(MatLMVMUpdate(B, x, g));
58 for (PetscInt iter = 0; iter < history_size; iter++, xscale *= -1.0, fscale *= -1.0) {
59 PetscCall(VecAXPY(x, xscale, dx));
60 PetscCall(VecAXPY(g, fscale, df));
61 PetscCall(MatLMVMUpdate(B, x, g));
62 PetscCall(MatSolve(B, g, p));
63 }
64 if (epoch > 0) PetscCall(PetscLogStagePush(matsolve_loop));
65 for (PetscInt iter = 0; iter < n_iters; iter++, xscale *= -1.0, fscale *= -1.0) {
66 PetscCall(VecAXPY(x, xscale, dx));
67 PetscCall(VecAXPY(g, fscale, df));
68 PetscCall(MatLMVMUpdate(B, x, g));
69 PetscCall(MatSolve(B, g, p));
70 }
71 PetscCall(MatLMVMReset(B, PETSC_FALSE));
72 if (epoch > 0) PetscCall(PetscLogStagePop());
73 }
74 PetscCall(PetscViewerPushFormat(PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD), PETSC_VIEWER_ASCII_INFO_DETAIL));
75 PetscCall(MatView(B, PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD)));
76 PetscCall(PetscViewerPopFormat(PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD)));
77 PetscCall(PetscRandomDestroy(&rand));
78 PetscCall(MatDestroy(&B));
79 PetscCall(VecDestroy(&p));
80 PetscCall(VecDestroy(&df));
81 PetscCall(VecDestroy(&dx));
82 PetscCall(VecDestroy(&g));
83 PetscCall(VecDestroy(&x));
84 PetscCall(PetscFinalize());
85 return 0;
86 }
87
88 /*TEST
89
90 test:
91 suffix: 0
92 args: -mat_lmvm_scale_type none
93
94 TEST*/
95