1 const char help[] = "Test TAOLMVM on a least-squares problem";
2
3 #include <petsctao.h>
4 #include <petscdevice.h>
5
6 typedef struct _n_AppCtx {
7 Mat A;
8 Vec b;
9 Vec r;
10 } AppCtx;
11
LSObjAndGrad(Tao tao,Vec x,PetscReal * obj,Vec g,void * _ctx)12 static PetscErrorCode LSObjAndGrad(Tao tao, Vec x, PetscReal *obj, Vec g, void *_ctx)
13 {
14 PetscFunctionBegin;
15 AppCtx *ctx = (AppCtx *)_ctx;
16 PetscCall(VecAXPBY(ctx->r, -1.0, 0.0, ctx->b));
17 PetscCall(MatMultAdd(ctx->A, x, ctx->r, ctx->r));
18 PetscCall(VecDotRealPart(ctx->r, ctx->r, obj));
19 *obj *= 0.5;
20 PetscCall(MatMultTranspose(ctx->A, ctx->r, g));
21 PetscFunctionReturn(PETSC_SUCCESS);
22 }
23
main(int argc,char ** argv)24 int main(int argc, char **argv)
25 {
26 PetscCall(PetscInitialize(&argc, &argv, NULL, help));
27 MPI_Comm comm = PETSC_COMM_WORLD;
28 AppCtx ctx;
29 Vec sol;
30 PetscBool flg, cuda = PETSC_FALSE;
31
32 PetscInt M = 10;
33 PetscInt N = 10;
34 PetscOptionsBegin(comm, "", help, "TAO");
35 PetscCall(PetscOptionsInt("-m", "data size", NULL, M, &M, NULL));
36 PetscCall(PetscOptionsInt("-n", "data size", NULL, N, &N, NULL));
37 PetscCall(PetscOptionsGetBool(NULL, NULL, "-cuda", &cuda, &flg));
38 PetscOptionsEnd();
39
40 if (cuda) {
41 VecType vec_type;
42 PetscCall(VecCreateSeqCUDA(comm, N, &ctx.b));
43 PetscCall(VecGetType(ctx.b, &vec_type));
44 PetscCall(MatCreateDenseFromVecType(comm, vec_type, M, N, PETSC_DECIDE, PETSC_DECIDE, -1, NULL, &ctx.A));
45 PetscCall(MatCreateVecs(ctx.A, &sol, NULL));
46 } else {
47 PetscCall(MatCreateDense(comm, PETSC_DECIDE, PETSC_DECIDE, M, N, NULL, &ctx.A));
48 PetscCall(MatCreateVecs(ctx.A, &sol, &ctx.b));
49 }
50 PetscCall(VecDuplicate(ctx.b, &ctx.r));
51 PetscCall(VecZeroEntries(sol));
52
53 PetscRandom rand;
54 PetscCall(PetscRandomCreate(comm, &rand));
55 PetscCall(PetscRandomSetFromOptions(rand));
56 PetscCall(MatSetRandom(ctx.A, rand));
57 PetscCall(VecSetRandom(ctx.b, rand));
58 PetscCall(PetscRandomDestroy(&rand));
59
60 Tao tao;
61 PetscCall(TaoCreate(comm, &tao));
62 PetscCall(TaoSetSolution(tao, sol));
63 PetscCall(TaoSetObjectiveAndGradient(tao, NULL, LSObjAndGrad, &ctx));
64 PetscCall(TaoSetType(tao, TAOLMVM));
65 PetscCall(TaoSetFromOptions(tao));
66 PetscCall(TaoSolve(tao));
67 PetscCall(TaoDestroy(&tao));
68
69 PetscCall(VecDestroy(&ctx.r));
70 PetscCall(VecDestroy(&sol));
71 PetscCall(VecDestroy(&ctx.b));
72 PetscCall(MatDestroy(&ctx.A));
73
74 PetscCall(PetscFinalize());
75 return 0;
76 }
77
78 /*TEST
79
80 build:
81 requires: !complex !__float128 !single !defined(PETSC_USE_64BIT_INDICES)
82
83 test:
84 suffix: 0
85 args: -tao_monitor -tao_ls_gtol 1.e-6 -tao_view -tao_lmvm_mat_lmvm_hist_size 20 -tao_ls_type more-thuente -tao_lmvm_mat_lmvm_scale_type none -tao_lmvm_mat_type lmvmbfgs
86
87 test:
88 suffix: 1
89 args: -tao_monitor -tao_ls_gtol 1.e-6 -tao_view -tao_lmvm_mat_lmvm_hist_size 20 -tao_ls_type more-thuente -tao_lmvm_mat_lmvm_scale_type none -tao_lmvm_mat_type lmvmdbfgs
90
91 test:
92 suffix: 2
93 args: -tao_monitor -tao_ls_gtol 1.e-6 -tao_view -tao_lmvm_mat_lmvm_hist_size 20 -tao_ls_type more-thuente -tao_lmvm_mat_type lmvmdbfgs -tao_lmvm_mat_lmvm_scale_type none -tao_lmvm_mat_lbfgs_type {{inplace reorder}}
94
95 TEST*/
96