xref: /petsc/src/tao/unconstrained/impls/lmvm/tests/ex1.c (revision 5fe01c21d727d5e61b58091b225b37a3f96d9fc7)
1 const char help[] = "Test TAOLMVM on a least-squares problem";
2 
3 #include <petsctao.h>
4 #include <petscdevice.h>
5 
6 typedef struct _n_AppCtx {
7   Mat A;
8   Vec b;
9   Vec r;
10 } AppCtx;
11 
LSObjAndGrad(Tao tao,Vec x,PetscReal * obj,Vec g,void * _ctx)12 static PetscErrorCode LSObjAndGrad(Tao tao, Vec x, PetscReal *obj, Vec g, void *_ctx)
13 {
14   PetscFunctionBegin;
15   AppCtx *ctx = (AppCtx *)_ctx;
16   PetscCall(VecAXPBY(ctx->r, -1.0, 0.0, ctx->b));
17   PetscCall(MatMultAdd(ctx->A, x, ctx->r, ctx->r));
18   PetscCall(VecDotRealPart(ctx->r, ctx->r, obj));
19   *obj *= 0.5;
20   PetscCall(MatMultTranspose(ctx->A, ctx->r, g));
21   PetscFunctionReturn(PETSC_SUCCESS);
22 }
23 
main(int argc,char ** argv)24 int main(int argc, char **argv)
25 {
26   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
27   MPI_Comm  comm = PETSC_COMM_WORLD;
28   AppCtx    ctx;
29   Vec       sol;
30   PetscBool flg, cuda = PETSC_FALSE;
31 
32   PetscInt M = 10;
33   PetscInt N = 10;
34   PetscOptionsBegin(comm, "", help, "TAO");
35   PetscCall(PetscOptionsInt("-m", "data size", NULL, M, &M, NULL));
36   PetscCall(PetscOptionsInt("-n", "data size", NULL, N, &N, NULL));
37   PetscCall(PetscOptionsGetBool(NULL, NULL, "-cuda", &cuda, &flg));
38   PetscOptionsEnd();
39 
40   if (cuda) {
41     VecType vec_type;
42     PetscCall(VecCreateSeqCUDA(comm, N, &ctx.b));
43     PetscCall(VecGetType(ctx.b, &vec_type));
44     PetscCall(MatCreateDenseFromVecType(comm, vec_type, M, N, PETSC_DECIDE, PETSC_DECIDE, -1, NULL, &ctx.A));
45     PetscCall(MatCreateVecs(ctx.A, &sol, NULL));
46   } else {
47     PetscCall(MatCreateDense(comm, PETSC_DECIDE, PETSC_DECIDE, M, N, NULL, &ctx.A));
48     PetscCall(MatCreateVecs(ctx.A, &sol, &ctx.b));
49   }
50   PetscCall(VecDuplicate(ctx.b, &ctx.r));
51   PetscCall(VecZeroEntries(sol));
52 
53   PetscRandom rand;
54   PetscCall(PetscRandomCreate(comm, &rand));
55   PetscCall(PetscRandomSetFromOptions(rand));
56   PetscCall(MatSetRandom(ctx.A, rand));
57   PetscCall(VecSetRandom(ctx.b, rand));
58   PetscCall(PetscRandomDestroy(&rand));
59 
60   Tao tao;
61   PetscCall(TaoCreate(comm, &tao));
62   PetscCall(TaoSetSolution(tao, sol));
63   PetscCall(TaoSetObjectiveAndGradient(tao, NULL, LSObjAndGrad, &ctx));
64   PetscCall(TaoSetType(tao, TAOLMVM));
65   PetscCall(TaoSetFromOptions(tao));
66   PetscCall(TaoSolve(tao));
67   PetscCall(TaoDestroy(&tao));
68 
69   PetscCall(VecDestroy(&ctx.r));
70   PetscCall(VecDestroy(&sol));
71   PetscCall(VecDestroy(&ctx.b));
72   PetscCall(MatDestroy(&ctx.A));
73 
74   PetscCall(PetscFinalize());
75   return 0;
76 }
77 
78 /*TEST
79 
80   build:
81     requires: !complex !__float128 !single !defined(PETSC_USE_64BIT_INDICES)
82 
83   test:
84     suffix: 0
85     args: -tao_monitor -tao_ls_gtol 1.e-6 -tao_view -tao_lmvm_mat_lmvm_hist_size 20 -tao_ls_type more-thuente -tao_lmvm_mat_lmvm_scale_type none -tao_lmvm_mat_type lmvmbfgs
86 
87   test:
88     suffix: 1
89     args: -tao_monitor -tao_ls_gtol 1.e-6 -tao_view -tao_lmvm_mat_lmvm_hist_size 20 -tao_ls_type more-thuente -tao_lmvm_mat_lmvm_scale_type none -tao_lmvm_mat_type lmvmdbfgs
90 
91   test:
92     suffix: 2
93     args: -tao_monitor -tao_ls_gtol 1.e-6 -tao_view -tao_lmvm_mat_lmvm_hist_size 20 -tao_ls_type more-thuente -tao_lmvm_mat_type lmvmdbfgs -tao_lmvm_mat_lmvm_scale_type none -tao_lmvm_mat_lbfgs_type {{inplace reorder}}
94 
95 TEST*/
96