1 const char help[] = "Test TAOLMVM on a least-squares problem"; 2 3 #include <petsctao.h> 4 #include <petscdevice.h> 5 6 typedef struct _n_AppCtx { 7 Mat A; 8 Vec b; 9 Vec r; 10 } AppCtx; 11 12 static PetscErrorCode LSObjAndGrad(Tao tao, Vec x, PetscReal *obj, Vec g, void *_ctx) 13 { 14 PetscFunctionBegin; 15 AppCtx *ctx = (AppCtx *)_ctx; 16 PetscCall(VecAXPBY(ctx->r, -1.0, 0.0, ctx->b)); 17 PetscCall(MatMultAdd(ctx->A, x, ctx->r, ctx->r)); 18 PetscCall(VecDotRealPart(ctx->r, ctx->r, obj)); 19 *obj *= 0.5; 20 PetscCall(MatMultTranspose(ctx->A, ctx->r, g)); 21 PetscFunctionReturn(PETSC_SUCCESS); 22 } 23 24 int main(int argc, char **argv) 25 { 26 PetscCall(PetscInitialize(&argc, &argv, NULL, help)); 27 MPI_Comm comm = PETSC_COMM_WORLD; 28 AppCtx ctx; 29 Vec sol; 30 PetscBool flg, cuda = PETSC_FALSE; 31 32 PetscInt M = 10; 33 PetscInt N = 10; 34 PetscOptionsBegin(comm, "", help, "TAO"); 35 PetscCall(PetscOptionsInt("-m", "data size", NULL, M, &M, NULL)); 36 PetscCall(PetscOptionsInt("-n", "data size", NULL, N, &N, NULL)); 37 PetscCall(PetscOptionsGetBool(NULL, NULL, "-cuda", &cuda, &flg)); 38 PetscOptionsEnd(); 39 40 if (cuda) { 41 VecType vec_type; 42 PetscCall(VecCreateSeqCUDA(comm, N, &ctx.b)); 43 PetscCall(VecGetType(ctx.b, &vec_type)); 44 PetscCall(MatCreateDenseFromVecType(comm, vec_type, M, N, PETSC_DECIDE, PETSC_DECIDE, -1, NULL, &ctx.A)); 45 PetscCall(MatCreateVecs(ctx.A, &sol, NULL)); 46 } else { 47 PetscCall(MatCreateDense(comm, PETSC_DECIDE, PETSC_DECIDE, M, N, NULL, &ctx.A)); 48 PetscCall(MatCreateVecs(ctx.A, &sol, &ctx.b)); 49 } 50 PetscCall(VecDuplicate(ctx.b, &ctx.r)); 51 PetscCall(VecZeroEntries(sol)); 52 53 PetscRandom rand; 54 PetscCall(PetscRandomCreate(comm, &rand)); 55 PetscCall(PetscRandomSetFromOptions(rand)); 56 PetscCall(MatSetRandom(ctx.A, rand)); 57 PetscCall(VecSetRandom(ctx.b, rand)); 58 PetscCall(PetscRandomDestroy(&rand)); 59 60 Tao tao; 61 PetscCall(TaoCreate(comm, &tao)); 62 PetscCall(TaoSetSolution(tao, sol)); 63 PetscCall(TaoSetObjectiveAndGradient(tao, NULL, LSObjAndGrad, &ctx)); 64 PetscCall(TaoSetType(tao, TAOLMVM)); 65 PetscCall(TaoSetFromOptions(tao)); 66 PetscCall(TaoSolve(tao)); 67 PetscCall(TaoDestroy(&tao)); 68 69 PetscCall(VecDestroy(&ctx.r)); 70 PetscCall(VecDestroy(&sol)); 71 PetscCall(VecDestroy(&ctx.b)); 72 PetscCall(MatDestroy(&ctx.A)); 73 74 PetscCall(PetscFinalize()); 75 return 0; 76 } 77 78 /*TEST 79 80 build: 81 requires: !complex !__float128 !single !defined(PETSC_USE_64BIT_INDICES) 82 83 test: 84 suffix: 0 85 args: -tao_monitor -tao_ls_gtol 1.e-6 -tao_view -tao_lmvm_mat_lmvm_hist_size 20 -tao_ls_type more-thuente -tao_lmvm_mat_lmvm_scale_type none -tao_lmvm_mat_type lmvmbfgs 86 87 test: 88 suffix: 1 89 args: -tao_monitor -tao_ls_gtol 1.e-6 -tao_view -tao_lmvm_mat_lmvm_hist_size 20 -tao_ls_type more-thuente -tao_lmvm_mat_lmvm_scale_type none -tao_lmvm_mat_type lmvmdbfgs 90 91 test: 92 suffix: 2 93 args: -tao_monitor -tao_ls_gtol 1.e-6 -tao_view -tao_lmvm_mat_lmvm_hist_size 20 -tao_ls_type more-thuente -tao_lmvm_mat_type lmvmdbfgs -tao_lmvm_mat_lmvm_scale_type none -tao_lmvm_mat_lbfgs_type {{inplace reorder}} 94 95 TEST*/ 96