xref: /petsc/src/tao/tutorials/ex4.c (revision 6a5217c03994f2d95bb2e6dbd8bed42381aeb015)
1 static char help[] = "Simple example to test separable objective optimizers.\n";
2 
3 #include <petsc.h>
4 #include <petsctao.h>
5 #include <petscvec.h>
6 #include <petscmath.h>
7 
8 #define NWORKLEFT 4
9 #define NWORKRIGHT 12
10 
11 typedef struct _UserCtx
12 {
13   PetscInt    m;                     /* The row dimension of F */
14   PetscInt    n;                     /* The column dimension of F */
15   PetscInt    matops;                /* Matrix format. 0 for stencil, 1 for random */
16   PetscInt    iter;                  /* Numer of iterations for ADMM */
17   PetscReal   hStart;                /* Starting point for Taylor test */
18   PetscReal   hFactor;               /* Taylor test step factor */
19   PetscReal   hMin;                  /* Taylor test end goal */
20   PetscReal   alpha;                 /* regularization constant applied to || x ||_p */
21   PetscReal   eps;                   /* small constant for approximating gradient of || x ||_1 */
22   PetscReal   mu;                    /* the augmented Lagrangian term in ADMM */
23   PetscReal   abstol;
24   PetscReal   reltol;
25   Mat         F;                     /* matrix in least squares component $(1/2) * || F x - d ||_2^2$ */
26   Mat         W;                     /* Workspace matrix. ATA */
27   Mat         Hm;                    /* Hessian Misfit*/
28   Mat         Hr;                    /* Hessian Reg*/
29   Vec         d;                     /* RHS in least squares component $(1/2) * || F x - d ||_2^2$ */
30   Vec         workLeft[NWORKLEFT];   /* Workspace for temporary vec */
31   Vec         workRight[NWORKRIGHT]; /* Workspace for temporary vec */
32   NormType    p;
33   PetscRandom rctx;
34   PetscBool   taylor;                /* Flag to determine whether to run Taylor test or not */
35   PetscBool   use_admm;              /* Flag to determine whether to run Taylor test or not */
36 }* UserCtx;
37 
38 static PetscErrorCode CreateRHS(UserCtx ctx)
39 {
40   PetscFunctionBegin;
41   /* build the rhs d in ctx */
42   PetscCall(VecCreate(PETSC_COMM_WORLD,&(ctx->d)));
43   PetscCall(VecSetSizes(ctx->d,PETSC_DECIDE,ctx->m));
44   PetscCall(VecSetFromOptions(ctx->d));
45   PetscCall(VecSetRandom(ctx->d,ctx->rctx));
46   PetscFunctionReturn(0);
47 }
48 
49 static PetscErrorCode CreateMatrix(UserCtx ctx)
50 {
51   PetscInt       Istart,Iend,i,j,Ii,gridN,I_n, I_s, I_e, I_w;
52 #if defined(PETSC_USE_LOG)
53   PetscLogStage  stage;
54 #endif
55 
56   PetscFunctionBegin;
57   /* build the matrix F in ctx */
58   PetscCall(MatCreate(PETSC_COMM_WORLD, &(ctx->F)));
59   PetscCall(MatSetSizes(ctx->F,PETSC_DECIDE, PETSC_DECIDE, ctx->m, ctx->n));
60   PetscCall(MatSetType(ctx->F,MATAIJ)); /* TODO: Decide specific SetType other than dummy*/
61   PetscCall(MatMPIAIJSetPreallocation(ctx->F, 5, NULL, 5, NULL)); /*TODO: some number other than 5?*/
62   PetscCall(MatSeqAIJSetPreallocation(ctx->F, 5, NULL));
63   PetscCall(MatSetUp(ctx->F));
64   PetscCall(MatGetOwnershipRange(ctx->F,&Istart,&Iend));
65   PetscCall(PetscLogStageRegister("Assembly", &stage));
66   PetscCall(PetscLogStagePush(stage));
67 
68   /* Set matrix elements in  2-D five point stencil format. */
69   if (!(ctx->matops)) {
70     PetscCheck(ctx->m == ctx->n,PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Stencil matrix must be square");
71     gridN = (PetscInt) PetscSqrtReal((PetscReal) ctx->m);
72     PetscCheck(gridN * gridN == ctx->m,PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Number of rows must be square");
73     for (Ii=Istart; Ii<Iend; Ii++) {
74       i   = Ii / gridN; j = Ii % gridN;
75       I_n = i * gridN + j + 1;
76       if (j + 1 >= gridN) I_n = -1;
77       I_s = i * gridN + j - 1;
78       if (j - 1 < 0) I_s = -1;
79       I_e = (i + 1) * gridN + j;
80       if (i + 1 >= gridN) I_e = -1;
81       I_w = (i - 1) * gridN + j;
82       if (i - 1 < 0) I_w = -1;
83       PetscCall(MatSetValue(ctx->F, Ii, Ii, 4., INSERT_VALUES));
84       PetscCall(MatSetValue(ctx->F, Ii, I_n, -1., INSERT_VALUES));
85       PetscCall(MatSetValue(ctx->F, Ii, I_s, -1., INSERT_VALUES));
86       PetscCall(MatSetValue(ctx->F, Ii, I_e, -1., INSERT_VALUES));
87       PetscCall(MatSetValue(ctx->F, Ii, I_w, -1., INSERT_VALUES));
88     }
89   } else PetscCall(MatSetRandom(ctx->F, ctx->rctx));
90   PetscCall(MatAssemblyBegin(ctx->F, MAT_FINAL_ASSEMBLY));
91   PetscCall(MatAssemblyEnd(ctx->F, MAT_FINAL_ASSEMBLY));
92   PetscCall(PetscLogStagePop());
93   /* Stencil matrix is symmetric. Setting symmetric flag for ICC/Cholesky preconditioner */
94   if (!(ctx->matops)) {
95     PetscCall(MatSetOption(ctx->F,MAT_SYMMETRIC,PETSC_TRUE));
96   }
97   PetscCall(MatTransposeMatMult(ctx->F,ctx->F, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(ctx->W)));
98   /* Setup Hessian Workspace in same shape as W */
99   PetscCall(MatDuplicate(ctx->W,MAT_DO_NOT_COPY_VALUES,&(ctx->Hm)));
100   PetscCall(MatDuplicate(ctx->W,MAT_DO_NOT_COPY_VALUES,&(ctx->Hr)));
101   PetscFunctionReturn(0);
102 }
103 
104 static PetscErrorCode SetupWorkspace(UserCtx ctx)
105 {
106   PetscInt       i;
107 
108   PetscFunctionBegin;
109   PetscCall(MatCreateVecs(ctx->F, &ctx->workLeft[0], &ctx->workRight[0]));
110   for (i=1; i<NWORKLEFT; i++) {
111     PetscCall(VecDuplicate(ctx->workLeft[0], &(ctx->workLeft[i])));
112   }
113   for (i=1; i<NWORKRIGHT; i++) {
114     PetscCall(VecDuplicate(ctx->workRight[0], &(ctx->workRight[i])));
115   }
116   PetscFunctionReturn(0);
117 }
118 
119 static PetscErrorCode ConfigureContext(UserCtx ctx)
120 {
121   PetscErrorCode ierr;
122 
123   PetscFunctionBegin;
124   ctx->m        = 16;
125   ctx->n        = 16;
126   ctx->eps      = 1.e-3;
127   ctx->abstol   = 1.e-4;
128   ctx->reltol   = 1.e-2;
129   ctx->hStart   = 1.;
130   ctx->hMin     = 1.e-3;
131   ctx->hFactor  = 0.5;
132   ctx->alpha    = 1.;
133   ctx->mu       = 1.0;
134   ctx->matops   = 0;
135   ctx->iter     = 10;
136   ctx->p        = NORM_2;
137   ctx->taylor   = PETSC_TRUE;
138   ctx->use_admm = PETSC_FALSE;
139   ierr = PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "ex4.c");PetscCall(ierr);
140   PetscCall(PetscOptionsInt("-m", "The row dimension of matrix F", "ex4.c", ctx->m, &(ctx->m), NULL));
141   PetscCall(PetscOptionsInt("-n", "The column dimension of matrix F", "ex4.c", ctx->n, &(ctx->n), NULL));
142   PetscCall(PetscOptionsInt("-matrix_format","Decide format of F matrix. 0 for stencil, 1 for random", "ex4.c", ctx->matops, &(ctx->matops), NULL));
143   PetscCall(PetscOptionsInt("-iter","Iteration number ADMM", "ex4.c", ctx->iter, &(ctx->iter), NULL));
144   PetscCall(PetscOptionsReal("-alpha", "The regularization multiplier. 1 default", "ex4.c", ctx->alpha, &(ctx->alpha), NULL));
145   PetscCall(PetscOptionsReal("-epsilon", "The small constant added to |x_i| in the denominator to approximate the gradient of ||x||_1", "ex4.c", ctx->eps, &(ctx->eps), NULL));
146   PetscCall(PetscOptionsReal("-mu", "The augmented lagrangian multiplier in ADMM", "ex4.c", ctx->mu, &(ctx->mu), NULL));
147   PetscCall(PetscOptionsReal("-hStart", "Taylor test starting point. 1 default.", "ex4.c", ctx->hStart, &(ctx->hStart), NULL));
148   PetscCall(PetscOptionsReal("-hFactor", "Taylor test multiplier factor. 0.5 default", "ex4.c", ctx->hFactor, &(ctx->hFactor), NULL));
149   PetscCall(PetscOptionsReal("-hMin", "Taylor test ending condition. 1.e-3 default", "ex4.c", ctx->hMin, &(ctx->hMin), NULL));
150   PetscCall(PetscOptionsReal("-abstol", "Absolute stopping criterion for ADMM", "ex4.c", ctx->abstol, &(ctx->abstol), NULL));
151   PetscCall(PetscOptionsReal("-reltol", "Relative stopping criterion for ADMM", "ex4.c", ctx->reltol, &(ctx->reltol), NULL));
152   PetscCall(PetscOptionsBool("-taylor","Flag for Taylor test. Default is true.", "ex4.c", ctx->taylor, &(ctx->taylor), NULL));
153   PetscCall(PetscOptionsBool("-use_admm","Use the ADMM solver in this example.", "ex4.c", ctx->use_admm, &(ctx->use_admm), NULL));
154   PetscCall(PetscOptionsEnum("-p","Norm type.", "ex4.c", NormTypes, (PetscEnum)ctx->p, (PetscEnum *) &(ctx->p), NULL));
155   ierr = PetscOptionsEnd();PetscCall(ierr);
156   /* Creating random ctx */
157   PetscCall(PetscRandomCreate(PETSC_COMM_WORLD,&(ctx->rctx)));
158   PetscCall(PetscRandomSetFromOptions(ctx->rctx));
159   PetscCall(CreateMatrix(ctx));
160   PetscCall(CreateRHS(ctx));
161   PetscCall(SetupWorkspace(ctx));
162   PetscFunctionReturn(0);
163 }
164 
165 static PetscErrorCode DestroyContext(UserCtx *ctx)
166 {
167   PetscInt       i;
168 
169   PetscFunctionBegin;
170   PetscCall(MatDestroy(&((*ctx)->F)));
171   PetscCall(MatDestroy(&((*ctx)->W)));
172   PetscCall(MatDestroy(&((*ctx)->Hm)));
173   PetscCall(MatDestroy(&((*ctx)->Hr)));
174   PetscCall(VecDestroy(&((*ctx)->d)));
175   for (i=0; i<NWORKLEFT; i++) {
176     PetscCall(VecDestroy(&((*ctx)->workLeft[i])));
177   }
178   for (i=0; i<NWORKRIGHT; i++) {
179     PetscCall(VecDestroy(&((*ctx)->workRight[i])));
180   }
181   PetscCall(PetscRandomDestroy(&((*ctx)->rctx)));
182   PetscCall(PetscFree(*ctx));
183   PetscFunctionReturn(0);
184 }
185 
186 /* compute (1/2) * ||F x - d||^2 */
187 static PetscErrorCode ObjectiveMisfit(Tao tao, Vec x, PetscReal *J, void *_ctx)
188 {
189   UserCtx        ctx = (UserCtx) _ctx;
190   Vec            y;
191 
192   PetscFunctionBegin;
193   y    = ctx->workLeft[0];
194   PetscCall(MatMult(ctx->F, x, y));
195   PetscCall(VecAXPY(y, -1., ctx->d));
196   PetscCall(VecDot(y, y, J));
197   *J  *= 0.5;
198   PetscFunctionReturn(0);
199 }
200 
201 /* compute V = FTFx - FTd */
202 static PetscErrorCode GradientMisfit(Tao tao, Vec x, Vec V, void *_ctx)
203 {
204   UserCtx        ctx = (UserCtx) _ctx;
205   Vec            FTFx, FTd;
206 
207   PetscFunctionBegin;
208   /* work1 is A^T Ax, work2 is Ab, W is A^T A*/
209   FTFx = ctx->workRight[0];
210   FTd  = ctx->workRight[1];
211   PetscCall(MatMult(ctx->W,x,FTFx));
212   PetscCall(MatMultTranspose(ctx->F, ctx->d, FTd));
213   PetscCall(VecWAXPY(V, -1., FTd, FTFx));
214   PetscFunctionReturn(0);
215 }
216 
217 /* returns FTF */
218 static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
219 {
220   UserCtx        ctx = (UserCtx) _ctx;
221 
222   PetscFunctionBegin;
223   if (H != ctx->W) PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
224   if (Hpre != ctx->W) PetscCall(MatCopy(ctx->W, Hpre, DIFFERENT_NONZERO_PATTERN));
225   PetscFunctionReturn(0);
226 }
227 
228 /* computes augment Lagrangian objective (with scaled dual):
229  * 0.5 * ||F x - d||^2  + 0.5 * mu ||x - z + u||^2 */
230 static PetscErrorCode ObjectiveMisfitADMM(Tao tao, Vec x, PetscReal *J, void *_ctx)
231 {
232   UserCtx        ctx = (UserCtx) _ctx;
233   PetscReal      mu, workNorm, misfit;
234   Vec            z, u, temp;
235 
236   PetscFunctionBegin;
237   mu   = ctx->mu;
238   z    = ctx->workRight[5];
239   u    = ctx->workRight[6];
240   temp = ctx->workRight[10];
241   /* misfit = f(x) */
242   PetscCall(ObjectiveMisfit(tao, x, &misfit, _ctx));
243   PetscCall(VecCopy(x,temp));
244   /* temp = x - z + u */
245   PetscCall(VecAXPBYPCZ(temp,-1.,1.,1.,z,u));
246   /* workNorm = ||x - z + u||^2 */
247   PetscCall(VecDot(temp, temp, &workNorm));
248   /* augment Lagrangian objective (with scaled dual): f(x) + 0.5 * mu ||x -z + u||^2 */
249   *J = misfit + 0.5 * mu * workNorm;
250   PetscFunctionReturn(0);
251 }
252 
253 /* computes FTFx - FTd  mu*(x - z + u) */
254 static PetscErrorCode GradientMisfitADMM(Tao tao, Vec x, Vec V, void *_ctx)
255 {
256   UserCtx        ctx = (UserCtx) _ctx;
257   PetscReal      mu;
258   Vec            z, u, temp;
259 
260   PetscFunctionBegin;
261   mu   = ctx->mu;
262   z    = ctx->workRight[5];
263   u    = ctx->workRight[6];
264   temp = ctx->workRight[10];
265   PetscCall(GradientMisfit(tao, x, V, _ctx));
266   PetscCall(VecCopy(x, temp));
267   /* temp = x - z + u */
268   PetscCall(VecAXPBYPCZ(temp,-1.,1.,1.,z,u));
269   /* V =  FTFx - FTd  mu*(x - z + u) */
270   PetscCall(VecAXPY(V, mu, temp));
271   PetscFunctionReturn(0);
272 }
273 
274 /* returns FTF + diag(mu) */
275 static PetscErrorCode HessianMisfitADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
276 {
277   UserCtx        ctx = (UserCtx) _ctx;
278 
279   PetscFunctionBegin;
280   PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
281   PetscCall(MatShift(H, ctx->mu));
282   if (Hpre != H) {
283     PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN));
284   }
285   PetscFunctionReturn(0);
286 }
287 
288 /* computes || x ||_p (mult by 0.5 in case of NORM_2) */
289 static PetscErrorCode ObjectiveRegularization(Tao tao, Vec x, PetscReal *J, void *_ctx)
290 {
291   UserCtx        ctx = (UserCtx) _ctx;
292   PetscReal      norm;
293 
294   PetscFunctionBegin;
295   *J = 0;
296   PetscCall(VecNorm (x, ctx->p, &norm));
297   if (ctx->p == NORM_2) norm = 0.5 * norm * norm;
298   *J = ctx->alpha * norm;
299   PetscFunctionReturn(0);
300 }
301 
302 /* NORM_2 Case: return x
303  * NORM_1 Case: x/(|x| + eps)
304  * Else: TODO */
305 static PetscErrorCode GradientRegularization(Tao tao, Vec x, Vec V, void *_ctx)
306 {
307   UserCtx        ctx = (UserCtx) _ctx;
308   PetscReal      eps = ctx->eps;
309 
310   PetscFunctionBegin;
311   if (ctx->p == NORM_2) {
312     PetscCall(VecCopy(x, V));
313   } else if (ctx->p == NORM_1) {
314     PetscCall(VecCopy(x, ctx->workRight[1]));
315     PetscCall(VecAbs(ctx->workRight[1]));
316     PetscCall(VecShift(ctx->workRight[1], eps));
317     PetscCall(VecPointwiseDivide(V, x, ctx->workRight[1]));
318   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
319   PetscFunctionReturn(0);
320 }
321 
322 /* NORM_2 Case: returns diag(mu)
323  * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)))  */
324 static PetscErrorCode HessianRegularization(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
325 {
326   UserCtx        ctx = (UserCtx) _ctx;
327   PetscReal      eps = ctx->eps;
328   Vec            copy1,copy2,copy3;
329 
330   PetscFunctionBegin;
331   if (ctx->p == NORM_2) {
332     /* Identity matrix scaled by mu */
333     PetscCall(MatZeroEntries(H));
334     PetscCall(MatShift(H,ctx->mu));
335     if (Hpre != H) {
336       PetscCall(MatZeroEntries(Hpre));
337       PetscCall(MatShift(Hpre,ctx->mu));
338     }
339   } else if (ctx->p == NORM_1) {
340     /* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)) */
341     copy1 = ctx->workRight[1];
342     copy2 = ctx->workRight[2];
343     copy3 = ctx->workRight[3];
344     /* copy1 : 1/sqrt(x_i^2 + eps) */
345     PetscCall(VecCopy(x, copy1));
346     PetscCall(VecPow(copy1,2));
347     PetscCall(VecShift(copy1, eps));
348     PetscCall(VecSqrtAbs(copy1));
349     PetscCall(VecReciprocal(copy1));
350     /* copy2:  x_i^2.*/
351     PetscCall(VecCopy(x,copy2));
352     PetscCall(VecPow(copy2,2));
353     /* copy3: abs(x_i^2 + eps) */
354     PetscCall(VecCopy(x,copy3));
355     PetscCall(VecPow(copy3,2));
356     PetscCall(VecShift(copy3, eps));
357     PetscCall(VecAbs(copy3));
358     /* copy2: 1 - x_i^2/abs(x_i^2 + eps) */
359     PetscCall(VecPointwiseDivide(copy2, copy2,copy3));
360     PetscCall(VecScale(copy2, -1.));
361     PetscCall(VecShift(copy2, 1.));
362     PetscCall(VecAXPY(copy1,1.,copy2));
363     PetscCall(VecScale(copy1, ctx->mu));
364     PetscCall(MatZeroEntries(H));
365     PetscCall(MatDiagonalSet(H, copy1,INSERT_VALUES));
366     if (Hpre != H) {
367       PetscCall(MatZeroEntries(Hpre));
368       PetscCall(MatDiagonalSet(Hpre, copy1,INSERT_VALUES));
369     }
370   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
371   PetscFunctionReturn(0);
372 }
373 
374 /* NORM_2 Case: 0.5 || x ||_2 + 0.5 * mu * ||x + u - z||^2
375  * Else : || x ||_2 + 0.5 * mu * ||x + u - z||^2 */
376 static PetscErrorCode ObjectiveRegularizationADMM(Tao tao, Vec z, PetscReal *J, void *_ctx)
377 {
378   UserCtx        ctx = (UserCtx) _ctx;
379   PetscReal      mu, workNorm, reg;
380   Vec            x, u, temp;
381 
382   PetscFunctionBegin;
383   mu   = ctx->mu;
384   x    = ctx->workRight[4];
385   u    = ctx->workRight[6];
386   temp = ctx->workRight[10];
387   PetscCall(ObjectiveRegularization(tao, z, &reg, _ctx));
388   PetscCall(VecCopy(z,temp));
389   /* temp = x + u -z */
390   PetscCall(VecAXPBYPCZ(temp,1.,1.,-1.,x,u));
391   /* workNorm = ||x + u - z ||^2 */
392   PetscCall(VecDot(temp, temp, &workNorm));
393   *J   = reg + 0.5 * mu * workNorm;
394   PetscFunctionReturn(0);
395 }
396 
397 /* NORM_2 Case: x - mu*(x + u - z)
398  * NORM_1 Case: x/(|x| + eps) - mu*(x + u - z)
399  * Else: TODO */
400 static PetscErrorCode GradientRegularizationADMM(Tao tao, Vec z, Vec V, void *_ctx)
401 {
402   UserCtx        ctx = (UserCtx) _ctx;
403   PetscReal      mu;
404   Vec            x, u, temp;
405 
406   PetscFunctionBegin;
407   mu   = ctx->mu;
408   x    = ctx->workRight[4];
409   u    = ctx->workRight[6];
410   temp = ctx->workRight[10];
411   PetscCall(GradientRegularization(tao, z, V, _ctx));
412   PetscCall(VecCopy(z, temp));
413   /* temp = x + u -z */
414   PetscCall(VecAXPBYPCZ(temp,1.,1.,-1.,x,u));
415   PetscCall(VecAXPY(V, -mu, temp));
416   PetscFunctionReturn(0);
417 }
418 
419 /* NORM_2 Case: returns diag(mu)
420  * NORM_1 Case: FTF + diag(mu) */
421 static PetscErrorCode HessianRegularizationADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
422 {
423   UserCtx        ctx = (UserCtx) _ctx;
424 
425   PetscFunctionBegin;
426   if (ctx->p == NORM_2) {
427     /* Identity matrix scaled by mu */
428     PetscCall(MatZeroEntries(H));
429     PetscCall(MatShift(H,ctx->mu));
430     if (Hpre != H) {
431       PetscCall(MatZeroEntries(Hpre));
432       PetscCall(MatShift(Hpre,ctx->mu));
433     }
434   } else if (ctx->p == NORM_1) {
435     PetscCall(HessianMisfit(tao, x, H, Hpre, (void*) ctx));
436     PetscCall(MatShift(H, ctx->mu));
437     if (Hpre != H) PetscCall(MatShift(Hpre, ctx->mu));
438   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
439   PetscFunctionReturn(0);
440 }
441 
442 /* NORM_2 Case : (1/2) * ||F x - d||^2 + 0.5 * || x ||_p
443 *  NORM_1 Case : (1/2) * ||F x - d||^2 + || x ||_p */
444 static PetscErrorCode ObjectiveComplete(Tao tao, Vec x, PetscReal *J, void *ctx)
445 {
446   PetscReal      Jm, Jr;
447 
448   PetscFunctionBegin;
449   PetscCall(ObjectiveMisfit(tao, x, &Jm, ctx));
450   PetscCall(ObjectiveRegularization(tao, x, &Jr, ctx));
451   *J   = Jm + Jr;
452   PetscFunctionReturn(0);
453 }
454 
455 /* NORM_2 Case: FTFx - FTd + x
456  * NORM_1 Case: FTFx - FTd + x/(|x| + eps) */
457 static PetscErrorCode GradientComplete(Tao tao, Vec x, Vec V, void *ctx)
458 {
459   UserCtx        cntx = (UserCtx) ctx;
460 
461   PetscFunctionBegin;
462   PetscCall(GradientMisfit(tao, x, cntx->workRight[2], ctx));
463   PetscCall(GradientRegularization(tao, x, cntx->workRight[3], ctx));
464   PetscCall(VecWAXPY(V,1,cntx->workRight[2],cntx->workRight[3]));
465   PetscFunctionReturn(0);
466 }
467 
468 /* NORM_2 Case: diag(mu) + FTF
469  * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) + FTF  */
470 static PetscErrorCode HessianComplete(Tao tao, Vec x, Mat H, Mat Hpre, void *ctx)
471 {
472   Mat            tempH;
473 
474   PetscFunctionBegin;
475   PetscCall(MatDuplicate(H, MAT_SHARE_NONZERO_PATTERN, &tempH));
476   PetscCall(HessianMisfit(tao, x, H, H, ctx));
477   PetscCall(HessianRegularization(tao, x, tempH, tempH, ctx));
478   PetscCall(MatAXPY(H, 1., tempH, DIFFERENT_NONZERO_PATTERN));
479   if (Hpre != H) {
480     PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN));
481   }
482   PetscCall(MatDestroy(&tempH));
483   PetscFunctionReturn(0);
484 }
485 
486 static PetscErrorCode TaoSolveADMM(UserCtx ctx,  Vec x)
487 {
488   PetscInt       i;
489   PetscReal      u_norm, r_norm, s_norm, primal, dual, x_norm, z_norm;
490   Tao            tao1,tao2;
491   Vec            xk,z,u,diff,zold,zdiff,temp;
492   PetscReal      mu;
493 
494   PetscFunctionBegin;
495   xk    = ctx->workRight[4];
496   z     = ctx->workRight[5];
497   u     = ctx->workRight[6];
498   diff  = ctx->workRight[7];
499   zold  = ctx->workRight[8];
500   zdiff = ctx->workRight[9];
501   temp  = ctx->workRight[11];
502   mu    = ctx->mu;
503   PetscCall(VecSet(u, 0.));
504   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao1));
505   PetscCall(TaoSetType(tao1,TAONLS));
506   PetscCall(TaoSetObjective(tao1, ObjectiveMisfitADMM, (void*) ctx));
507   PetscCall(TaoSetGradient(tao1, NULL, GradientMisfitADMM, (void*) ctx));
508   PetscCall(TaoSetHessian(tao1, ctx->Hm, ctx->Hm, HessianMisfitADMM, (void*) ctx));
509   PetscCall(VecSet(xk, 0.));
510   PetscCall(TaoSetSolution(tao1, xk));
511   PetscCall(TaoSetOptionsPrefix(tao1, "misfit_"));
512   PetscCall(TaoSetFromOptions(tao1));
513   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao2));
514   if (ctx->p == NORM_2) {
515     PetscCall(TaoSetType(tao2,TAONLS));
516     PetscCall(TaoSetObjective(tao2, ObjectiveRegularizationADMM, (void*) ctx));
517     PetscCall(TaoSetGradient(tao2, NULL, GradientRegularizationADMM, (void*) ctx));
518     PetscCall(TaoSetHessian(tao2, ctx->Hr, ctx->Hr, HessianRegularizationADMM, (void*) ctx));
519   }
520   PetscCall(VecSet(z, 0.));
521   PetscCall(TaoSetSolution(tao2, z));
522   PetscCall(TaoSetOptionsPrefix(tao2, "reg_"));
523   PetscCall(TaoSetFromOptions(tao2));
524 
525   for (i=0; i<ctx->iter; i++) {
526     PetscCall(VecCopy(z,zold));
527     PetscCall(TaoSolve(tao1)); /* Updates xk */
528     if (ctx->p == NORM_1) {
529       PetscCall(VecWAXPY(temp,1.,xk,u));
530       PetscCall(TaoSoftThreshold(temp,-ctx->alpha/mu,ctx->alpha/mu,z));
531     } else {
532       PetscCall(TaoSolve(tao2)); /* Update zk */
533     }
534     /* u = u + xk -z */
535     PetscCall(VecAXPBYPCZ(u,1.,-1.,1.,xk,z));
536     /* r_norm : norm(x-z) */
537     PetscCall(VecWAXPY(diff,-1.,z,xk));
538     PetscCall(VecNorm(diff,NORM_2,&r_norm));
539     /* s_norm : norm(-mu(z-zold)) */
540     PetscCall(VecWAXPY(zdiff, -1.,zold,z));
541     PetscCall(VecNorm(zdiff,NORM_2,&s_norm));
542     s_norm = s_norm * mu;
543     /* primal : sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z))*/
544     PetscCall(VecNorm(xk,NORM_2,&x_norm));
545     PetscCall(VecNorm(z,NORM_2,&z_norm));
546     primal = PetscSqrtReal(ctx->n)*ctx->abstol + ctx->reltol*PetscMax(x_norm,z_norm);
547     /* Duality : sqrt(n)*ABSTOL + RELTOL*norm(mu*u)*/
548     PetscCall(VecNorm(u,NORM_2,&u_norm));
549     dual = PetscSqrtReal(ctx->n)*ctx->abstol + ctx->reltol*u_norm*mu;
550     PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao1),"Iter %D : ||x-z||: %g, mu*||z-zold||: %g\n", i, (double) r_norm, (double) s_norm));
551     if (r_norm < primal && s_norm < dual) break;
552   }
553   PetscCall(VecCopy(xk, x));
554   PetscCall(TaoDestroy(&tao1));
555   PetscCall(TaoDestroy(&tao2));
556   PetscFunctionReturn(0);
557 }
558 
559 /* Second order Taylor remainder convergence test */
560 static PetscErrorCode TaylorTest(UserCtx ctx, Tao tao, Vec x, PetscReal *C)
561 {
562   PetscReal      h,J,temp;
563   PetscInt       i,j;
564   PetscInt       numValues;
565   PetscReal      Jx,Jxhat_comp,Jxhat_pred;
566   PetscReal      *Js, *hs;
567   PetscReal      gdotdx;
568   PetscReal      minrate = PETSC_MAX_REAL;
569   MPI_Comm       comm = PetscObjectComm((PetscObject)x);
570   Vec            g, dx, xhat;
571 
572   PetscFunctionBegin;
573   PetscCall(VecDuplicate(x, &g));
574   PetscCall(VecDuplicate(x, &xhat));
575   /* choose a perturbation direction */
576   PetscCall(VecDuplicate(x, &dx));
577   PetscCall(VecSetRandom(dx,ctx->rctx));
578   /* evaluate objective at x: J(x) */
579   PetscCall(TaoComputeObjective(tao, x, &Jx));
580   /* evaluate gradient at x, save in vector g */
581   PetscCall(TaoComputeGradient(tao, x, g));
582   PetscCall(VecDot(g, dx, &gdotdx));
583 
584   for (numValues=0, h=ctx->hStart; h>=ctx->hMin; h*=ctx->hFactor) numValues++;
585   PetscCall(PetscCalloc2(numValues, &Js, numValues, &hs));
586   for (i=0, h=ctx->hStart; h>=ctx->hMin; h*=ctx->hFactor, i++) {
587     PetscCall(VecWAXPY(xhat, h, dx, x));
588     PetscCall(TaoComputeObjective(tao, xhat, &Jxhat_comp));
589     /* J(\hat(x)) \approx J(x) + g^T (xhat - x) = J(x) + h * g^T dx */
590     Jxhat_pred = Jx + h * gdotdx;
591     /* Vector to dJdm scalar? Dot?*/
592     J     = PetscAbsReal(Jxhat_comp - Jxhat_pred);
593     PetscCall(PetscPrintf (comm, "J(xhat): %g, predicted: %g, diff %g\n", (double) Jxhat_comp,(double) Jxhat_pred, (double) J));
594     Js[i] = J;
595     hs[i] = h;
596   }
597   for (j=1; j<numValues; j++) {
598     temp    = PetscLogReal(Js[j]/Js[j-1]) / PetscLogReal (hs[j]/hs[j-1]);
599     PetscCall(PetscPrintf (comm, "Convergence rate step %D: %g\n", j-1, (double) temp));
600     minrate = PetscMin(minrate, temp);
601   }
602   /* If O is not ~2, then the test is wrong */
603   PetscCall(PetscFree2(Js, hs));
604   *C   = minrate;
605   PetscCall(VecDestroy(&dx));
606   PetscCall(VecDestroy(&xhat));
607   PetscCall(VecDestroy(&g));
608   PetscFunctionReturn(0);
609 }
610 
611 int main(int argc, char ** argv)
612 {
613   UserCtx ctx;
614   Tao     tao;
615   Vec     x;
616   Mat     H;
617 
618   PetscCall(PetscInitialize(&argc, &argv, NULL,help));
619   PetscCall(PetscNew(&ctx));
620   PetscCall(ConfigureContext(ctx));
621   /* Define two functions that could pass as objectives to TaoSetObjective(): one
622    * for the misfit component, and one for the regularization component */
623   /* ObjectiveMisfit() and ObjectiveRegularization() */
624 
625   /* Define a single function that calls both components adds them together: the complete objective,
626    * in the absence of a Tao implementation that handles separability */
627   /* ObjectiveComplete() */
628   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao));
629   PetscCall(TaoSetType(tao,TAONM));
630   PetscCall(TaoSetObjective(tao, ObjectiveComplete, (void*) ctx));
631   PetscCall(TaoSetGradient(tao, NULL, GradientComplete, (void*) ctx));
632   PetscCall(MatDuplicate(ctx->W, MAT_SHARE_NONZERO_PATTERN, &H));
633   PetscCall(TaoSetHessian(tao, H, H, HessianComplete, (void*) ctx));
634   PetscCall(MatCreateVecs(ctx->F, NULL, &x));
635   PetscCall(VecSet(x, 0.));
636   PetscCall(TaoSetSolution(tao, x));
637   PetscCall(TaoSetFromOptions(tao));
638   if (ctx->use_admm) {
639     PetscCall(TaoSolveADMM(ctx,x));
640   } else PetscCall(TaoSolve(tao));
641   /* examine solution */
642   PetscCall(VecViewFromOptions(x, NULL, "-view_sol"));
643   if (ctx->taylor) {
644     PetscReal rate;
645     PetscCall(TaylorTest(ctx, tao, x, &rate));
646   }
647   PetscCall(MatDestroy(&H));
648   PetscCall(TaoDestroy(&tao));
649   PetscCall(VecDestroy(&x));
650   PetscCall(DestroyContext(&ctx));
651   PetscCall(PetscFinalize());
652   return 0;
653 }
654 
655 /*TEST
656 
657   build:
658     requires: !complex
659 
660   test:
661     suffix: 0
662     args:
663 
664   test:
665     suffix: l1_1
666     args: -p 1 -tao_type lmvm -alpha 1. -epsilon 1.e-7 -m 64 -n 64 -view_sol -matrix_format 1
667 
668   test:
669     suffix: hessian_1
670     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nls
671 
672   test:
673     suffix: hessian_2
674     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nls
675 
676   test:
677     suffix: nm_1
678     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nm -tao_max_it 50
679 
680   test:
681     suffix: nm_2
682     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nm -tao_max_it 50
683 
684   test:
685     suffix: lmvm_1
686     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type lmvm -tao_max_it 40
687 
688   test:
689     suffix: lmvm_2
690     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type lmvm -tao_max_it 15
691 
692   test:
693     suffix: soft_threshold_admm_1
694     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm
695 
696   test:
697     suffix: hessian_admm_1
698     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nls -misfit_tao_type nls
699 
700   test:
701     suffix: hessian_admm_2
702     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nls -misfit_tao_type nls
703 
704   test:
705     suffix: nm_admm_1
706     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nm -misfit_tao_type nm
707 
708   test:
709     suffix: nm_admm_2
710     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nm -misfit_tao_type nm -iter 7
711 
712   test:
713     suffix: lmvm_admm_1
714     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm
715 
716   test:
717     suffix: lmvm_admm_2
718     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm
719 
720 TEST*/
721