xref: /petsc/src/tao/tutorials/ex4.c (revision 3f02e49b19195914bf17f317a25cb39636853415)
1 static char help[] = "Simple example to test separable objective optimizers.\n";
2 
3 #include <petsc.h>
4 #include <petsctao.h>
5 #include <petscvec.h>
6 #include <petscmath.h>
7 
8 #define NWORKLEFT  4
9 #define NWORKRIGHT 12
10 
11 typedef struct _UserCtx {
12   PetscInt    m;       /* The row dimension of F */
13   PetscInt    n;       /* The column dimension of F */
14   PetscInt    matops;  /* Matrix format. 0 for stencil, 1 for random */
15   PetscInt    iter;    /* Number of iterations for ADMM */
16   PetscReal   hStart;  /* Starting point for Taylor test */
17   PetscReal   hFactor; /* Taylor test step factor */
18   PetscReal   hMin;    /* Taylor test end goal */
19   PetscReal   alpha;   /* regularization constant applied to || x ||_p */
20   PetscReal   eps;     /* small constant for approximating gradient of || x ||_1 */
21   PetscReal   mu;      /* the augmented Lagrangian term in ADMM */
22   PetscReal   abstol;
23   PetscReal   reltol;
24   Mat         F;                     /* matrix in least squares component $(1/2) * || F x - d ||_2^2$ */
25   Mat         W;                     /* Workspace matrix. ATA */
26   Mat         Hm;                    /* Hessian Misfit*/
27   Mat         Hr;                    /* Hessian Reg*/
28   Vec         d;                     /* RHS in least squares component $(1/2) * || F x - d ||_2^2$ */
29   Vec         workLeft[NWORKLEFT];   /* Workspace for temporary vec */
30   Vec         workRight[NWORKRIGHT]; /* Workspace for temporary vec */
31   NormType    p;
32   PetscRandom rctx;
33   PetscBool   soft;
34   PetscBool   taylor;   /* Flag to determine whether to run Taylor test or not */
35   PetscBool   use_admm; /* Flag to determine whether to run Taylor test or not */
36 } *UserCtx;
37 
38 static PetscErrorCode CreateRHS(UserCtx ctx)
39 {
40   PetscFunctionBegin;
41   /* build the rhs d in ctx */
42   PetscCall(VecCreate(PETSC_COMM_WORLD, &ctx->d));
43   PetscCall(VecSetSizes(ctx->d, PETSC_DECIDE, ctx->m));
44   PetscCall(VecSetFromOptions(ctx->d));
45   PetscCall(VecSetRandom(ctx->d, ctx->rctx));
46   PetscFunctionReturn(PETSC_SUCCESS);
47 }
48 
49 static PetscErrorCode CreateMatrix(UserCtx ctx)
50 {
51   PetscInt      Istart, Iend, i, j, Ii, gridN, I_n, I_s, I_e, I_w;
52   PetscLogStage stage;
53 
54   PetscFunctionBegin;
55   /* build the matrix F in ctx */
56   PetscCall(MatCreate(PETSC_COMM_WORLD, &ctx->F));
57   PetscCall(MatSetSizes(ctx->F, PETSC_DECIDE, PETSC_DECIDE, ctx->m, ctx->n));
58   PetscCall(MatSetType(ctx->F, MATAIJ));                          /* TODO: Decide specific SetType other than dummy*/
59   PetscCall(MatMPIAIJSetPreallocation(ctx->F, 5, NULL, 5, NULL)); /*TODO: some number other than 5?*/
60   PetscCall(MatSeqAIJSetPreallocation(ctx->F, 5, NULL));
61   PetscCall(MatSetUp(ctx->F));
62   PetscCall(MatGetOwnershipRange(ctx->F, &Istart, &Iend));
63   PetscCall(PetscLogStageRegister("Assembly", &stage));
64   PetscCall(PetscLogStagePush(stage));
65 
66   /* Set matrix elements in  2-D five point stencil format. */
67   if (!ctx->matops) {
68     PetscCheck(ctx->m == ctx->n, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Stencil matrix must be square");
69     gridN = (PetscInt)PetscSqrtReal((PetscReal)ctx->m);
70     PetscCheck(gridN * gridN == ctx->m, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Number of rows must be square");
71     for (Ii = Istart; Ii < Iend; Ii++) {
72       i   = Ii / gridN;
73       j   = Ii % gridN;
74       I_n = i * gridN + j + 1;
75       if (j + 1 >= gridN) I_n = -1;
76       I_s = i * gridN + j - 1;
77       if (j - 1 < 0) I_s = -1;
78       I_e = (i + 1) * gridN + j;
79       if (i + 1 >= gridN) I_e = -1;
80       I_w = (i - 1) * gridN + j;
81       if (i - 1 < 0) I_w = -1;
82       PetscCall(MatSetValue(ctx->F, Ii, Ii, 4., INSERT_VALUES));
83       PetscCall(MatSetValue(ctx->F, Ii, I_n, -1., INSERT_VALUES));
84       PetscCall(MatSetValue(ctx->F, Ii, I_s, -1., INSERT_VALUES));
85       PetscCall(MatSetValue(ctx->F, Ii, I_e, -1., INSERT_VALUES));
86       PetscCall(MatSetValue(ctx->F, Ii, I_w, -1., INSERT_VALUES));
87     }
88   } else PetscCall(MatSetRandom(ctx->F, ctx->rctx));
89   PetscCall(MatAssemblyBegin(ctx->F, MAT_FINAL_ASSEMBLY));
90   PetscCall(MatAssemblyEnd(ctx->F, MAT_FINAL_ASSEMBLY));
91   PetscCall(PetscLogStagePop());
92   /* Stencil matrix is symmetric. Setting symmetric flag for ICC/Cholesky preconditioner */
93   if (!ctx->matops) PetscCall(MatSetOption(ctx->F, MAT_SYMMETRIC, PETSC_TRUE));
94   PetscCall(MatTransposeMatMult(ctx->F, ctx->F, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &ctx->W));
95   /* Setup Hessian Workspace in same shape as W */
96   PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &ctx->Hm));
97   PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &ctx->Hr));
98   PetscFunctionReturn(PETSC_SUCCESS);
99 }
100 
101 static PetscErrorCode SetupWorkspace(UserCtx ctx)
102 {
103   PetscInt i;
104 
105   PetscFunctionBegin;
106   PetscCall(MatCreateVecs(ctx->F, &ctx->workLeft[0], &ctx->workRight[0]));
107   for (i = 1; i < NWORKLEFT; i++) PetscCall(VecDuplicate(ctx->workLeft[0], &ctx->workLeft[i]));
108   for (i = 1; i < NWORKRIGHT; i++) PetscCall(VecDuplicate(ctx->workRight[0], &ctx->workRight[i]));
109   PetscFunctionReturn(PETSC_SUCCESS);
110 }
111 
112 static PetscErrorCode ConfigureContext(UserCtx ctx)
113 {
114   PetscFunctionBegin;
115   ctx->m        = 16;
116   ctx->n        = 16;
117   ctx->eps      = 1.e-3;
118   ctx->abstol   = 1.e-4;
119   ctx->reltol   = 1.e-2;
120   ctx->hStart   = 1.;
121   ctx->hMin     = 1.e-3;
122   ctx->hFactor  = 0.5;
123   ctx->alpha    = 1.;
124   ctx->mu       = 1.0;
125   ctx->matops   = 0;
126   ctx->iter     = 10;
127   ctx->p        = NORM_2;
128   ctx->soft     = PETSC_FALSE;
129   ctx->taylor   = PETSC_TRUE;
130   ctx->use_admm = PETSC_FALSE;
131   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "ex4.c");
132   PetscCall(PetscOptionsInt("-m", "The row dimension of matrix F", "ex4.c", ctx->m, &ctx->m, NULL));
133   PetscCall(PetscOptionsInt("-n", "The column dimension of matrix F", "ex4.c", ctx->n, &ctx->n, NULL));
134   PetscCall(PetscOptionsInt("-matrix_format", "Decide format of F matrix. 0 for stencil, 1 for random", "ex4.c", ctx->matops, &ctx->matops, NULL));
135   PetscCall(PetscOptionsInt("-iter", "Iteration number ADMM", "ex4.c", ctx->iter, &ctx->iter, NULL));
136   PetscCall(PetscOptionsReal("-alpha", "The regularization multiplier. 1 default", "ex4.c", ctx->alpha, &ctx->alpha, NULL));
137   PetscCall(PetscOptionsReal("-epsilon", "The small constant added to |x_i| in the denominator to approximate the gradient of ||x||_1", "ex4.c", ctx->eps, &ctx->eps, NULL));
138   PetscCall(PetscOptionsReal("-mu", "The augmented lagrangian multiplier in ADMM", "ex4.c", ctx->mu, &ctx->mu, NULL));
139   PetscCall(PetscOptionsReal("-hStart", "Taylor test starting point. 1 default.", "ex4.c", ctx->hStart, &ctx->hStart, NULL));
140   PetscCall(PetscOptionsReal("-hFactor", "Taylor test multiplier factor. 0.5 default", "ex4.c", ctx->hFactor, &ctx->hFactor, NULL));
141   PetscCall(PetscOptionsReal("-hMin", "Taylor test ending condition. 1.e-3 default", "ex4.c", ctx->hMin, &ctx->hMin, NULL));
142   PetscCall(PetscOptionsReal("-abstol", "Absolute stopping criterion for ADMM", "ex4.c", ctx->abstol, &ctx->abstol, NULL));
143   PetscCall(PetscOptionsReal("-reltol", "Relative stopping criterion for ADMM", "ex4.c", ctx->reltol, &ctx->reltol, NULL));
144   PetscCall(PetscOptionsBool("-taylor", "Flag for Taylor test. Default is true.", "ex4.c", ctx->taylor, &ctx->taylor, NULL));
145   PetscCall(PetscOptionsBool("-soft", "Flag for testing soft threshold no-op case. Default is false.", "ex4.c", ctx->soft, &ctx->soft, NULL));
146   PetscCall(PetscOptionsBool("-use_admm", "Use the ADMM solver in this example.", "ex4.c", ctx->use_admm, &ctx->use_admm, NULL));
147   PetscCall(PetscOptionsEnum("-p", "Norm type.", "ex4.c", NormTypes, (PetscEnum)ctx->p, (PetscEnum *)&ctx->p, NULL));
148   PetscOptionsEnd();
149   /* Creating random ctx */
150   PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &ctx->rctx));
151   PetscCall(PetscRandomSetFromOptions(ctx->rctx));
152   PetscCall(CreateMatrix(ctx));
153   PetscCall(CreateRHS(ctx));
154   PetscCall(SetupWorkspace(ctx));
155   PetscFunctionReturn(PETSC_SUCCESS);
156 }
157 
158 static PetscErrorCode DestroyContext(UserCtx *ctx)
159 {
160   PetscInt i;
161 
162   PetscFunctionBegin;
163   PetscCall(MatDestroy(&(*ctx)->F));
164   PetscCall(MatDestroy(&(*ctx)->W));
165   PetscCall(MatDestroy(&(*ctx)->Hm));
166   PetscCall(MatDestroy(&(*ctx)->Hr));
167   PetscCall(VecDestroy(&(*ctx)->d));
168   for (i = 0; i < NWORKLEFT; i++) PetscCall(VecDestroy(&(*ctx)->workLeft[i]));
169   for (i = 0; i < NWORKRIGHT; i++) PetscCall(VecDestroy(&(*ctx)->workRight[i]));
170   PetscCall(PetscRandomDestroy(&(*ctx)->rctx));
171   PetscCall(PetscFree(*ctx));
172   PetscFunctionReturn(PETSC_SUCCESS);
173 }
174 
175 /* compute (1/2) * ||F x - d||^2 */
176 static PetscErrorCode ObjectiveMisfit(Tao tao, Vec x, PetscReal *J, void *_ctx)
177 {
178   UserCtx ctx = (UserCtx)_ctx;
179   Vec     y;
180 
181   PetscFunctionBegin;
182   y = ctx->workLeft[0];
183   PetscCall(MatMult(ctx->F, x, y));
184   PetscCall(VecAXPY(y, -1., ctx->d));
185   PetscCall(VecDot(y, y, J));
186   *J *= 0.5;
187   PetscFunctionReturn(PETSC_SUCCESS);
188 }
189 
190 /* compute V = FTFx - FTd */
191 static PetscErrorCode GradientMisfit(Tao tao, Vec x, Vec V, void *_ctx)
192 {
193   UserCtx ctx = (UserCtx)_ctx;
194   Vec     FTFx, FTd;
195 
196   PetscFunctionBegin;
197   /* work1 is A^T Ax, work2 is Ab, W is A^T A*/
198   FTFx = ctx->workRight[0];
199   FTd  = ctx->workRight[1];
200   PetscCall(MatMult(ctx->W, x, FTFx));
201   PetscCall(MatMultTranspose(ctx->F, ctx->d, FTd));
202   PetscCall(VecWAXPY(V, -1., FTd, FTFx));
203   PetscFunctionReturn(PETSC_SUCCESS);
204 }
205 
206 /* returns FTF */
207 static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
208 {
209   UserCtx ctx = (UserCtx)_ctx;
210 
211   PetscFunctionBegin;
212   if (H != ctx->W) PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
213   if (Hpre != ctx->W) PetscCall(MatCopy(ctx->W, Hpre, DIFFERENT_NONZERO_PATTERN));
214   PetscFunctionReturn(PETSC_SUCCESS);
215 }
216 
217 /* computes augment Lagrangian objective (with scaled dual):
218  * 0.5 * ||F x - d||^2  + 0.5 * mu ||x - z + u||^2 */
219 static PetscErrorCode ObjectiveMisfitADMM(Tao tao, Vec x, PetscReal *J, void *_ctx)
220 {
221   UserCtx   ctx = (UserCtx)_ctx;
222   PetscReal mu, workNorm, misfit;
223   Vec       z, u, temp;
224 
225   PetscFunctionBegin;
226   mu   = ctx->mu;
227   z    = ctx->workRight[5];
228   u    = ctx->workRight[6];
229   temp = ctx->workRight[10];
230   /* misfit = f(x) */
231   PetscCall(ObjectiveMisfit(tao, x, &misfit, _ctx));
232   PetscCall(VecCopy(x, temp));
233   /* temp = x - z + u */
234   PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u));
235   /* workNorm = ||x - z + u||^2 */
236   PetscCall(VecDot(temp, temp, &workNorm));
237   /* augment Lagrangian objective (with scaled dual): f(x) + 0.5 * mu ||x -z + u||^2 */
238   *J = misfit + 0.5 * mu * workNorm;
239   PetscFunctionReturn(PETSC_SUCCESS);
240 }
241 
242 /* computes FTFx - FTd  mu*(x - z + u) */
243 static PetscErrorCode GradientMisfitADMM(Tao tao, Vec x, Vec V, void *_ctx)
244 {
245   UserCtx   ctx = (UserCtx)_ctx;
246   PetscReal mu;
247   Vec       z, u, temp;
248 
249   PetscFunctionBegin;
250   mu   = ctx->mu;
251   z    = ctx->workRight[5];
252   u    = ctx->workRight[6];
253   temp = ctx->workRight[10];
254   PetscCall(GradientMisfit(tao, x, V, _ctx));
255   PetscCall(VecCopy(x, temp));
256   /* temp = x - z + u */
257   PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u));
258   /* V =  FTFx - FTd  mu*(x - z + u) */
259   PetscCall(VecAXPY(V, mu, temp));
260   PetscFunctionReturn(PETSC_SUCCESS);
261 }
262 
263 /* returns FTF + diag(mu) */
264 static PetscErrorCode HessianMisfitADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
265 {
266   UserCtx ctx = (UserCtx)_ctx;
267 
268   PetscFunctionBegin;
269   PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
270   PetscCall(MatShift(H, ctx->mu));
271   if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN));
272   PetscFunctionReturn(PETSC_SUCCESS);
273 }
274 
275 /* computes || x ||_p (mult by 0.5 in case of NORM_2) */
276 static PetscErrorCode ObjectiveRegularization(Tao tao, Vec x, PetscReal *J, void *_ctx)
277 {
278   UserCtx   ctx = (UserCtx)_ctx;
279   PetscReal norm;
280 
281   PetscFunctionBegin;
282   *J = 0;
283   PetscCall(VecNorm(x, ctx->p, &norm));
284   if (ctx->p == NORM_2) norm = 0.5 * norm * norm;
285   *J = ctx->alpha * norm;
286   PetscFunctionReturn(PETSC_SUCCESS);
287 }
288 
289 /* NORM_2 Case: return x
290  * NORM_1 Case: x/(|x| + eps)
291  * Else: TODO */
292 static PetscErrorCode GradientRegularization(Tao tao, Vec x, Vec V, void *_ctx)
293 {
294   UserCtx   ctx = (UserCtx)_ctx;
295   PetscReal eps = ctx->eps;
296 
297   PetscFunctionBegin;
298   if (ctx->p == NORM_2) {
299     PetscCall(VecCopy(x, V));
300   } else if (ctx->p == NORM_1) {
301     PetscCall(VecCopy(x, ctx->workRight[1]));
302     PetscCall(VecAbs(ctx->workRight[1]));
303     PetscCall(VecShift(ctx->workRight[1], eps));
304     PetscCall(VecPointwiseDivide(V, x, ctx->workRight[1]));
305   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
306   PetscFunctionReturn(PETSC_SUCCESS);
307 }
308 
309 /* NORM_2 Case: returns diag(mu)
310  * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)))  */
311 static PetscErrorCode HessianRegularization(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
312 {
313   UserCtx   ctx = (UserCtx)_ctx;
314   PetscReal eps = ctx->eps;
315   Vec       copy1, copy2, copy3;
316 
317   PetscFunctionBegin;
318   if (ctx->p == NORM_2) {
319     /* Identity matrix scaled by mu */
320     PetscCall(MatZeroEntries(H));
321     PetscCall(MatShift(H, ctx->mu));
322     if (Hpre != H) {
323       PetscCall(MatZeroEntries(Hpre));
324       PetscCall(MatShift(Hpre, ctx->mu));
325     }
326   } else if (ctx->p == NORM_1) {
327     /* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)) */
328     copy1 = ctx->workRight[1];
329     copy2 = ctx->workRight[2];
330     copy3 = ctx->workRight[3];
331     /* copy1 : 1/sqrt(x_i^2 + eps) */
332     PetscCall(VecCopy(x, copy1));
333     PetscCall(VecPow(copy1, 2));
334     PetscCall(VecShift(copy1, eps));
335     PetscCall(VecSqrtAbs(copy1));
336     PetscCall(VecReciprocal(copy1));
337     /* copy2:  x_i^2.*/
338     PetscCall(VecCopy(x, copy2));
339     PetscCall(VecPow(copy2, 2));
340     /* copy3: abs(x_i^2 + eps) */
341     PetscCall(VecCopy(x, copy3));
342     PetscCall(VecPow(copy3, 2));
343     PetscCall(VecShift(copy3, eps));
344     PetscCall(VecAbs(copy3));
345     /* copy2: 1 - x_i^2/abs(x_i^2 + eps) */
346     PetscCall(VecPointwiseDivide(copy2, copy2, copy3));
347     PetscCall(VecScale(copy2, -1.));
348     PetscCall(VecShift(copy2, 1.));
349     PetscCall(VecAXPY(copy1, 1., copy2));
350     PetscCall(VecScale(copy1, ctx->mu));
351     PetscCall(MatZeroEntries(H));
352     PetscCall(MatDiagonalSet(H, copy1, INSERT_VALUES));
353     if (Hpre != H) {
354       PetscCall(MatZeroEntries(Hpre));
355       PetscCall(MatDiagonalSet(Hpre, copy1, INSERT_VALUES));
356     }
357   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
358   PetscFunctionReturn(PETSC_SUCCESS);
359 }
360 
361 /* NORM_2 Case: 0.5 || x ||_2 + 0.5 * mu * ||x + u - z||^2
362  * Else : || x ||_2 + 0.5 * mu * ||x + u - z||^2 */
363 static PetscErrorCode ObjectiveRegularizationADMM(Tao tao, Vec z, PetscReal *J, void *_ctx)
364 {
365   UserCtx   ctx = (UserCtx)_ctx;
366   PetscReal mu, workNorm, reg;
367   Vec       x, u, temp;
368 
369   PetscFunctionBegin;
370   mu   = ctx->mu;
371   x    = ctx->workRight[4];
372   u    = ctx->workRight[6];
373   temp = ctx->workRight[10];
374   PetscCall(ObjectiveRegularization(tao, z, &reg, _ctx));
375   PetscCall(VecCopy(z, temp));
376   /* temp = x + u -z */
377   PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u));
378   /* workNorm = ||x + u - z ||^2 */
379   PetscCall(VecDot(temp, temp, &workNorm));
380   *J = reg + 0.5 * mu * workNorm;
381   PetscFunctionReturn(PETSC_SUCCESS);
382 }
383 
384 /* NORM_2 Case: x - mu*(x + u - z)
385  * NORM_1 Case: x/(|x| + eps) - mu*(x + u - z)
386  * Else: TODO */
387 static PetscErrorCode GradientRegularizationADMM(Tao tao, Vec z, Vec V, void *_ctx)
388 {
389   UserCtx   ctx = (UserCtx)_ctx;
390   PetscReal mu;
391   Vec       x, u, temp;
392 
393   PetscFunctionBegin;
394   mu   = ctx->mu;
395   x    = ctx->workRight[4];
396   u    = ctx->workRight[6];
397   temp = ctx->workRight[10];
398   PetscCall(GradientRegularization(tao, z, V, _ctx));
399   PetscCall(VecCopy(z, temp));
400   /* temp = x + u -z */
401   PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u));
402   PetscCall(VecAXPY(V, -mu, temp));
403   PetscFunctionReturn(PETSC_SUCCESS);
404 }
405 
406 /* NORM_2 Case: returns diag(mu)
407  * NORM_1 Case: FTF + diag(mu) */
408 static PetscErrorCode HessianRegularizationADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
409 {
410   UserCtx ctx = (UserCtx)_ctx;
411 
412   PetscFunctionBegin;
413   if (ctx->p == NORM_2) {
414     /* Identity matrix scaled by mu */
415     PetscCall(MatZeroEntries(H));
416     PetscCall(MatShift(H, ctx->mu));
417     if (Hpre != H) {
418       PetscCall(MatZeroEntries(Hpre));
419       PetscCall(MatShift(Hpre, ctx->mu));
420     }
421   } else if (ctx->p == NORM_1) {
422     PetscCall(HessianMisfit(tao, x, H, Hpre, (void *)ctx));
423     PetscCall(MatShift(H, ctx->mu));
424     if (Hpre != H) PetscCall(MatShift(Hpre, ctx->mu));
425   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
426   PetscFunctionReturn(PETSC_SUCCESS);
427 }
428 
429 /* NORM_2 Case : (1/2) * ||F x - d||^2 + 0.5 * || x ||_p
430 *  NORM_1 Case : (1/2) * ||F x - d||^2 + || x ||_p */
431 static PetscErrorCode ObjectiveComplete(Tao tao, Vec x, PetscReal *J, void *ctx)
432 {
433   PetscReal Jm, Jr;
434 
435   PetscFunctionBegin;
436   PetscCall(ObjectiveMisfit(tao, x, &Jm, ctx));
437   PetscCall(ObjectiveRegularization(tao, x, &Jr, ctx));
438   *J = Jm + Jr;
439   PetscFunctionReturn(PETSC_SUCCESS);
440 }
441 
442 /* NORM_2 Case: FTFx - FTd + x
443  * NORM_1 Case: FTFx - FTd + x/(|x| + eps) */
444 static PetscErrorCode GradientComplete(Tao tao, Vec x, Vec V, void *ctx)
445 {
446   UserCtx cntx = (UserCtx)ctx;
447 
448   PetscFunctionBegin;
449   PetscCall(GradientMisfit(tao, x, cntx->workRight[2], ctx));
450   PetscCall(GradientRegularization(tao, x, cntx->workRight[3], ctx));
451   PetscCall(VecWAXPY(V, 1, cntx->workRight[2], cntx->workRight[3]));
452   PetscFunctionReturn(PETSC_SUCCESS);
453 }
454 
455 /* NORM_2 Case: diag(mu) + FTF
456  * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) + FTF  */
457 static PetscErrorCode HessianComplete(Tao tao, Vec x, Mat H, Mat Hpre, void *ctx)
458 {
459   Mat tempH;
460 
461   PetscFunctionBegin;
462   PetscCall(MatDuplicate(H, MAT_SHARE_NONZERO_PATTERN, &tempH));
463   PetscCall(HessianMisfit(tao, x, H, H, ctx));
464   PetscCall(HessianRegularization(tao, x, tempH, tempH, ctx));
465   PetscCall(MatAXPY(H, 1., tempH, DIFFERENT_NONZERO_PATTERN));
466   if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN));
467   PetscCall(MatDestroy(&tempH));
468   PetscFunctionReturn(PETSC_SUCCESS);
469 }
470 
471 static PetscErrorCode TaoSolveADMM(UserCtx ctx, Vec x)
472 {
473   PetscInt  i;
474   PetscReal u_norm, r_norm, s_norm, primal, dual, x_norm, z_norm;
475   Tao       tao1, tao2;
476   Vec       xk, z, u, diff, zold, zdiff, temp;
477   PetscReal mu;
478 
479   PetscFunctionBegin;
480   xk    = ctx->workRight[4];
481   z     = ctx->workRight[5];
482   u     = ctx->workRight[6];
483   diff  = ctx->workRight[7];
484   zold  = ctx->workRight[8];
485   zdiff = ctx->workRight[9];
486   temp  = ctx->workRight[11];
487   mu    = ctx->mu;
488   PetscCall(VecSet(u, 0.));
489   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao1));
490   PetscCall(TaoSetType(tao1, TAONLS));
491   PetscCall(TaoSetObjective(tao1, ObjectiveMisfitADMM, (void *)ctx));
492   PetscCall(TaoSetGradient(tao1, NULL, GradientMisfitADMM, (void *)ctx));
493   PetscCall(TaoSetHessian(tao1, ctx->Hm, ctx->Hm, HessianMisfitADMM, (void *)ctx));
494   PetscCall(VecSet(xk, 0.));
495   PetscCall(TaoSetSolution(tao1, xk));
496   PetscCall(TaoSetOptionsPrefix(tao1, "misfit_"));
497   PetscCall(TaoSetFromOptions(tao1));
498   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao2));
499   if (ctx->p == NORM_2) {
500     PetscCall(TaoSetType(tao2, TAONLS));
501     PetscCall(TaoSetObjective(tao2, ObjectiveRegularizationADMM, (void *)ctx));
502     PetscCall(TaoSetGradient(tao2, NULL, GradientRegularizationADMM, (void *)ctx));
503     PetscCall(TaoSetHessian(tao2, ctx->Hr, ctx->Hr, HessianRegularizationADMM, (void *)ctx));
504   }
505   PetscCall(VecSet(z, 0.));
506   PetscCall(TaoSetSolution(tao2, z));
507   PetscCall(TaoSetOptionsPrefix(tao2, "reg_"));
508   PetscCall(TaoSetFromOptions(tao2));
509 
510   for (i = 0; i < ctx->iter; i++) {
511     PetscCall(VecCopy(z, zold));
512     PetscCall(TaoSolve(tao1)); /* Updates xk */
513     if (ctx->p == NORM_1) {
514       PetscCall(VecWAXPY(temp, 1., xk, u));
515       PetscCall(TaoSoftThreshold(temp, -ctx->alpha / mu, ctx->alpha / mu, z));
516     } else {
517       PetscCall(TaoSolve(tao2)); /* Update zk */
518     }
519     /* u = u + xk -z */
520     PetscCall(VecAXPBYPCZ(u, 1., -1., 1., xk, z));
521     /* r_norm : norm(x-z) */
522     PetscCall(VecWAXPY(diff, -1., z, xk));
523     PetscCall(VecNorm(diff, NORM_2, &r_norm));
524     /* s_norm : norm(-mu(z-zold)) */
525     PetscCall(VecWAXPY(zdiff, -1., zold, z));
526     PetscCall(VecNorm(zdiff, NORM_2, &s_norm));
527     s_norm = s_norm * mu;
528     /* primal : sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z))*/
529     PetscCall(VecNorm(xk, NORM_2, &x_norm));
530     PetscCall(VecNorm(z, NORM_2, &z_norm));
531     primal = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * PetscMax(x_norm, z_norm);
532     /* Duality : sqrt(n)*ABSTOL + RELTOL*norm(mu*u)*/
533     PetscCall(VecNorm(u, NORM_2, &u_norm));
534     dual = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * u_norm * mu;
535     PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao1), "Iter %" PetscInt_FMT " : ||x-z||: %g, mu*||z-zold||: %g\n", i, (double)r_norm, (double)s_norm));
536     if (r_norm < primal && s_norm < dual) break;
537   }
538   PetscCall(VecCopy(xk, x));
539   PetscCall(TaoDestroy(&tao1));
540   PetscCall(TaoDestroy(&tao2));
541   PetscFunctionReturn(PETSC_SUCCESS);
542 }
543 
544 /* Second order Taylor remainder convergence test */
545 static PetscErrorCode TaylorTest(UserCtx ctx, Tao tao, Vec x, PetscReal *C)
546 {
547   PetscReal  h, J, temp;
548   PetscInt   i, j;
549   PetscInt   numValues;
550   PetscReal  Jx, Jxhat_comp, Jxhat_pred;
551   PetscReal *Js, *hs;
552   PetscReal  gdotdx;
553   PetscReal  minrate = PETSC_MAX_REAL;
554   MPI_Comm   comm    = PetscObjectComm((PetscObject)x);
555   Vec        g, dx, xhat;
556 
557   PetscFunctionBegin;
558   PetscCall(VecDuplicate(x, &g));
559   PetscCall(VecDuplicate(x, &xhat));
560   /* choose a perturbation direction */
561   PetscCall(VecDuplicate(x, &dx));
562   PetscCall(VecSetRandom(dx, ctx->rctx));
563   /* evaluate objective at x: J(x) */
564   PetscCall(TaoComputeObjective(tao, x, &Jx));
565   /* evaluate gradient at x, save in vector g */
566   PetscCall(TaoComputeGradient(tao, x, g));
567   PetscCall(VecDot(g, dx, &gdotdx));
568 
569   for (numValues = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor) numValues++;
570   PetscCall(PetscCalloc2(numValues, &Js, numValues, &hs));
571   for (i = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor, i++) {
572     PetscCall(VecWAXPY(xhat, h, dx, x));
573     PetscCall(TaoComputeObjective(tao, xhat, &Jxhat_comp));
574     /* J(\hat(x)) \approx J(x) + g^T (xhat - x) = J(x) + h * g^T dx */
575     Jxhat_pred = Jx + h * gdotdx;
576     /* Vector to dJdm scalar? Dot?*/
577     J = PetscAbsReal(Jxhat_comp - Jxhat_pred);
578     PetscCall(PetscPrintf(comm, "J(xhat): %g, predicted: %g, diff %g\n", (double)Jxhat_comp, (double)Jxhat_pred, (double)J));
579     Js[i] = J;
580     hs[i] = h;
581   }
582   for (j = 1; j < numValues; j++) {
583     temp = PetscLogReal(Js[j] / Js[j - 1]) / PetscLogReal(hs[j] / hs[j - 1]);
584     PetscCall(PetscPrintf(comm, "Convergence rate step %" PetscInt_FMT ": %g\n", j - 1, (double)temp));
585     minrate = PetscMin(minrate, temp);
586   }
587   /* If O is not ~2, then the test is wrong */
588   PetscCall(PetscFree2(Js, hs));
589   *C = minrate;
590   PetscCall(VecDestroy(&dx));
591   PetscCall(VecDestroy(&xhat));
592   PetscCall(VecDestroy(&g));
593   PetscFunctionReturn(PETSC_SUCCESS);
594 }
595 
596 int main(int argc, char **argv)
597 {
598   UserCtx ctx;
599   Tao     tao;
600   Vec     x;
601   Mat     H;
602 
603   PetscFunctionBeginUser;
604   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
605   PetscCall(PetscNew(&ctx));
606   PetscCall(ConfigureContext(ctx));
607   /* Define two functions that could pass as objectives to TaoSetObjective(): one
608    * for the misfit component, and one for the regularization component */
609   /* ObjectiveMisfit() and ObjectiveRegularization() */
610 
611   /* Define a single function that calls both components adds them together: the complete objective,
612    * in the absence of a Tao implementation that handles separability */
613   /* ObjectiveComplete() */
614   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao));
615   PetscCall(TaoSetType(tao, TAONM));
616   PetscCall(TaoSetObjective(tao, ObjectiveComplete, (void *)ctx));
617   PetscCall(TaoSetGradient(tao, NULL, GradientComplete, (void *)ctx));
618   PetscCall(MatDuplicate(ctx->W, MAT_SHARE_NONZERO_PATTERN, &H));
619   PetscCall(TaoSetHessian(tao, H, H, HessianComplete, (void *)ctx));
620   PetscCall(MatCreateVecs(ctx->F, NULL, &x));
621   PetscCall(VecSet(x, 0.));
622   PetscCall(TaoSetSolution(tao, x));
623   PetscCall(TaoSetFromOptions(tao));
624   if (ctx->use_admm) PetscCall(TaoSolveADMM(ctx, x));
625   else PetscCall(TaoSolve(tao));
626   /* examine solution */
627   PetscCall(VecViewFromOptions(x, NULL, "-view_sol"));
628   if (ctx->taylor) {
629     PetscReal rate;
630     PetscCall(TaylorTest(ctx, tao, x, &rate));
631   }
632   if (ctx->soft) PetscCall(TaoSoftThreshold(x, 0., 0., x));
633   PetscCall(MatDestroy(&H));
634   PetscCall(TaoDestroy(&tao));
635   PetscCall(VecDestroy(&x));
636   PetscCall(DestroyContext(&ctx));
637   PetscCall(PetscFinalize());
638   return 0;
639 }
640 
641 /*TEST
642 
643   build:
644     requires: !complex
645 
646   test:
647     suffix: 0
648     args:
649 
650   test:
651     suffix: l1_1
652     args: -p 1 -tao_type lmvm -alpha 1. -epsilon 1.e-7 -m 64 -n 64 -view_sol -matrix_format 1
653 
654   test:
655     suffix: hessian_1
656     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nls
657 
658   test:
659     suffix: hessian_2
660     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nls
661 
662   test:
663     suffix: nm_1
664     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nm -tao_max_it 50
665 
666   test:
667     suffix: nm_2
668     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nm -tao_max_it 50
669 
670   test:
671     suffix: lmvm_1
672     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type lmvm -tao_max_it 40
673 
674   test:
675     suffix: lmvm_2
676     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type lmvm -tao_max_it 15
677 
678   test:
679     suffix: soft_threshold_admm_1
680     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm
681 
682   test:
683     suffix: hessian_admm_1
684     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nls -misfit_tao_type nls
685 
686   test:
687     suffix: hessian_admm_2
688     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nls -misfit_tao_type nls
689 
690   test:
691     suffix: nm_admm_1
692     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nm -misfit_tao_type nm
693 
694   test:
695     suffix: nm_admm_2
696     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nm -misfit_tao_type nm -iter 7
697 
698   test:
699     suffix: lmvm_admm_1
700     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm
701 
702   test:
703     suffix: lmvm_admm_2
704     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm
705 
706   test:
707     suffix: soft
708     args: -taylor 0 -soft 1
709     output_file: output/empty.out
710 
711 TEST*/
712