1 const char help[] = "Profile the performance of MATLMVM MatSolve() in a loop"; 2 3 #include <petscksp.h> 4 #include <petscmath.h> 5 6 int main(int argc, char **argv) 7 { 8 PetscInt n = 1000; 9 PetscInt n_epochs = 10; 10 PetscInt n_iters = 10; 11 Vec x, g, dx, df, p; 12 PetscRandom rand; 13 PetscLogStage matsolve_loop, main_stage; 14 Mat B, J0; 15 16 PetscCall(PetscInitialize(&argc, &argv, NULL, help)); 17 PetscCall(KSPInitializePackage()); 18 PetscOptionsBegin(PETSC_COMM_WORLD, NULL, help, "KSP"); 19 PetscCall(PetscOptionsInt("-n", "Vector size", __FILE__, n, &n, NULL)); 20 PetscCall(PetscOptionsInt("-epochs", "Number of epochs", __FILE__, n_epochs, &n_epochs, NULL)); 21 PetscCall(PetscOptionsInt("-iters", "Number of iterations per epoch", __FILE__, n_iters, &n_iters, NULL)); 22 PetscOptionsEnd(); 23 PetscCall(VecCreateMPI(PETSC_COMM_WORLD, PETSC_DETERMINE, n, &x)); 24 PetscCall(VecSetFromOptions(x)); 25 PetscCall(VecDuplicate(x, &g)); 26 PetscCall(VecDuplicate(x, &dx)); 27 PetscCall(VecDuplicate(x, &df)); 28 PetscCall(VecDuplicate(x, &p)); 29 PetscCall(MatCreate(PETSC_COMM_WORLD, &B)); 30 PetscCall(MatSetType(B, MATLMVMBFGS)); 31 PetscCall(MatLMVMAllocate(B, x, g)); 32 PetscCall(MatSetFromOptions(B)); 33 PetscCall(MatLMVMGetJ0(B, &J0)); 34 PetscCall(MatZeroEntries(J0)); 35 PetscCall(MatShift(J0, 1.0)); 36 PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &rand)); 37 PetscCall(PetscRandomSetInterval(rand, -1.0, 1.0)); 38 PetscCall(PetscRandomSetFromOptions(rand)); 39 PetscCall(PetscLogStageRegister("LMVM MatSolve Loop", &matsolve_loop)); 40 PetscCall(PetscLogStageGetId("Main Stage", &main_stage)); 41 PetscCall(PetscLogStageSetVisible(main_stage, PETSC_FALSE)); 42 for (PetscInt epoch = 0; epoch < n_epochs + 1; epoch++) { 43 PetscScalar dot; 44 PetscReal xscale, fscale, absdot; 45 PetscInt history_size; 46 47 PetscCall(VecSetRandom(dx, rand)); 48 PetscCall(VecSetRandom(df, rand)); 49 PetscCall(VecDot(dx, df, &dot)); 50 absdot = PetscAbsScalar(dot); 51 PetscCall(VecSetRandom(x, rand)); 52 PetscCall(VecSetRandom(g, rand)); 53 xscale = 1.0; 54 fscale = absdot / PetscRealPart(dot); 55 PetscCall(MatLMVMGetHistorySize(B, &history_size)); 56 57 PetscCall(MatLMVMUpdate(B, x, g)); 58 for (PetscInt iter = 0; iter < history_size; iter++, xscale *= -1.0, fscale *= -1.0) { 59 PetscCall(VecAXPY(x, xscale, dx)); 60 PetscCall(VecAXPY(g, fscale, df)); 61 PetscCall(MatLMVMUpdate(B, x, g)); 62 PetscCall(MatSolve(B, g, p)); 63 } 64 if (epoch > 0) PetscCall(PetscLogStagePush(matsolve_loop)); 65 for (PetscInt iter = 0; iter < n_iters; iter++, xscale *= -1.0, fscale *= -1.0) { 66 PetscCall(VecAXPY(x, xscale, dx)); 67 PetscCall(VecAXPY(g, fscale, df)); 68 PetscCall(MatLMVMUpdate(B, x, g)); 69 PetscCall(MatSolve(B, g, p)); 70 } 71 PetscCall(MatLMVMReset(B, PETSC_FALSE)); 72 if (epoch > 0) PetscCall(PetscLogStagePop()); 73 } 74 PetscCall(PetscViewerPushFormat(PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD), PETSC_VIEWER_ASCII_INFO_DETAIL)); 75 PetscCall(MatView(B, PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD))); 76 PetscCall(PetscViewerPopFormat(PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD))); 77 PetscCall(PetscRandomDestroy(&rand)); 78 PetscCall(MatDestroy(&B)); 79 PetscCall(VecDestroy(&p)); 80 PetscCall(VecDestroy(&df)); 81 PetscCall(VecDestroy(&dx)); 82 PetscCall(VecDestroy(&g)); 83 PetscCall(VecDestroy(&x)); 84 PetscCall(PetscFinalize()); 85 return 0; 86 } 87 88 /*TEST 89 90 test: 91 suffix: 0 92 args: -mat_lmvm_scale_type none 93 94 TEST*/ 95