xref: /petsc/src/tao/tutorials/ex4.c (revision 7a46b59513fac54acb317385fa9864cc0988b1fb)
1 static char help[] = "Simple example to test separable objective optimizers.\n";
2 
3 #include <petsc.h>
4 #include <petsctao.h>
5 #include <petscvec.h>
6 #include <petscmath.h>
7 
8 #define NWORKLEFT 4
9 #define NWORKRIGHT 12
10 
11 typedef struct _UserCtx
12 {
13   PetscInt    m;                     /* The row dimension of F */
14   PetscInt    n;                     /* The column dimension of F */
15   PetscInt    matops;                /* Matrix format. 0 for stencil, 1 for random */
16   PetscInt    iter;                  /* Numer of iterations for ADMM */
17   PetscReal   hStart;                /* Starting point for Taylor test */
18   PetscReal   hFactor;               /* Taylor test step factor */
19   PetscReal   hMin;                  /* Taylor test end goal */
20   PetscReal   alpha;                 /* regularization constant applied to || x ||_p */
21   PetscReal   eps;                   /* small constant for approximating gradient of || x ||_1 */
22   PetscReal   mu;                    /* the augmented Lagrangian term in ADMM */
23   PetscReal   abstol;
24   PetscReal   reltol;
25   Mat         F;                     /* matrix in least squares component $(1/2) * || F x - d ||_2^2$ */
26   Mat         W;                     /* Workspace matrix. ATA */
27   Mat         Hm;                    /* Hessian Misfit*/
28   Mat         Hr;                    /* Hessian Reg*/
29   Vec         d;                     /* RHS in least squares component $(1/2) * || F x - d ||_2^2$ */
30   Vec         workLeft[NWORKLEFT];   /* Workspace for temporary vec */
31   Vec         workRight[NWORKRIGHT]; /* Workspace for temporary vec */
32   NormType    p;
33   PetscRandom rctx;
34   PetscBool   taylor;                /* Flag to determine whether to run Taylor test or not */
35   PetscBool   use_admm;              /* Flag to determine whether to run Taylor test or not */
36 }* UserCtx;
37 
38 static PetscErrorCode CreateRHS(UserCtx ctx)
39 {
40   PetscFunctionBegin;
41   /* build the rhs d in ctx */
42   PetscCall(VecCreate(PETSC_COMM_WORLD,&(ctx->d)));
43   PetscCall(VecSetSizes(ctx->d,PETSC_DECIDE,ctx->m));
44   PetscCall(VecSetFromOptions(ctx->d));
45   PetscCall(VecSetRandom(ctx->d,ctx->rctx));
46   PetscFunctionReturn(0);
47 }
48 
49 static PetscErrorCode CreateMatrix(UserCtx ctx)
50 {
51   PetscInt       Istart,Iend,i,j,Ii,gridN,I_n, I_s, I_e, I_w;
52 #if defined(PETSC_USE_LOG)
53   PetscLogStage  stage;
54 #endif
55 
56   PetscFunctionBegin;
57   /* build the matrix F in ctx */
58   PetscCall(MatCreate(PETSC_COMM_WORLD, &(ctx->F)));
59   PetscCall(MatSetSizes(ctx->F,PETSC_DECIDE, PETSC_DECIDE, ctx->m, ctx->n));
60   PetscCall(MatSetType(ctx->F,MATAIJ)); /* TODO: Decide specific SetType other than dummy*/
61   PetscCall(MatMPIAIJSetPreallocation(ctx->F, 5, NULL, 5, NULL)); /*TODO: some number other than 5?*/
62   PetscCall(MatSeqAIJSetPreallocation(ctx->F, 5, NULL));
63   PetscCall(MatSetUp(ctx->F));
64   PetscCall(MatGetOwnershipRange(ctx->F,&Istart,&Iend));
65   PetscCall(PetscLogStageRegister("Assembly", &stage));
66   PetscCall(PetscLogStagePush(stage));
67 
68   /* Set matrix elements in  2-D five point stencil format. */
69   if (!(ctx->matops)) {
70     PetscCheck(ctx->m == ctx->n,PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Stencil matrix must be square");
71     gridN = (PetscInt) PetscSqrtReal((PetscReal) ctx->m);
72     PetscCheck(gridN * gridN == ctx->m,PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Number of rows must be square");
73     for (Ii=Istart; Ii<Iend; Ii++) {
74       i   = Ii / gridN; j = Ii % gridN;
75       I_n = i * gridN + j + 1;
76       if (j + 1 >= gridN) I_n = -1;
77       I_s = i * gridN + j - 1;
78       if (j - 1 < 0) I_s = -1;
79       I_e = (i + 1) * gridN + j;
80       if (i + 1 >= gridN) I_e = -1;
81       I_w = (i - 1) * gridN + j;
82       if (i - 1 < 0) I_w = -1;
83       PetscCall(MatSetValue(ctx->F, Ii, Ii, 4., INSERT_VALUES));
84       PetscCall(MatSetValue(ctx->F, Ii, I_n, -1., INSERT_VALUES));
85       PetscCall(MatSetValue(ctx->F, Ii, I_s, -1., INSERT_VALUES));
86       PetscCall(MatSetValue(ctx->F, Ii, I_e, -1., INSERT_VALUES));
87       PetscCall(MatSetValue(ctx->F, Ii, I_w, -1., INSERT_VALUES));
88     }
89   } else PetscCall(MatSetRandom(ctx->F, ctx->rctx));
90   PetscCall(MatAssemblyBegin(ctx->F, MAT_FINAL_ASSEMBLY));
91   PetscCall(MatAssemblyEnd(ctx->F, MAT_FINAL_ASSEMBLY));
92   PetscCall(PetscLogStagePop());
93   /* Stencil matrix is symmetric. Setting symmetric flag for ICC/Cholesky preconditioner */
94   if (!(ctx->matops)) {
95     PetscCall(MatSetOption(ctx->F,MAT_SYMMETRIC,PETSC_TRUE));
96   }
97   PetscCall(MatTransposeMatMult(ctx->F,ctx->F, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(ctx->W)));
98   /* Setup Hessian Workspace in same shape as W */
99   PetscCall(MatDuplicate(ctx->W,MAT_DO_NOT_COPY_VALUES,&(ctx->Hm)));
100   PetscCall(MatDuplicate(ctx->W,MAT_DO_NOT_COPY_VALUES,&(ctx->Hr)));
101   PetscFunctionReturn(0);
102 }
103 
104 static PetscErrorCode SetupWorkspace(UserCtx ctx)
105 {
106   PetscInt       i;
107 
108   PetscFunctionBegin;
109   PetscCall(MatCreateVecs(ctx->F, &ctx->workLeft[0], &ctx->workRight[0]));
110   for (i=1; i<NWORKLEFT; i++) {
111     PetscCall(VecDuplicate(ctx->workLeft[0], &(ctx->workLeft[i])));
112   }
113   for (i=1; i<NWORKRIGHT; i++) {
114     PetscCall(VecDuplicate(ctx->workRight[0], &(ctx->workRight[i])));
115   }
116   PetscFunctionReturn(0);
117 }
118 
119 static PetscErrorCode ConfigureContext(UserCtx ctx)
120 {
121   PetscFunctionBegin;
122   ctx->m        = 16;
123   ctx->n        = 16;
124   ctx->eps      = 1.e-3;
125   ctx->abstol   = 1.e-4;
126   ctx->reltol   = 1.e-2;
127   ctx->hStart   = 1.;
128   ctx->hMin     = 1.e-3;
129   ctx->hFactor  = 0.5;
130   ctx->alpha    = 1.;
131   ctx->mu       = 1.0;
132   ctx->matops   = 0;
133   ctx->iter     = 10;
134   ctx->p        = NORM_2;
135   ctx->taylor   = PETSC_TRUE;
136   ctx->use_admm = PETSC_FALSE;
137   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "ex4.c");
138   PetscCall(PetscOptionsInt("-m", "The row dimension of matrix F", "ex4.c", ctx->m, &(ctx->m), NULL));
139   PetscCall(PetscOptionsInt("-n", "The column dimension of matrix F", "ex4.c", ctx->n, &(ctx->n), NULL));
140   PetscCall(PetscOptionsInt("-matrix_format","Decide format of F matrix. 0 for stencil, 1 for random", "ex4.c", ctx->matops, &(ctx->matops), NULL));
141   PetscCall(PetscOptionsInt("-iter","Iteration number ADMM", "ex4.c", ctx->iter, &(ctx->iter), NULL));
142   PetscCall(PetscOptionsReal("-alpha", "The regularization multiplier. 1 default", "ex4.c", ctx->alpha, &(ctx->alpha), NULL));
143   PetscCall(PetscOptionsReal("-epsilon", "The small constant added to |x_i| in the denominator to approximate the gradient of ||x||_1", "ex4.c", ctx->eps, &(ctx->eps), NULL));
144   PetscCall(PetscOptionsReal("-mu", "The augmented lagrangian multiplier in ADMM", "ex4.c", ctx->mu, &(ctx->mu), NULL));
145   PetscCall(PetscOptionsReal("-hStart", "Taylor test starting point. 1 default.", "ex4.c", ctx->hStart, &(ctx->hStart), NULL));
146   PetscCall(PetscOptionsReal("-hFactor", "Taylor test multiplier factor. 0.5 default", "ex4.c", ctx->hFactor, &(ctx->hFactor), NULL));
147   PetscCall(PetscOptionsReal("-hMin", "Taylor test ending condition. 1.e-3 default", "ex4.c", ctx->hMin, &(ctx->hMin), NULL));
148   PetscCall(PetscOptionsReal("-abstol", "Absolute stopping criterion for ADMM", "ex4.c", ctx->abstol, &(ctx->abstol), NULL));
149   PetscCall(PetscOptionsReal("-reltol", "Relative stopping criterion for ADMM", "ex4.c", ctx->reltol, &(ctx->reltol), NULL));
150   PetscCall(PetscOptionsBool("-taylor","Flag for Taylor test. Default is true.", "ex4.c", ctx->taylor, &(ctx->taylor), NULL));
151   PetscCall(PetscOptionsBool("-use_admm","Use the ADMM solver in this example.", "ex4.c", ctx->use_admm, &(ctx->use_admm), NULL));
152   PetscCall(PetscOptionsEnum("-p","Norm type.", "ex4.c", NormTypes, (PetscEnum)ctx->p, (PetscEnum *) &(ctx->p), NULL));
153   PetscOptionsEnd();
154   /* Creating random ctx */
155   PetscCall(PetscRandomCreate(PETSC_COMM_WORLD,&(ctx->rctx)));
156   PetscCall(PetscRandomSetFromOptions(ctx->rctx));
157   PetscCall(CreateMatrix(ctx));
158   PetscCall(CreateRHS(ctx));
159   PetscCall(SetupWorkspace(ctx));
160   PetscFunctionReturn(0);
161 }
162 
163 static PetscErrorCode DestroyContext(UserCtx *ctx)
164 {
165   PetscInt       i;
166 
167   PetscFunctionBegin;
168   PetscCall(MatDestroy(&((*ctx)->F)));
169   PetscCall(MatDestroy(&((*ctx)->W)));
170   PetscCall(MatDestroy(&((*ctx)->Hm)));
171   PetscCall(MatDestroy(&((*ctx)->Hr)));
172   PetscCall(VecDestroy(&((*ctx)->d)));
173   for (i=0; i<NWORKLEFT; i++) {
174     PetscCall(VecDestroy(&((*ctx)->workLeft[i])));
175   }
176   for (i=0; i<NWORKRIGHT; i++) {
177     PetscCall(VecDestroy(&((*ctx)->workRight[i])));
178   }
179   PetscCall(PetscRandomDestroy(&((*ctx)->rctx)));
180   PetscCall(PetscFree(*ctx));
181   PetscFunctionReturn(0);
182 }
183 
184 /* compute (1/2) * ||F x - d||^2 */
185 static PetscErrorCode ObjectiveMisfit(Tao tao, Vec x, PetscReal *J, void *_ctx)
186 {
187   UserCtx        ctx = (UserCtx) _ctx;
188   Vec            y;
189 
190   PetscFunctionBegin;
191   y    = ctx->workLeft[0];
192   PetscCall(MatMult(ctx->F, x, y));
193   PetscCall(VecAXPY(y, -1., ctx->d));
194   PetscCall(VecDot(y, y, J));
195   *J  *= 0.5;
196   PetscFunctionReturn(0);
197 }
198 
199 /* compute V = FTFx - FTd */
200 static PetscErrorCode GradientMisfit(Tao tao, Vec x, Vec V, void *_ctx)
201 {
202   UserCtx        ctx = (UserCtx) _ctx;
203   Vec            FTFx, FTd;
204 
205   PetscFunctionBegin;
206   /* work1 is A^T Ax, work2 is Ab, W is A^T A*/
207   FTFx = ctx->workRight[0];
208   FTd  = ctx->workRight[1];
209   PetscCall(MatMult(ctx->W,x,FTFx));
210   PetscCall(MatMultTranspose(ctx->F, ctx->d, FTd));
211   PetscCall(VecWAXPY(V, -1., FTd, FTFx));
212   PetscFunctionReturn(0);
213 }
214 
215 /* returns FTF */
216 static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
217 {
218   UserCtx        ctx = (UserCtx) _ctx;
219 
220   PetscFunctionBegin;
221   if (H != ctx->W) PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
222   if (Hpre != ctx->W) PetscCall(MatCopy(ctx->W, Hpre, DIFFERENT_NONZERO_PATTERN));
223   PetscFunctionReturn(0);
224 }
225 
226 /* computes augment Lagrangian objective (with scaled dual):
227  * 0.5 * ||F x - d||^2  + 0.5 * mu ||x - z + u||^2 */
228 static PetscErrorCode ObjectiveMisfitADMM(Tao tao, Vec x, PetscReal *J, void *_ctx)
229 {
230   UserCtx        ctx = (UserCtx) _ctx;
231   PetscReal      mu, workNorm, misfit;
232   Vec            z, u, temp;
233 
234   PetscFunctionBegin;
235   mu   = ctx->mu;
236   z    = ctx->workRight[5];
237   u    = ctx->workRight[6];
238   temp = ctx->workRight[10];
239   /* misfit = f(x) */
240   PetscCall(ObjectiveMisfit(tao, x, &misfit, _ctx));
241   PetscCall(VecCopy(x,temp));
242   /* temp = x - z + u */
243   PetscCall(VecAXPBYPCZ(temp,-1.,1.,1.,z,u));
244   /* workNorm = ||x - z + u||^2 */
245   PetscCall(VecDot(temp, temp, &workNorm));
246   /* augment Lagrangian objective (with scaled dual): f(x) + 0.5 * mu ||x -z + u||^2 */
247   *J = misfit + 0.5 * mu * workNorm;
248   PetscFunctionReturn(0);
249 }
250 
251 /* computes FTFx - FTd  mu*(x - z + u) */
252 static PetscErrorCode GradientMisfitADMM(Tao tao, Vec x, Vec V, void *_ctx)
253 {
254   UserCtx        ctx = (UserCtx) _ctx;
255   PetscReal      mu;
256   Vec            z, u, temp;
257 
258   PetscFunctionBegin;
259   mu   = ctx->mu;
260   z    = ctx->workRight[5];
261   u    = ctx->workRight[6];
262   temp = ctx->workRight[10];
263   PetscCall(GradientMisfit(tao, x, V, _ctx));
264   PetscCall(VecCopy(x, temp));
265   /* temp = x - z + u */
266   PetscCall(VecAXPBYPCZ(temp,-1.,1.,1.,z,u));
267   /* V =  FTFx - FTd  mu*(x - z + u) */
268   PetscCall(VecAXPY(V, mu, temp));
269   PetscFunctionReturn(0);
270 }
271 
272 /* returns FTF + diag(mu) */
273 static PetscErrorCode HessianMisfitADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
274 {
275   UserCtx        ctx = (UserCtx) _ctx;
276 
277   PetscFunctionBegin;
278   PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
279   PetscCall(MatShift(H, ctx->mu));
280   if (Hpre != H) {
281     PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN));
282   }
283   PetscFunctionReturn(0);
284 }
285 
286 /* computes || x ||_p (mult by 0.5 in case of NORM_2) */
287 static PetscErrorCode ObjectiveRegularization(Tao tao, Vec x, PetscReal *J, void *_ctx)
288 {
289   UserCtx        ctx = (UserCtx) _ctx;
290   PetscReal      norm;
291 
292   PetscFunctionBegin;
293   *J = 0;
294   PetscCall(VecNorm (x, ctx->p, &norm));
295   if (ctx->p == NORM_2) norm = 0.5 * norm * norm;
296   *J = ctx->alpha * norm;
297   PetscFunctionReturn(0);
298 }
299 
300 /* NORM_2 Case: return x
301  * NORM_1 Case: x/(|x| + eps)
302  * Else: TODO */
303 static PetscErrorCode GradientRegularization(Tao tao, Vec x, Vec V, void *_ctx)
304 {
305   UserCtx        ctx = (UserCtx) _ctx;
306   PetscReal      eps = ctx->eps;
307 
308   PetscFunctionBegin;
309   if (ctx->p == NORM_2) {
310     PetscCall(VecCopy(x, V));
311   } else if (ctx->p == NORM_1) {
312     PetscCall(VecCopy(x, ctx->workRight[1]));
313     PetscCall(VecAbs(ctx->workRight[1]));
314     PetscCall(VecShift(ctx->workRight[1], eps));
315     PetscCall(VecPointwiseDivide(V, x, ctx->workRight[1]));
316   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
317   PetscFunctionReturn(0);
318 }
319 
320 /* NORM_2 Case: returns diag(mu)
321  * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)))  */
322 static PetscErrorCode HessianRegularization(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
323 {
324   UserCtx        ctx = (UserCtx) _ctx;
325   PetscReal      eps = ctx->eps;
326   Vec            copy1,copy2,copy3;
327 
328   PetscFunctionBegin;
329   if (ctx->p == NORM_2) {
330     /* Identity matrix scaled by mu */
331     PetscCall(MatZeroEntries(H));
332     PetscCall(MatShift(H,ctx->mu));
333     if (Hpre != H) {
334       PetscCall(MatZeroEntries(Hpre));
335       PetscCall(MatShift(Hpre,ctx->mu));
336     }
337   } else if (ctx->p == NORM_1) {
338     /* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)) */
339     copy1 = ctx->workRight[1];
340     copy2 = ctx->workRight[2];
341     copy3 = ctx->workRight[3];
342     /* copy1 : 1/sqrt(x_i^2 + eps) */
343     PetscCall(VecCopy(x, copy1));
344     PetscCall(VecPow(copy1,2));
345     PetscCall(VecShift(copy1, eps));
346     PetscCall(VecSqrtAbs(copy1));
347     PetscCall(VecReciprocal(copy1));
348     /* copy2:  x_i^2.*/
349     PetscCall(VecCopy(x,copy2));
350     PetscCall(VecPow(copy2,2));
351     /* copy3: abs(x_i^2 + eps) */
352     PetscCall(VecCopy(x,copy3));
353     PetscCall(VecPow(copy3,2));
354     PetscCall(VecShift(copy3, eps));
355     PetscCall(VecAbs(copy3));
356     /* copy2: 1 - x_i^2/abs(x_i^2 + eps) */
357     PetscCall(VecPointwiseDivide(copy2, copy2,copy3));
358     PetscCall(VecScale(copy2, -1.));
359     PetscCall(VecShift(copy2, 1.));
360     PetscCall(VecAXPY(copy1,1.,copy2));
361     PetscCall(VecScale(copy1, ctx->mu));
362     PetscCall(MatZeroEntries(H));
363     PetscCall(MatDiagonalSet(H, copy1,INSERT_VALUES));
364     if (Hpre != H) {
365       PetscCall(MatZeroEntries(Hpre));
366       PetscCall(MatDiagonalSet(Hpre, copy1,INSERT_VALUES));
367     }
368   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
369   PetscFunctionReturn(0);
370 }
371 
372 /* NORM_2 Case: 0.5 || x ||_2 + 0.5 * mu * ||x + u - z||^2
373  * Else : || x ||_2 + 0.5 * mu * ||x + u - z||^2 */
374 static PetscErrorCode ObjectiveRegularizationADMM(Tao tao, Vec z, PetscReal *J, void *_ctx)
375 {
376   UserCtx        ctx = (UserCtx) _ctx;
377   PetscReal      mu, workNorm, reg;
378   Vec            x, u, temp;
379 
380   PetscFunctionBegin;
381   mu   = ctx->mu;
382   x    = ctx->workRight[4];
383   u    = ctx->workRight[6];
384   temp = ctx->workRight[10];
385   PetscCall(ObjectiveRegularization(tao, z, &reg, _ctx));
386   PetscCall(VecCopy(z,temp));
387   /* temp = x + u -z */
388   PetscCall(VecAXPBYPCZ(temp,1.,1.,-1.,x,u));
389   /* workNorm = ||x + u - z ||^2 */
390   PetscCall(VecDot(temp, temp, &workNorm));
391   *J   = reg + 0.5 * mu * workNorm;
392   PetscFunctionReturn(0);
393 }
394 
395 /* NORM_2 Case: x - mu*(x + u - z)
396  * NORM_1 Case: x/(|x| + eps) - mu*(x + u - z)
397  * Else: TODO */
398 static PetscErrorCode GradientRegularizationADMM(Tao tao, Vec z, Vec V, void *_ctx)
399 {
400   UserCtx        ctx = (UserCtx) _ctx;
401   PetscReal      mu;
402   Vec            x, u, temp;
403 
404   PetscFunctionBegin;
405   mu   = ctx->mu;
406   x    = ctx->workRight[4];
407   u    = ctx->workRight[6];
408   temp = ctx->workRight[10];
409   PetscCall(GradientRegularization(tao, z, V, _ctx));
410   PetscCall(VecCopy(z, temp));
411   /* temp = x + u -z */
412   PetscCall(VecAXPBYPCZ(temp,1.,1.,-1.,x,u));
413   PetscCall(VecAXPY(V, -mu, temp));
414   PetscFunctionReturn(0);
415 }
416 
417 /* NORM_2 Case: returns diag(mu)
418  * NORM_1 Case: FTF + diag(mu) */
419 static PetscErrorCode HessianRegularizationADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
420 {
421   UserCtx        ctx = (UserCtx) _ctx;
422 
423   PetscFunctionBegin;
424   if (ctx->p == NORM_2) {
425     /* Identity matrix scaled by mu */
426     PetscCall(MatZeroEntries(H));
427     PetscCall(MatShift(H,ctx->mu));
428     if (Hpre != H) {
429       PetscCall(MatZeroEntries(Hpre));
430       PetscCall(MatShift(Hpre,ctx->mu));
431     }
432   } else if (ctx->p == NORM_1) {
433     PetscCall(HessianMisfit(tao, x, H, Hpre, (void*) ctx));
434     PetscCall(MatShift(H, ctx->mu));
435     if (Hpre != H) PetscCall(MatShift(Hpre, ctx->mu));
436   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
437   PetscFunctionReturn(0);
438 }
439 
440 /* NORM_2 Case : (1/2) * ||F x - d||^2 + 0.5 * || x ||_p
441 *  NORM_1 Case : (1/2) * ||F x - d||^2 + || x ||_p */
442 static PetscErrorCode ObjectiveComplete(Tao tao, Vec x, PetscReal *J, void *ctx)
443 {
444   PetscReal      Jm, Jr;
445 
446   PetscFunctionBegin;
447   PetscCall(ObjectiveMisfit(tao, x, &Jm, ctx));
448   PetscCall(ObjectiveRegularization(tao, x, &Jr, ctx));
449   *J   = Jm + Jr;
450   PetscFunctionReturn(0);
451 }
452 
453 /* NORM_2 Case: FTFx - FTd + x
454  * NORM_1 Case: FTFx - FTd + x/(|x| + eps) */
455 static PetscErrorCode GradientComplete(Tao tao, Vec x, Vec V, void *ctx)
456 {
457   UserCtx        cntx = (UserCtx) ctx;
458 
459   PetscFunctionBegin;
460   PetscCall(GradientMisfit(tao, x, cntx->workRight[2], ctx));
461   PetscCall(GradientRegularization(tao, x, cntx->workRight[3], ctx));
462   PetscCall(VecWAXPY(V,1,cntx->workRight[2],cntx->workRight[3]));
463   PetscFunctionReturn(0);
464 }
465 
466 /* NORM_2 Case: diag(mu) + FTF
467  * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) + FTF  */
468 static PetscErrorCode HessianComplete(Tao tao, Vec x, Mat H, Mat Hpre, void *ctx)
469 {
470   Mat            tempH;
471 
472   PetscFunctionBegin;
473   PetscCall(MatDuplicate(H, MAT_SHARE_NONZERO_PATTERN, &tempH));
474   PetscCall(HessianMisfit(tao, x, H, H, ctx));
475   PetscCall(HessianRegularization(tao, x, tempH, tempH, ctx));
476   PetscCall(MatAXPY(H, 1., tempH, DIFFERENT_NONZERO_PATTERN));
477   if (Hpre != H) {
478     PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN));
479   }
480   PetscCall(MatDestroy(&tempH));
481   PetscFunctionReturn(0);
482 }
483 
484 static PetscErrorCode TaoSolveADMM(UserCtx ctx,  Vec x)
485 {
486   PetscInt       i;
487   PetscReal      u_norm, r_norm, s_norm, primal, dual, x_norm, z_norm;
488   Tao            tao1,tao2;
489   Vec            xk,z,u,diff,zold,zdiff,temp;
490   PetscReal      mu;
491 
492   PetscFunctionBegin;
493   xk    = ctx->workRight[4];
494   z     = ctx->workRight[5];
495   u     = ctx->workRight[6];
496   diff  = ctx->workRight[7];
497   zold  = ctx->workRight[8];
498   zdiff = ctx->workRight[9];
499   temp  = ctx->workRight[11];
500   mu    = ctx->mu;
501   PetscCall(VecSet(u, 0.));
502   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao1));
503   PetscCall(TaoSetType(tao1,TAONLS));
504   PetscCall(TaoSetObjective(tao1, ObjectiveMisfitADMM, (void*) ctx));
505   PetscCall(TaoSetGradient(tao1, NULL, GradientMisfitADMM, (void*) ctx));
506   PetscCall(TaoSetHessian(tao1, ctx->Hm, ctx->Hm, HessianMisfitADMM, (void*) ctx));
507   PetscCall(VecSet(xk, 0.));
508   PetscCall(TaoSetSolution(tao1, xk));
509   PetscCall(TaoSetOptionsPrefix(tao1, "misfit_"));
510   PetscCall(TaoSetFromOptions(tao1));
511   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao2));
512   if (ctx->p == NORM_2) {
513     PetscCall(TaoSetType(tao2,TAONLS));
514     PetscCall(TaoSetObjective(tao2, ObjectiveRegularizationADMM, (void*) ctx));
515     PetscCall(TaoSetGradient(tao2, NULL, GradientRegularizationADMM, (void*) ctx));
516     PetscCall(TaoSetHessian(tao2, ctx->Hr, ctx->Hr, HessianRegularizationADMM, (void*) ctx));
517   }
518   PetscCall(VecSet(z, 0.));
519   PetscCall(TaoSetSolution(tao2, z));
520   PetscCall(TaoSetOptionsPrefix(tao2, "reg_"));
521   PetscCall(TaoSetFromOptions(tao2));
522 
523   for (i=0; i<ctx->iter; i++) {
524     PetscCall(VecCopy(z,zold));
525     PetscCall(TaoSolve(tao1)); /* Updates xk */
526     if (ctx->p == NORM_1) {
527       PetscCall(VecWAXPY(temp,1.,xk,u));
528       PetscCall(TaoSoftThreshold(temp,-ctx->alpha/mu,ctx->alpha/mu,z));
529     } else {
530       PetscCall(TaoSolve(tao2)); /* Update zk */
531     }
532     /* u = u + xk -z */
533     PetscCall(VecAXPBYPCZ(u,1.,-1.,1.,xk,z));
534     /* r_norm : norm(x-z) */
535     PetscCall(VecWAXPY(diff,-1.,z,xk));
536     PetscCall(VecNorm(diff,NORM_2,&r_norm));
537     /* s_norm : norm(-mu(z-zold)) */
538     PetscCall(VecWAXPY(zdiff, -1.,zold,z));
539     PetscCall(VecNorm(zdiff,NORM_2,&s_norm));
540     s_norm = s_norm * mu;
541     /* primal : sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z))*/
542     PetscCall(VecNorm(xk,NORM_2,&x_norm));
543     PetscCall(VecNorm(z,NORM_2,&z_norm));
544     primal = PetscSqrtReal(ctx->n)*ctx->abstol + ctx->reltol*PetscMax(x_norm,z_norm);
545     /* Duality : sqrt(n)*ABSTOL + RELTOL*norm(mu*u)*/
546     PetscCall(VecNorm(u,NORM_2,&u_norm));
547     dual = PetscSqrtReal(ctx->n)*ctx->abstol + ctx->reltol*u_norm*mu;
548     PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao1),"Iter %" PetscInt_FMT " : ||x-z||: %g, mu*||z-zold||: %g\n", i, (double) r_norm, (double) s_norm));
549     if (r_norm < primal && s_norm < dual) break;
550   }
551   PetscCall(VecCopy(xk, x));
552   PetscCall(TaoDestroy(&tao1));
553   PetscCall(TaoDestroy(&tao2));
554   PetscFunctionReturn(0);
555 }
556 
557 /* Second order Taylor remainder convergence test */
558 static PetscErrorCode TaylorTest(UserCtx ctx, Tao tao, Vec x, PetscReal *C)
559 {
560   PetscReal      h,J,temp;
561   PetscInt       i,j;
562   PetscInt       numValues;
563   PetscReal      Jx,Jxhat_comp,Jxhat_pred;
564   PetscReal      *Js, *hs;
565   PetscReal      gdotdx;
566   PetscReal      minrate = PETSC_MAX_REAL;
567   MPI_Comm       comm = PetscObjectComm((PetscObject)x);
568   Vec            g, dx, xhat;
569 
570   PetscFunctionBegin;
571   PetscCall(VecDuplicate(x, &g));
572   PetscCall(VecDuplicate(x, &xhat));
573   /* choose a perturbation direction */
574   PetscCall(VecDuplicate(x, &dx));
575   PetscCall(VecSetRandom(dx,ctx->rctx));
576   /* evaluate objective at x: J(x) */
577   PetscCall(TaoComputeObjective(tao, x, &Jx));
578   /* evaluate gradient at x, save in vector g */
579   PetscCall(TaoComputeGradient(tao, x, g));
580   PetscCall(VecDot(g, dx, &gdotdx));
581 
582   for (numValues=0, h=ctx->hStart; h>=ctx->hMin; h*=ctx->hFactor) numValues++;
583   PetscCall(PetscCalloc2(numValues, &Js, numValues, &hs));
584   for (i=0, h=ctx->hStart; h>=ctx->hMin; h*=ctx->hFactor, i++) {
585     PetscCall(VecWAXPY(xhat, h, dx, x));
586     PetscCall(TaoComputeObjective(tao, xhat, &Jxhat_comp));
587     /* J(\hat(x)) \approx J(x) + g^T (xhat - x) = J(x) + h * g^T dx */
588     Jxhat_pred = Jx + h * gdotdx;
589     /* Vector to dJdm scalar? Dot?*/
590     J     = PetscAbsReal(Jxhat_comp - Jxhat_pred);
591     PetscCall(PetscPrintf (comm, "J(xhat): %g, predicted: %g, diff %g\n", (double) Jxhat_comp,(double) Jxhat_pred, (double) J));
592     Js[i] = J;
593     hs[i] = h;
594   }
595   for (j=1; j<numValues; j++) {
596     temp    = PetscLogReal(Js[j]/Js[j-1]) / PetscLogReal (hs[j]/hs[j-1]);
597     PetscCall(PetscPrintf (comm, "Convergence rate step %" PetscInt_FMT ": %g\n", j-1, (double) temp));
598     minrate = PetscMin(minrate, temp);
599   }
600   /* If O is not ~2, then the test is wrong */
601   PetscCall(PetscFree2(Js, hs));
602   *C   = minrate;
603   PetscCall(VecDestroy(&dx));
604   PetscCall(VecDestroy(&xhat));
605   PetscCall(VecDestroy(&g));
606   PetscFunctionReturn(0);
607 }
608 
609 int main(int argc, char ** argv)
610 {
611   UserCtx ctx;
612   Tao     tao;
613   Vec     x;
614   Mat     H;
615 
616   PetscCall(PetscInitialize(&argc, &argv, NULL,help));
617   PetscCall(PetscNew(&ctx));
618   PetscCall(ConfigureContext(ctx));
619   /* Define two functions that could pass as objectives to TaoSetObjective(): one
620    * for the misfit component, and one for the regularization component */
621   /* ObjectiveMisfit() and ObjectiveRegularization() */
622 
623   /* Define a single function that calls both components adds them together: the complete objective,
624    * in the absence of a Tao implementation that handles separability */
625   /* ObjectiveComplete() */
626   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao));
627   PetscCall(TaoSetType(tao,TAONM));
628   PetscCall(TaoSetObjective(tao, ObjectiveComplete, (void*) ctx));
629   PetscCall(TaoSetGradient(tao, NULL, GradientComplete, (void*) ctx));
630   PetscCall(MatDuplicate(ctx->W, MAT_SHARE_NONZERO_PATTERN, &H));
631   PetscCall(TaoSetHessian(tao, H, H, HessianComplete, (void*) ctx));
632   PetscCall(MatCreateVecs(ctx->F, NULL, &x));
633   PetscCall(VecSet(x, 0.));
634   PetscCall(TaoSetSolution(tao, x));
635   PetscCall(TaoSetFromOptions(tao));
636   if (ctx->use_admm) PetscCall(TaoSolveADMM(ctx,x));
637   else PetscCall(TaoSolve(tao));
638   /* examine solution */
639   PetscCall(VecViewFromOptions(x, NULL, "-view_sol"));
640   if (ctx->taylor) {
641     PetscReal rate;
642     PetscCall(TaylorTest(ctx, tao, x, &rate));
643   }
644   PetscCall(MatDestroy(&H));
645   PetscCall(TaoDestroy(&tao));
646   PetscCall(VecDestroy(&x));
647   PetscCall(DestroyContext(&ctx));
648   PetscCall(PetscFinalize());
649   return 0;
650 }
651 
652 /*TEST
653 
654   build:
655     requires: !complex
656 
657   test:
658     suffix: 0
659     args:
660 
661   test:
662     suffix: l1_1
663     args: -p 1 -tao_type lmvm -alpha 1. -epsilon 1.e-7 -m 64 -n 64 -view_sol -matrix_format 1
664 
665   test:
666     suffix: hessian_1
667     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nls
668 
669   test:
670     suffix: hessian_2
671     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nls
672 
673   test:
674     suffix: nm_1
675     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nm -tao_max_it 50
676 
677   test:
678     suffix: nm_2
679     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nm -tao_max_it 50
680 
681   test:
682     suffix: lmvm_1
683     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type lmvm -tao_max_it 40
684 
685   test:
686     suffix: lmvm_2
687     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type lmvm -tao_max_it 15
688 
689   test:
690     suffix: soft_threshold_admm_1
691     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm
692 
693   test:
694     suffix: hessian_admm_1
695     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nls -misfit_tao_type nls
696 
697   test:
698     suffix: hessian_admm_2
699     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nls -misfit_tao_type nls
700 
701   test:
702     suffix: nm_admm_1
703     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nm -misfit_tao_type nm
704 
705   test:
706     suffix: nm_admm_2
707     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nm -misfit_tao_type nm -iter 7
708 
709   test:
710     suffix: lmvm_admm_1
711     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm
712 
713   test:
714     suffix: lmvm_admm_2
715     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm
716 
717 TEST*/
718