xref: /petsc/src/tao/tutorials/ex4.c (revision 9371c9d470a9602b6d10a8bf50c9b2280a79e45a)
1 static char help[] = "Simple example to test separable objective optimizers.\n";
2 
3 #include <petsc.h>
4 #include <petsctao.h>
5 #include <petscvec.h>
6 #include <petscmath.h>
7 
8 #define NWORKLEFT  4
9 #define NWORKRIGHT 12
10 
11 typedef struct _UserCtx {
12   PetscInt    m;       /* The row dimension of F */
13   PetscInt    n;       /* The column dimension of F */
14   PetscInt    matops;  /* Matrix format. 0 for stencil, 1 for random */
15   PetscInt    iter;    /* Numer of iterations for ADMM */
16   PetscReal   hStart;  /* Starting point for Taylor test */
17   PetscReal   hFactor; /* Taylor test step factor */
18   PetscReal   hMin;    /* Taylor test end goal */
19   PetscReal   alpha;   /* regularization constant applied to || x ||_p */
20   PetscReal   eps;     /* small constant for approximating gradient of || x ||_1 */
21   PetscReal   mu;      /* the augmented Lagrangian term in ADMM */
22   PetscReal   abstol;
23   PetscReal   reltol;
24   Mat         F;                     /* matrix in least squares component $(1/2) * || F x - d ||_2^2$ */
25   Mat         W;                     /* Workspace matrix. ATA */
26   Mat         Hm;                    /* Hessian Misfit*/
27   Mat         Hr;                    /* Hessian Reg*/
28   Vec         d;                     /* RHS in least squares component $(1/2) * || F x - d ||_2^2$ */
29   Vec         workLeft[NWORKLEFT];   /* Workspace for temporary vec */
30   Vec         workRight[NWORKRIGHT]; /* Workspace for temporary vec */
31   NormType    p;
32   PetscRandom rctx;
33   PetscBool   taylor;   /* Flag to determine whether to run Taylor test or not */
34   PetscBool   use_admm; /* Flag to determine whether to run Taylor test or not */
35 } * UserCtx;
36 
37 static PetscErrorCode CreateRHS(UserCtx ctx) {
38   PetscFunctionBegin;
39   /* build the rhs d in ctx */
40   PetscCall(VecCreate(PETSC_COMM_WORLD, &(ctx->d)));
41   PetscCall(VecSetSizes(ctx->d, PETSC_DECIDE, ctx->m));
42   PetscCall(VecSetFromOptions(ctx->d));
43   PetscCall(VecSetRandom(ctx->d, ctx->rctx));
44   PetscFunctionReturn(0);
45 }
46 
47 static PetscErrorCode CreateMatrix(UserCtx ctx) {
48   PetscInt Istart, Iend, i, j, Ii, gridN, I_n, I_s, I_e, I_w;
49 #if defined(PETSC_USE_LOG)
50   PetscLogStage stage;
51 #endif
52 
53   PetscFunctionBegin;
54   /* build the matrix F in ctx */
55   PetscCall(MatCreate(PETSC_COMM_WORLD, &(ctx->F)));
56   PetscCall(MatSetSizes(ctx->F, PETSC_DECIDE, PETSC_DECIDE, ctx->m, ctx->n));
57   PetscCall(MatSetType(ctx->F, MATAIJ));                          /* TODO: Decide specific SetType other than dummy*/
58   PetscCall(MatMPIAIJSetPreallocation(ctx->F, 5, NULL, 5, NULL)); /*TODO: some number other than 5?*/
59   PetscCall(MatSeqAIJSetPreallocation(ctx->F, 5, NULL));
60   PetscCall(MatSetUp(ctx->F));
61   PetscCall(MatGetOwnershipRange(ctx->F, &Istart, &Iend));
62   PetscCall(PetscLogStageRegister("Assembly", &stage));
63   PetscCall(PetscLogStagePush(stage));
64 
65   /* Set matrix elements in  2-D five point stencil format. */
66   if (!(ctx->matops)) {
67     PetscCheck(ctx->m == ctx->n, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Stencil matrix must be square");
68     gridN = (PetscInt)PetscSqrtReal((PetscReal)ctx->m);
69     PetscCheck(gridN * gridN == ctx->m, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Number of rows must be square");
70     for (Ii = Istart; Ii < Iend; Ii++) {
71       i   = Ii / gridN;
72       j   = Ii % gridN;
73       I_n = i * gridN + j + 1;
74       if (j + 1 >= gridN) I_n = -1;
75       I_s = i * gridN + j - 1;
76       if (j - 1 < 0) I_s = -1;
77       I_e = (i + 1) * gridN + j;
78       if (i + 1 >= gridN) I_e = -1;
79       I_w = (i - 1) * gridN + j;
80       if (i - 1 < 0) I_w = -1;
81       PetscCall(MatSetValue(ctx->F, Ii, Ii, 4., INSERT_VALUES));
82       PetscCall(MatSetValue(ctx->F, Ii, I_n, -1., INSERT_VALUES));
83       PetscCall(MatSetValue(ctx->F, Ii, I_s, -1., INSERT_VALUES));
84       PetscCall(MatSetValue(ctx->F, Ii, I_e, -1., INSERT_VALUES));
85       PetscCall(MatSetValue(ctx->F, Ii, I_w, -1., INSERT_VALUES));
86     }
87   } else PetscCall(MatSetRandom(ctx->F, ctx->rctx));
88   PetscCall(MatAssemblyBegin(ctx->F, MAT_FINAL_ASSEMBLY));
89   PetscCall(MatAssemblyEnd(ctx->F, MAT_FINAL_ASSEMBLY));
90   PetscCall(PetscLogStagePop());
91   /* Stencil matrix is symmetric. Setting symmetric flag for ICC/Cholesky preconditioner */
92   if (!(ctx->matops)) { PetscCall(MatSetOption(ctx->F, MAT_SYMMETRIC, PETSC_TRUE)); }
93   PetscCall(MatTransposeMatMult(ctx->F, ctx->F, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(ctx->W)));
94   /* Setup Hessian Workspace in same shape as W */
95   PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &(ctx->Hm)));
96   PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &(ctx->Hr)));
97   PetscFunctionReturn(0);
98 }
99 
100 static PetscErrorCode SetupWorkspace(UserCtx ctx) {
101   PetscInt i;
102 
103   PetscFunctionBegin;
104   PetscCall(MatCreateVecs(ctx->F, &ctx->workLeft[0], &ctx->workRight[0]));
105   for (i = 1; i < NWORKLEFT; i++) { PetscCall(VecDuplicate(ctx->workLeft[0], &(ctx->workLeft[i]))); }
106   for (i = 1; i < NWORKRIGHT; i++) { PetscCall(VecDuplicate(ctx->workRight[0], &(ctx->workRight[i]))); }
107   PetscFunctionReturn(0);
108 }
109 
110 static PetscErrorCode ConfigureContext(UserCtx ctx) {
111   PetscFunctionBegin;
112   ctx->m        = 16;
113   ctx->n        = 16;
114   ctx->eps      = 1.e-3;
115   ctx->abstol   = 1.e-4;
116   ctx->reltol   = 1.e-2;
117   ctx->hStart   = 1.;
118   ctx->hMin     = 1.e-3;
119   ctx->hFactor  = 0.5;
120   ctx->alpha    = 1.;
121   ctx->mu       = 1.0;
122   ctx->matops   = 0;
123   ctx->iter     = 10;
124   ctx->p        = NORM_2;
125   ctx->taylor   = PETSC_TRUE;
126   ctx->use_admm = PETSC_FALSE;
127   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "ex4.c");
128   PetscCall(PetscOptionsInt("-m", "The row dimension of matrix F", "ex4.c", ctx->m, &(ctx->m), NULL));
129   PetscCall(PetscOptionsInt("-n", "The column dimension of matrix F", "ex4.c", ctx->n, &(ctx->n), NULL));
130   PetscCall(PetscOptionsInt("-matrix_format", "Decide format of F matrix. 0 for stencil, 1 for random", "ex4.c", ctx->matops, &(ctx->matops), NULL));
131   PetscCall(PetscOptionsInt("-iter", "Iteration number ADMM", "ex4.c", ctx->iter, &(ctx->iter), NULL));
132   PetscCall(PetscOptionsReal("-alpha", "The regularization multiplier. 1 default", "ex4.c", ctx->alpha, &(ctx->alpha), NULL));
133   PetscCall(PetscOptionsReal("-epsilon", "The small constant added to |x_i| in the denominator to approximate the gradient of ||x||_1", "ex4.c", ctx->eps, &(ctx->eps), NULL));
134   PetscCall(PetscOptionsReal("-mu", "The augmented lagrangian multiplier in ADMM", "ex4.c", ctx->mu, &(ctx->mu), NULL));
135   PetscCall(PetscOptionsReal("-hStart", "Taylor test starting point. 1 default.", "ex4.c", ctx->hStart, &(ctx->hStart), NULL));
136   PetscCall(PetscOptionsReal("-hFactor", "Taylor test multiplier factor. 0.5 default", "ex4.c", ctx->hFactor, &(ctx->hFactor), NULL));
137   PetscCall(PetscOptionsReal("-hMin", "Taylor test ending condition. 1.e-3 default", "ex4.c", ctx->hMin, &(ctx->hMin), NULL));
138   PetscCall(PetscOptionsReal("-abstol", "Absolute stopping criterion for ADMM", "ex4.c", ctx->abstol, &(ctx->abstol), NULL));
139   PetscCall(PetscOptionsReal("-reltol", "Relative stopping criterion for ADMM", "ex4.c", ctx->reltol, &(ctx->reltol), NULL));
140   PetscCall(PetscOptionsBool("-taylor", "Flag for Taylor test. Default is true.", "ex4.c", ctx->taylor, &(ctx->taylor), NULL));
141   PetscCall(PetscOptionsBool("-use_admm", "Use the ADMM solver in this example.", "ex4.c", ctx->use_admm, &(ctx->use_admm), NULL));
142   PetscCall(PetscOptionsEnum("-p", "Norm type.", "ex4.c", NormTypes, (PetscEnum)ctx->p, (PetscEnum *)&(ctx->p), NULL));
143   PetscOptionsEnd();
144   /* Creating random ctx */
145   PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &(ctx->rctx)));
146   PetscCall(PetscRandomSetFromOptions(ctx->rctx));
147   PetscCall(CreateMatrix(ctx));
148   PetscCall(CreateRHS(ctx));
149   PetscCall(SetupWorkspace(ctx));
150   PetscFunctionReturn(0);
151 }
152 
153 static PetscErrorCode DestroyContext(UserCtx *ctx) {
154   PetscInt i;
155 
156   PetscFunctionBegin;
157   PetscCall(MatDestroy(&((*ctx)->F)));
158   PetscCall(MatDestroy(&((*ctx)->W)));
159   PetscCall(MatDestroy(&((*ctx)->Hm)));
160   PetscCall(MatDestroy(&((*ctx)->Hr)));
161   PetscCall(VecDestroy(&((*ctx)->d)));
162   for (i = 0; i < NWORKLEFT; i++) { PetscCall(VecDestroy(&((*ctx)->workLeft[i]))); }
163   for (i = 0; i < NWORKRIGHT; i++) { PetscCall(VecDestroy(&((*ctx)->workRight[i]))); }
164   PetscCall(PetscRandomDestroy(&((*ctx)->rctx)));
165   PetscCall(PetscFree(*ctx));
166   PetscFunctionReturn(0);
167 }
168 
169 /* compute (1/2) * ||F x - d||^2 */
170 static PetscErrorCode ObjectiveMisfit(Tao tao, Vec x, PetscReal *J, void *_ctx) {
171   UserCtx ctx = (UserCtx)_ctx;
172   Vec     y;
173 
174   PetscFunctionBegin;
175   y = ctx->workLeft[0];
176   PetscCall(MatMult(ctx->F, x, y));
177   PetscCall(VecAXPY(y, -1., ctx->d));
178   PetscCall(VecDot(y, y, J));
179   *J *= 0.5;
180   PetscFunctionReturn(0);
181 }
182 
183 /* compute V = FTFx - FTd */
184 static PetscErrorCode GradientMisfit(Tao tao, Vec x, Vec V, void *_ctx) {
185   UserCtx ctx = (UserCtx)_ctx;
186   Vec     FTFx, FTd;
187 
188   PetscFunctionBegin;
189   /* work1 is A^T Ax, work2 is Ab, W is A^T A*/
190   FTFx = ctx->workRight[0];
191   FTd  = ctx->workRight[1];
192   PetscCall(MatMult(ctx->W, x, FTFx));
193   PetscCall(MatMultTranspose(ctx->F, ctx->d, FTd));
194   PetscCall(VecWAXPY(V, -1., FTd, FTFx));
195   PetscFunctionReturn(0);
196 }
197 
198 /* returns FTF */
199 static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) {
200   UserCtx ctx = (UserCtx)_ctx;
201 
202   PetscFunctionBegin;
203   if (H != ctx->W) PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
204   if (Hpre != ctx->W) PetscCall(MatCopy(ctx->W, Hpre, DIFFERENT_NONZERO_PATTERN));
205   PetscFunctionReturn(0);
206 }
207 
208 /* computes augment Lagrangian objective (with scaled dual):
209  * 0.5 * ||F x - d||^2  + 0.5 * mu ||x - z + u||^2 */
210 static PetscErrorCode ObjectiveMisfitADMM(Tao tao, Vec x, PetscReal *J, void *_ctx) {
211   UserCtx   ctx = (UserCtx)_ctx;
212   PetscReal mu, workNorm, misfit;
213   Vec       z, u, temp;
214 
215   PetscFunctionBegin;
216   mu   = ctx->mu;
217   z    = ctx->workRight[5];
218   u    = ctx->workRight[6];
219   temp = ctx->workRight[10];
220   /* misfit = f(x) */
221   PetscCall(ObjectiveMisfit(tao, x, &misfit, _ctx));
222   PetscCall(VecCopy(x, temp));
223   /* temp = x - z + u */
224   PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u));
225   /* workNorm = ||x - z + u||^2 */
226   PetscCall(VecDot(temp, temp, &workNorm));
227   /* augment Lagrangian objective (with scaled dual): f(x) + 0.5 * mu ||x -z + u||^2 */
228   *J = misfit + 0.5 * mu * workNorm;
229   PetscFunctionReturn(0);
230 }
231 
232 /* computes FTFx - FTd  mu*(x - z + u) */
233 static PetscErrorCode GradientMisfitADMM(Tao tao, Vec x, Vec V, void *_ctx) {
234   UserCtx   ctx = (UserCtx)_ctx;
235   PetscReal mu;
236   Vec       z, u, temp;
237 
238   PetscFunctionBegin;
239   mu   = ctx->mu;
240   z    = ctx->workRight[5];
241   u    = ctx->workRight[6];
242   temp = ctx->workRight[10];
243   PetscCall(GradientMisfit(tao, x, V, _ctx));
244   PetscCall(VecCopy(x, temp));
245   /* temp = x - z + u */
246   PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u));
247   /* V =  FTFx - FTd  mu*(x - z + u) */
248   PetscCall(VecAXPY(V, mu, temp));
249   PetscFunctionReturn(0);
250 }
251 
252 /* returns FTF + diag(mu) */
253 static PetscErrorCode HessianMisfitADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) {
254   UserCtx ctx = (UserCtx)_ctx;
255 
256   PetscFunctionBegin;
257   PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
258   PetscCall(MatShift(H, ctx->mu));
259   if (Hpre != H) { PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN)); }
260   PetscFunctionReturn(0);
261 }
262 
263 /* computes || x ||_p (mult by 0.5 in case of NORM_2) */
264 static PetscErrorCode ObjectiveRegularization(Tao tao, Vec x, PetscReal *J, void *_ctx) {
265   UserCtx   ctx = (UserCtx)_ctx;
266   PetscReal norm;
267 
268   PetscFunctionBegin;
269   *J = 0;
270   PetscCall(VecNorm(x, ctx->p, &norm));
271   if (ctx->p == NORM_2) norm = 0.5 * norm * norm;
272   *J = ctx->alpha * norm;
273   PetscFunctionReturn(0);
274 }
275 
276 /* NORM_2 Case: return x
277  * NORM_1 Case: x/(|x| + eps)
278  * Else: TODO */
279 static PetscErrorCode GradientRegularization(Tao tao, Vec x, Vec V, void *_ctx) {
280   UserCtx   ctx = (UserCtx)_ctx;
281   PetscReal eps = ctx->eps;
282 
283   PetscFunctionBegin;
284   if (ctx->p == NORM_2) {
285     PetscCall(VecCopy(x, V));
286   } else if (ctx->p == NORM_1) {
287     PetscCall(VecCopy(x, ctx->workRight[1]));
288     PetscCall(VecAbs(ctx->workRight[1]));
289     PetscCall(VecShift(ctx->workRight[1], eps));
290     PetscCall(VecPointwiseDivide(V, x, ctx->workRight[1]));
291   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
292   PetscFunctionReturn(0);
293 }
294 
295 /* NORM_2 Case: returns diag(mu)
296  * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)))  */
297 static PetscErrorCode HessianRegularization(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) {
298   UserCtx   ctx = (UserCtx)_ctx;
299   PetscReal eps = ctx->eps;
300   Vec       copy1, copy2, copy3;
301 
302   PetscFunctionBegin;
303   if (ctx->p == NORM_2) {
304     /* Identity matrix scaled by mu */
305     PetscCall(MatZeroEntries(H));
306     PetscCall(MatShift(H, ctx->mu));
307     if (Hpre != H) {
308       PetscCall(MatZeroEntries(Hpre));
309       PetscCall(MatShift(Hpre, ctx->mu));
310     }
311   } else if (ctx->p == NORM_1) {
312     /* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)) */
313     copy1 = ctx->workRight[1];
314     copy2 = ctx->workRight[2];
315     copy3 = ctx->workRight[3];
316     /* copy1 : 1/sqrt(x_i^2 + eps) */
317     PetscCall(VecCopy(x, copy1));
318     PetscCall(VecPow(copy1, 2));
319     PetscCall(VecShift(copy1, eps));
320     PetscCall(VecSqrtAbs(copy1));
321     PetscCall(VecReciprocal(copy1));
322     /* copy2:  x_i^2.*/
323     PetscCall(VecCopy(x, copy2));
324     PetscCall(VecPow(copy2, 2));
325     /* copy3: abs(x_i^2 + eps) */
326     PetscCall(VecCopy(x, copy3));
327     PetscCall(VecPow(copy3, 2));
328     PetscCall(VecShift(copy3, eps));
329     PetscCall(VecAbs(copy3));
330     /* copy2: 1 - x_i^2/abs(x_i^2 + eps) */
331     PetscCall(VecPointwiseDivide(copy2, copy2, copy3));
332     PetscCall(VecScale(copy2, -1.));
333     PetscCall(VecShift(copy2, 1.));
334     PetscCall(VecAXPY(copy1, 1., copy2));
335     PetscCall(VecScale(copy1, ctx->mu));
336     PetscCall(MatZeroEntries(H));
337     PetscCall(MatDiagonalSet(H, copy1, INSERT_VALUES));
338     if (Hpre != H) {
339       PetscCall(MatZeroEntries(Hpre));
340       PetscCall(MatDiagonalSet(Hpre, copy1, INSERT_VALUES));
341     }
342   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
343   PetscFunctionReturn(0);
344 }
345 
346 /* NORM_2 Case: 0.5 || x ||_2 + 0.5 * mu * ||x + u - z||^2
347  * Else : || x ||_2 + 0.5 * mu * ||x + u - z||^2 */
348 static PetscErrorCode ObjectiveRegularizationADMM(Tao tao, Vec z, PetscReal *J, void *_ctx) {
349   UserCtx   ctx = (UserCtx)_ctx;
350   PetscReal mu, workNorm, reg;
351   Vec       x, u, temp;
352 
353   PetscFunctionBegin;
354   mu   = ctx->mu;
355   x    = ctx->workRight[4];
356   u    = ctx->workRight[6];
357   temp = ctx->workRight[10];
358   PetscCall(ObjectiveRegularization(tao, z, &reg, _ctx));
359   PetscCall(VecCopy(z, temp));
360   /* temp = x + u -z */
361   PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u));
362   /* workNorm = ||x + u - z ||^2 */
363   PetscCall(VecDot(temp, temp, &workNorm));
364   *J = reg + 0.5 * mu * workNorm;
365   PetscFunctionReturn(0);
366 }
367 
368 /* NORM_2 Case: x - mu*(x + u - z)
369  * NORM_1 Case: x/(|x| + eps) - mu*(x + u - z)
370  * Else: TODO */
371 static PetscErrorCode GradientRegularizationADMM(Tao tao, Vec z, Vec V, void *_ctx) {
372   UserCtx   ctx = (UserCtx)_ctx;
373   PetscReal mu;
374   Vec       x, u, temp;
375 
376   PetscFunctionBegin;
377   mu   = ctx->mu;
378   x    = ctx->workRight[4];
379   u    = ctx->workRight[6];
380   temp = ctx->workRight[10];
381   PetscCall(GradientRegularization(tao, z, V, _ctx));
382   PetscCall(VecCopy(z, temp));
383   /* temp = x + u -z */
384   PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u));
385   PetscCall(VecAXPY(V, -mu, temp));
386   PetscFunctionReturn(0);
387 }
388 
389 /* NORM_2 Case: returns diag(mu)
390  * NORM_1 Case: FTF + diag(mu) */
391 static PetscErrorCode HessianRegularizationADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) {
392   UserCtx ctx = (UserCtx)_ctx;
393 
394   PetscFunctionBegin;
395   if (ctx->p == NORM_2) {
396     /* Identity matrix scaled by mu */
397     PetscCall(MatZeroEntries(H));
398     PetscCall(MatShift(H, ctx->mu));
399     if (Hpre != H) {
400       PetscCall(MatZeroEntries(Hpre));
401       PetscCall(MatShift(Hpre, ctx->mu));
402     }
403   } else if (ctx->p == NORM_1) {
404     PetscCall(HessianMisfit(tao, x, H, Hpre, (void *)ctx));
405     PetscCall(MatShift(H, ctx->mu));
406     if (Hpre != H) PetscCall(MatShift(Hpre, ctx->mu));
407   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
408   PetscFunctionReturn(0);
409 }
410 
411 /* NORM_2 Case : (1/2) * ||F x - d||^2 + 0.5 * || x ||_p
412 *  NORM_1 Case : (1/2) * ||F x - d||^2 + || x ||_p */
413 static PetscErrorCode ObjectiveComplete(Tao tao, Vec x, PetscReal *J, void *ctx) {
414   PetscReal Jm, Jr;
415 
416   PetscFunctionBegin;
417   PetscCall(ObjectiveMisfit(tao, x, &Jm, ctx));
418   PetscCall(ObjectiveRegularization(tao, x, &Jr, ctx));
419   *J = Jm + Jr;
420   PetscFunctionReturn(0);
421 }
422 
423 /* NORM_2 Case: FTFx - FTd + x
424  * NORM_1 Case: FTFx - FTd + x/(|x| + eps) */
425 static PetscErrorCode GradientComplete(Tao tao, Vec x, Vec V, void *ctx) {
426   UserCtx cntx = (UserCtx)ctx;
427 
428   PetscFunctionBegin;
429   PetscCall(GradientMisfit(tao, x, cntx->workRight[2], ctx));
430   PetscCall(GradientRegularization(tao, x, cntx->workRight[3], ctx));
431   PetscCall(VecWAXPY(V, 1, cntx->workRight[2], cntx->workRight[3]));
432   PetscFunctionReturn(0);
433 }
434 
435 /* NORM_2 Case: diag(mu) + FTF
436  * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) + FTF  */
437 static PetscErrorCode HessianComplete(Tao tao, Vec x, Mat H, Mat Hpre, void *ctx) {
438   Mat tempH;
439 
440   PetscFunctionBegin;
441   PetscCall(MatDuplicate(H, MAT_SHARE_NONZERO_PATTERN, &tempH));
442   PetscCall(HessianMisfit(tao, x, H, H, ctx));
443   PetscCall(HessianRegularization(tao, x, tempH, tempH, ctx));
444   PetscCall(MatAXPY(H, 1., tempH, DIFFERENT_NONZERO_PATTERN));
445   if (Hpre != H) { PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN)); }
446   PetscCall(MatDestroy(&tempH));
447   PetscFunctionReturn(0);
448 }
449 
450 static PetscErrorCode TaoSolveADMM(UserCtx ctx, Vec x) {
451   PetscInt  i;
452   PetscReal u_norm, r_norm, s_norm, primal, dual, x_norm, z_norm;
453   Tao       tao1, tao2;
454   Vec       xk, z, u, diff, zold, zdiff, temp;
455   PetscReal mu;
456 
457   PetscFunctionBegin;
458   xk    = ctx->workRight[4];
459   z     = ctx->workRight[5];
460   u     = ctx->workRight[6];
461   diff  = ctx->workRight[7];
462   zold  = ctx->workRight[8];
463   zdiff = ctx->workRight[9];
464   temp  = ctx->workRight[11];
465   mu    = ctx->mu;
466   PetscCall(VecSet(u, 0.));
467   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao1));
468   PetscCall(TaoSetType(tao1, TAONLS));
469   PetscCall(TaoSetObjective(tao1, ObjectiveMisfitADMM, (void *)ctx));
470   PetscCall(TaoSetGradient(tao1, NULL, GradientMisfitADMM, (void *)ctx));
471   PetscCall(TaoSetHessian(tao1, ctx->Hm, ctx->Hm, HessianMisfitADMM, (void *)ctx));
472   PetscCall(VecSet(xk, 0.));
473   PetscCall(TaoSetSolution(tao1, xk));
474   PetscCall(TaoSetOptionsPrefix(tao1, "misfit_"));
475   PetscCall(TaoSetFromOptions(tao1));
476   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao2));
477   if (ctx->p == NORM_2) {
478     PetscCall(TaoSetType(tao2, TAONLS));
479     PetscCall(TaoSetObjective(tao2, ObjectiveRegularizationADMM, (void *)ctx));
480     PetscCall(TaoSetGradient(tao2, NULL, GradientRegularizationADMM, (void *)ctx));
481     PetscCall(TaoSetHessian(tao2, ctx->Hr, ctx->Hr, HessianRegularizationADMM, (void *)ctx));
482   }
483   PetscCall(VecSet(z, 0.));
484   PetscCall(TaoSetSolution(tao2, z));
485   PetscCall(TaoSetOptionsPrefix(tao2, "reg_"));
486   PetscCall(TaoSetFromOptions(tao2));
487 
488   for (i = 0; i < ctx->iter; i++) {
489     PetscCall(VecCopy(z, zold));
490     PetscCall(TaoSolve(tao1)); /* Updates xk */
491     if (ctx->p == NORM_1) {
492       PetscCall(VecWAXPY(temp, 1., xk, u));
493       PetscCall(TaoSoftThreshold(temp, -ctx->alpha / mu, ctx->alpha / mu, z));
494     } else {
495       PetscCall(TaoSolve(tao2)); /* Update zk */
496     }
497     /* u = u + xk -z */
498     PetscCall(VecAXPBYPCZ(u, 1., -1., 1., xk, z));
499     /* r_norm : norm(x-z) */
500     PetscCall(VecWAXPY(diff, -1., z, xk));
501     PetscCall(VecNorm(diff, NORM_2, &r_norm));
502     /* s_norm : norm(-mu(z-zold)) */
503     PetscCall(VecWAXPY(zdiff, -1., zold, z));
504     PetscCall(VecNorm(zdiff, NORM_2, &s_norm));
505     s_norm = s_norm * mu;
506     /* primal : sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z))*/
507     PetscCall(VecNorm(xk, NORM_2, &x_norm));
508     PetscCall(VecNorm(z, NORM_2, &z_norm));
509     primal = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * PetscMax(x_norm, z_norm);
510     /* Duality : sqrt(n)*ABSTOL + RELTOL*norm(mu*u)*/
511     PetscCall(VecNorm(u, NORM_2, &u_norm));
512     dual = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * u_norm * mu;
513     PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao1), "Iter %" PetscInt_FMT " : ||x-z||: %g, mu*||z-zold||: %g\n", i, (double)r_norm, (double)s_norm));
514     if (r_norm < primal && s_norm < dual) break;
515   }
516   PetscCall(VecCopy(xk, x));
517   PetscCall(TaoDestroy(&tao1));
518   PetscCall(TaoDestroy(&tao2));
519   PetscFunctionReturn(0);
520 }
521 
522 /* Second order Taylor remainder convergence test */
523 static PetscErrorCode TaylorTest(UserCtx ctx, Tao tao, Vec x, PetscReal *C) {
524   PetscReal  h, J, temp;
525   PetscInt   i, j;
526   PetscInt   numValues;
527   PetscReal  Jx, Jxhat_comp, Jxhat_pred;
528   PetscReal *Js, *hs;
529   PetscReal  gdotdx;
530   PetscReal  minrate = PETSC_MAX_REAL;
531   MPI_Comm   comm    = PetscObjectComm((PetscObject)x);
532   Vec        g, dx, xhat;
533 
534   PetscFunctionBegin;
535   PetscCall(VecDuplicate(x, &g));
536   PetscCall(VecDuplicate(x, &xhat));
537   /* choose a perturbation direction */
538   PetscCall(VecDuplicate(x, &dx));
539   PetscCall(VecSetRandom(dx, ctx->rctx));
540   /* evaluate objective at x: J(x) */
541   PetscCall(TaoComputeObjective(tao, x, &Jx));
542   /* evaluate gradient at x, save in vector g */
543   PetscCall(TaoComputeGradient(tao, x, g));
544   PetscCall(VecDot(g, dx, &gdotdx));
545 
546   for (numValues = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor) numValues++;
547   PetscCall(PetscCalloc2(numValues, &Js, numValues, &hs));
548   for (i = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor, i++) {
549     PetscCall(VecWAXPY(xhat, h, dx, x));
550     PetscCall(TaoComputeObjective(tao, xhat, &Jxhat_comp));
551     /* J(\hat(x)) \approx J(x) + g^T (xhat - x) = J(x) + h * g^T dx */
552     Jxhat_pred = Jx + h * gdotdx;
553     /* Vector to dJdm scalar? Dot?*/
554     J          = PetscAbsReal(Jxhat_comp - Jxhat_pred);
555     PetscCall(PetscPrintf(comm, "J(xhat): %g, predicted: %g, diff %g\n", (double)Jxhat_comp, (double)Jxhat_pred, (double)J));
556     Js[i] = J;
557     hs[i] = h;
558   }
559   for (j = 1; j < numValues; j++) {
560     temp = PetscLogReal(Js[j] / Js[j - 1]) / PetscLogReal(hs[j] / hs[j - 1]);
561     PetscCall(PetscPrintf(comm, "Convergence rate step %" PetscInt_FMT ": %g\n", j - 1, (double)temp));
562     minrate = PetscMin(minrate, temp);
563   }
564   /* If O is not ~2, then the test is wrong */
565   PetscCall(PetscFree2(Js, hs));
566   *C = minrate;
567   PetscCall(VecDestroy(&dx));
568   PetscCall(VecDestroy(&xhat));
569   PetscCall(VecDestroy(&g));
570   PetscFunctionReturn(0);
571 }
572 
573 int main(int argc, char **argv) {
574   UserCtx ctx;
575   Tao     tao;
576   Vec     x;
577   Mat     H;
578 
579   PetscFunctionBeginUser;
580   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
581   PetscCall(PetscNew(&ctx));
582   PetscCall(ConfigureContext(ctx));
583   /* Define two functions that could pass as objectives to TaoSetObjective(): one
584    * for the misfit component, and one for the regularization component */
585   /* ObjectiveMisfit() and ObjectiveRegularization() */
586 
587   /* Define a single function that calls both components adds them together: the complete objective,
588    * in the absence of a Tao implementation that handles separability */
589   /* ObjectiveComplete() */
590   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao));
591   PetscCall(TaoSetType(tao, TAONM));
592   PetscCall(TaoSetObjective(tao, ObjectiveComplete, (void *)ctx));
593   PetscCall(TaoSetGradient(tao, NULL, GradientComplete, (void *)ctx));
594   PetscCall(MatDuplicate(ctx->W, MAT_SHARE_NONZERO_PATTERN, &H));
595   PetscCall(TaoSetHessian(tao, H, H, HessianComplete, (void *)ctx));
596   PetscCall(MatCreateVecs(ctx->F, NULL, &x));
597   PetscCall(VecSet(x, 0.));
598   PetscCall(TaoSetSolution(tao, x));
599   PetscCall(TaoSetFromOptions(tao));
600   if (ctx->use_admm) PetscCall(TaoSolveADMM(ctx, x));
601   else PetscCall(TaoSolve(tao));
602   /* examine solution */
603   PetscCall(VecViewFromOptions(x, NULL, "-view_sol"));
604   if (ctx->taylor) {
605     PetscReal rate;
606     PetscCall(TaylorTest(ctx, tao, x, &rate));
607   }
608   PetscCall(MatDestroy(&H));
609   PetscCall(TaoDestroy(&tao));
610   PetscCall(VecDestroy(&x));
611   PetscCall(DestroyContext(&ctx));
612   PetscCall(PetscFinalize());
613   return 0;
614 }
615 
616 /*TEST
617 
618   build:
619     requires: !complex
620 
621   test:
622     suffix: 0
623     args:
624 
625   test:
626     suffix: l1_1
627     args: -p 1 -tao_type lmvm -alpha 1. -epsilon 1.e-7 -m 64 -n 64 -view_sol -matrix_format 1
628 
629   test:
630     suffix: hessian_1
631     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nls
632 
633   test:
634     suffix: hessian_2
635     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nls
636 
637   test:
638     suffix: nm_1
639     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nm -tao_max_it 50
640 
641   test:
642     suffix: nm_2
643     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nm -tao_max_it 50
644 
645   test:
646     suffix: lmvm_1
647     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type lmvm -tao_max_it 40
648 
649   test:
650     suffix: lmvm_2
651     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type lmvm -tao_max_it 15
652 
653   test:
654     suffix: soft_threshold_admm_1
655     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm
656 
657   test:
658     suffix: hessian_admm_1
659     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nls -misfit_tao_type nls
660 
661   test:
662     suffix: hessian_admm_2
663     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nls -misfit_tao_type nls
664 
665   test:
666     suffix: nm_admm_1
667     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nm -misfit_tao_type nm
668 
669   test:
670     suffix: nm_admm_2
671     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nm -misfit_tao_type nm -iter 7
672 
673   test:
674     suffix: lmvm_admm_1
675     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm
676 
677   test:
678     suffix: lmvm_admm_2
679     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm
680 
681 TEST*/
682