1 const char help[] = "Profile the performance of MATLMVM MatSolve() in a loop"; 2 3 #include <petscksp.h> 4 #include <petscmath.h> 5 6 int main(int argc, char **argv) 7 { 8 PetscInt n = 1000; 9 PetscInt n_epochs = 10; 10 PetscInt n_iters = 10; 11 Vec x, g, dx, df, p; 12 PetscRandom rand; 13 PetscLogStage matsolve_loop, main_stage; 14 Mat B; 15 16 PetscCall(PetscInitialize(&argc, &argv, NULL, help)); 17 PetscOptionsBegin(PETSC_COMM_WORLD, NULL, help, "KSP"); 18 PetscCall(PetscOptionsInt("-n", "Vector size", __FILE__, n, &n, NULL)); 19 PetscCall(PetscOptionsInt("-epochs", "Number of epochs", __FILE__, n_epochs, &n_epochs, NULL)); 20 PetscCall(PetscOptionsInt("-iters", "Number of iterations per epoch", __FILE__, n_iters, &n_iters, NULL)); 21 PetscOptionsEnd(); 22 PetscCall(VecCreateMPI(PETSC_COMM_WORLD, PETSC_DETERMINE, n, &x)); 23 PetscCall(VecSetFromOptions(x)); 24 PetscCall(VecDuplicate(x, &g)); 25 PetscCall(VecDuplicate(x, &dx)); 26 PetscCall(VecDuplicate(x, &df)); 27 PetscCall(VecDuplicate(x, &p)); 28 PetscCall(MatCreateLMVMBFGS(PETSC_COMM_WORLD, PETSC_DETERMINE, n, &B)); 29 PetscCall(MatSetFromOptions(B)); 30 PetscCall(MatLMVMAllocate(B, x, g)); 31 PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &rand)); 32 PetscCall(PetscRandomSetInterval(rand, -1.0, 1.0)); 33 PetscCall(PetscRandomSetFromOptions(rand)); 34 PetscCall(PetscLogStageRegister("LMVM MatSolve Loop", &matsolve_loop)); 35 PetscCall(PetscLogStageGetId("Main Stage", &main_stage)); 36 PetscCall(PetscLogStageSetVisible(main_stage, PETSC_FALSE)); 37 for (PetscInt epoch = 0; epoch < n_epochs + 1; epoch++) { 38 PetscScalar dot; 39 PetscReal xscale, fscale, absdot; 40 PetscInt history_size; 41 42 PetscCall(VecSetRandom(dx, rand)); 43 PetscCall(VecSetRandom(df, rand)); 44 PetscCall(VecDot(dx, df, &dot)); 45 absdot = PetscAbsScalar(dot); 46 PetscCall(VecSetRandom(x, rand)); 47 PetscCall(VecSetRandom(g, rand)); 48 xscale = 1.0; 49 fscale = absdot / PetscRealPart(dot); 50 PetscCall(MatLMVMGetHistorySize(B, &history_size)); 51 52 PetscCall(MatLMVMUpdate(B, x, g)); 53 for (PetscInt iter = 0; iter < history_size; iter++, xscale *= -1.0, fscale *= -1.0) { 54 PetscCall(VecAXPY(x, xscale, dx)); 55 PetscCall(VecAXPY(g, fscale, df)); 56 PetscCall(MatLMVMUpdate(B, x, g)); 57 PetscCall(MatSolve(B, g, p)); 58 } 59 if (epoch > 0) PetscCall(PetscLogStagePush(matsolve_loop)); 60 for (PetscInt iter = 0; iter < n_iters; iter++, xscale *= -1.0, fscale *= -1.0) { 61 PetscCall(VecAXPY(x, xscale, dx)); 62 PetscCall(VecAXPY(g, fscale, df)); 63 PetscCall(MatLMVMUpdate(B, x, g)); 64 PetscCall(MatSolve(B, g, p)); 65 } 66 PetscCall(MatLMVMReset(B, PETSC_FALSE)); 67 if (epoch > 0) PetscCall(PetscLogStagePop()); 68 } 69 PetscCall(MatView(B, PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD))); 70 PetscCall(PetscRandomDestroy(&rand)); 71 PetscCall(MatDestroy(&B)); 72 PetscCall(VecDestroy(&p)); 73 PetscCall(VecDestroy(&df)); 74 PetscCall(VecDestroy(&dx)); 75 PetscCall(VecDestroy(&g)); 76 PetscCall(VecDestroy(&x)); 77 PetscCall(PetscFinalize()); 78 return 0; 79 } 80 81 /*TEST 82 83 test: 84 suffix: 0 85 args: -mat_lmvm_scale_type none 86 87 TEST*/ 88