1 static char help[] = "Simple example to test separable objective optimizers.\n";
2
3 #include <petsc.h>
4 #include <petsctao.h>
5 #include <petscvec.h>
6 #include <petscmath.h>
7
8 #define NWORKLEFT 4
9 #define NWORKRIGHT 12
10
11 typedef struct _UserCtx {
12 PetscInt m; /* The row dimension of F */
13 PetscInt n; /* The column dimension of F */
14 PetscInt matops; /* Matrix format. 0 for stencil, 1 for random */
15 PetscInt iter; /* Number of iterations for ADMM */
16 PetscReal hStart; /* Starting point for Taylor test */
17 PetscReal hFactor; /* Taylor test step factor */
18 PetscReal hMin; /* Taylor test end goal */
19 PetscReal alpha; /* regularization constant applied to || x ||_p */
20 PetscReal eps; /* small constant for approximating gradient of || x ||_1 */
21 PetscReal mu; /* the augmented Lagrangian term in ADMM */
22 PetscReal abstol;
23 PetscReal reltol;
24 Mat F; /* matrix in least squares component $(1/2) * || F x - d ||_2^2$ */
25 Mat W; /* Workspace matrix. ATA */
26 Mat Hm; /* Hessian Misfit*/
27 Mat Hr; /* Hessian Reg*/
28 Vec d; /* RHS in least squares component $(1/2) * || F x - d ||_2^2$ */
29 Vec workLeft[NWORKLEFT]; /* Workspace for temporary vec */
30 Vec workRight[NWORKRIGHT]; /* Workspace for temporary vec */
31 NormType p;
32 PetscRandom rctx;
33 PetscBool soft;
34 PetscBool taylor; /* Flag to determine whether to run Taylor test or not */
35 PetscBool use_admm; /* Flag to determine whether to run Taylor test or not */
36 } *UserCtx;
37
CreateRHS(UserCtx ctx)38 static PetscErrorCode CreateRHS(UserCtx ctx)
39 {
40 PetscFunctionBegin;
41 /* build the rhs d in ctx */
42 PetscCall(VecCreate(PETSC_COMM_WORLD, &ctx->d));
43 PetscCall(VecSetSizes(ctx->d, PETSC_DECIDE, ctx->m));
44 PetscCall(VecSetFromOptions(ctx->d));
45 PetscCall(VecSetRandom(ctx->d, ctx->rctx));
46 PetscFunctionReturn(PETSC_SUCCESS);
47 }
48
CreateMatrix(UserCtx ctx)49 static PetscErrorCode CreateMatrix(UserCtx ctx)
50 {
51 PetscInt Istart, Iend, i, j, Ii, gridN, I_n, I_s, I_e, I_w;
52 PetscLogStage stage;
53
54 PetscFunctionBegin;
55 /* build the matrix F in ctx */
56 PetscCall(MatCreate(PETSC_COMM_WORLD, &ctx->F));
57 PetscCall(MatSetSizes(ctx->F, PETSC_DECIDE, PETSC_DECIDE, ctx->m, ctx->n));
58 PetscCall(MatSetType(ctx->F, MATAIJ)); /* TODO: Decide specific SetType other than dummy*/
59 PetscCall(MatMPIAIJSetPreallocation(ctx->F, 5, NULL, 5, NULL)); /*TODO: some number other than 5?*/
60 PetscCall(MatSeqAIJSetPreallocation(ctx->F, 5, NULL));
61 PetscCall(MatSetUp(ctx->F));
62 PetscCall(MatGetOwnershipRange(ctx->F, &Istart, &Iend));
63 PetscCall(PetscLogStageRegister("Assembly", &stage));
64 PetscCall(PetscLogStagePush(stage));
65
66 /* Set matrix elements in 2-D five point stencil format. */
67 if (!ctx->matops) {
68 PetscCheck(ctx->m == ctx->n, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Stencil matrix must be square");
69 gridN = (PetscInt)PetscSqrtReal((PetscReal)ctx->m);
70 PetscCheck(gridN * gridN == ctx->m, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Number of rows must be square");
71 for (Ii = Istart; Ii < Iend; Ii++) {
72 i = Ii / gridN;
73 j = Ii % gridN;
74 I_n = i * gridN + j + 1;
75 if (j + 1 >= gridN) I_n = -1;
76 I_s = i * gridN + j - 1;
77 if (j - 1 < 0) I_s = -1;
78 I_e = (i + 1) * gridN + j;
79 if (i + 1 >= gridN) I_e = -1;
80 I_w = (i - 1) * gridN + j;
81 if (i - 1 < 0) I_w = -1;
82 PetscCall(MatSetValue(ctx->F, Ii, Ii, 4., INSERT_VALUES));
83 PetscCall(MatSetValue(ctx->F, Ii, I_n, -1., INSERT_VALUES));
84 PetscCall(MatSetValue(ctx->F, Ii, I_s, -1., INSERT_VALUES));
85 PetscCall(MatSetValue(ctx->F, Ii, I_e, -1., INSERT_VALUES));
86 PetscCall(MatSetValue(ctx->F, Ii, I_w, -1., INSERT_VALUES));
87 }
88 } else PetscCall(MatSetRandom(ctx->F, ctx->rctx));
89 PetscCall(MatAssemblyBegin(ctx->F, MAT_FINAL_ASSEMBLY));
90 PetscCall(MatAssemblyEnd(ctx->F, MAT_FINAL_ASSEMBLY));
91 PetscCall(PetscLogStagePop());
92 /* Stencil matrix is symmetric. Setting symmetric flag for ICC/Cholesky preconditioner */
93 if (!ctx->matops) PetscCall(MatSetOption(ctx->F, MAT_SYMMETRIC, PETSC_TRUE));
94 PetscCall(MatTransposeMatMult(ctx->F, ctx->F, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &ctx->W));
95 /* Setup Hessian Workspace in same shape as W */
96 PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &ctx->Hm));
97 PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &ctx->Hr));
98 PetscFunctionReturn(PETSC_SUCCESS);
99 }
100
SetupWorkspace(UserCtx ctx)101 static PetscErrorCode SetupWorkspace(UserCtx ctx)
102 {
103 PetscInt i;
104
105 PetscFunctionBegin;
106 PetscCall(MatCreateVecs(ctx->F, &ctx->workLeft[0], &ctx->workRight[0]));
107 for (i = 1; i < NWORKLEFT; i++) PetscCall(VecDuplicate(ctx->workLeft[0], &ctx->workLeft[i]));
108 for (i = 1; i < NWORKRIGHT; i++) PetscCall(VecDuplicate(ctx->workRight[0], &ctx->workRight[i]));
109 PetscFunctionReturn(PETSC_SUCCESS);
110 }
111
ConfigureContext(UserCtx ctx)112 static PetscErrorCode ConfigureContext(UserCtx ctx)
113 {
114 PetscFunctionBegin;
115 ctx->m = 16;
116 ctx->n = 16;
117 ctx->eps = 1.e-3;
118 ctx->abstol = 1.e-4;
119 ctx->reltol = 1.e-2;
120 ctx->hStart = 1.;
121 ctx->hMin = 1.e-3;
122 ctx->hFactor = 0.5;
123 ctx->alpha = 1.;
124 ctx->mu = 1.0;
125 ctx->matops = 0;
126 ctx->iter = 10;
127 ctx->p = NORM_2;
128 ctx->soft = PETSC_FALSE;
129 ctx->taylor = PETSC_TRUE;
130 ctx->use_admm = PETSC_FALSE;
131 PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "ex4.c");
132 PetscCall(PetscOptionsInt("-m", "The row dimension of matrix F", "ex4.c", ctx->m, &ctx->m, NULL));
133 PetscCall(PetscOptionsInt("-n", "The column dimension of matrix F", "ex4.c", ctx->n, &ctx->n, NULL));
134 PetscCall(PetscOptionsInt("-matrix_format", "Decide format of F matrix. 0 for stencil, 1 for random", "ex4.c", ctx->matops, &ctx->matops, NULL));
135 PetscCall(PetscOptionsInt("-iter", "Iteration number ADMM", "ex4.c", ctx->iter, &ctx->iter, NULL));
136 PetscCall(PetscOptionsReal("-alpha", "The regularization multiplier. 1 default", "ex4.c", ctx->alpha, &ctx->alpha, NULL));
137 PetscCall(PetscOptionsReal("-epsilon", "The small constant added to |x_i| in the denominator to approximate the gradient of ||x||_1", "ex4.c", ctx->eps, &ctx->eps, NULL));
138 PetscCall(PetscOptionsReal("-mu", "The augmented lagrangian multiplier in ADMM", "ex4.c", ctx->mu, &ctx->mu, NULL));
139 PetscCall(PetscOptionsReal("-hStart", "Taylor test starting point. 1 default.", "ex4.c", ctx->hStart, &ctx->hStart, NULL));
140 PetscCall(PetscOptionsReal("-hFactor", "Taylor test multiplier factor. 0.5 default", "ex4.c", ctx->hFactor, &ctx->hFactor, NULL));
141 PetscCall(PetscOptionsReal("-hMin", "Taylor test ending condition. 1.e-3 default", "ex4.c", ctx->hMin, &ctx->hMin, NULL));
142 PetscCall(PetscOptionsReal("-abstol", "Absolute stopping criterion for ADMM", "ex4.c", ctx->abstol, &ctx->abstol, NULL));
143 PetscCall(PetscOptionsReal("-reltol", "Relative stopping criterion for ADMM", "ex4.c", ctx->reltol, &ctx->reltol, NULL));
144 PetscCall(PetscOptionsBool("-taylor", "Flag for Taylor test. Default is true.", "ex4.c", ctx->taylor, &ctx->taylor, NULL));
145 PetscCall(PetscOptionsBool("-soft", "Flag for testing soft threshold no-op case. Default is false.", "ex4.c", ctx->soft, &ctx->soft, NULL));
146 PetscCall(PetscOptionsBool("-use_admm", "Use the ADMM solver in this example.", "ex4.c", ctx->use_admm, &ctx->use_admm, NULL));
147 PetscCall(PetscOptionsEnum("-p", "Norm type.", "ex4.c", NormTypes, (PetscEnum)ctx->p, (PetscEnum *)&ctx->p, NULL));
148 PetscOptionsEnd();
149 /* Creating random ctx */
150 PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &ctx->rctx));
151 PetscCall(PetscRandomSetFromOptions(ctx->rctx));
152 PetscCall(CreateMatrix(ctx));
153 PetscCall(CreateRHS(ctx));
154 PetscCall(SetupWorkspace(ctx));
155 PetscFunctionReturn(PETSC_SUCCESS);
156 }
157
DestroyContext(UserCtx * ctx)158 static PetscErrorCode DestroyContext(UserCtx *ctx)
159 {
160 PetscInt i;
161
162 PetscFunctionBegin;
163 PetscCall(MatDestroy(&(*ctx)->F));
164 PetscCall(MatDestroy(&(*ctx)->W));
165 PetscCall(MatDestroy(&(*ctx)->Hm));
166 PetscCall(MatDestroy(&(*ctx)->Hr));
167 PetscCall(VecDestroy(&(*ctx)->d));
168 for (i = 0; i < NWORKLEFT; i++) PetscCall(VecDestroy(&(*ctx)->workLeft[i]));
169 for (i = 0; i < NWORKRIGHT; i++) PetscCall(VecDestroy(&(*ctx)->workRight[i]));
170 PetscCall(PetscRandomDestroy(&(*ctx)->rctx));
171 PetscCall(PetscFree(*ctx));
172 PetscFunctionReturn(PETSC_SUCCESS);
173 }
174
175 /* compute (1/2) * ||F x - d||^2 */
ObjectiveMisfit(Tao tao,Vec x,PetscReal * J,void * _ctx)176 static PetscErrorCode ObjectiveMisfit(Tao tao, Vec x, PetscReal *J, void *_ctx)
177 {
178 UserCtx ctx = (UserCtx)_ctx;
179 Vec y;
180
181 PetscFunctionBegin;
182 y = ctx->workLeft[0];
183 PetscCall(MatMult(ctx->F, x, y));
184 PetscCall(VecAXPY(y, -1., ctx->d));
185 PetscCall(VecDot(y, y, J));
186 *J *= 0.5;
187 PetscFunctionReturn(PETSC_SUCCESS);
188 }
189
190 /* compute V = FTFx - FTd */
GradientMisfit(Tao tao,Vec x,Vec V,void * _ctx)191 static PetscErrorCode GradientMisfit(Tao tao, Vec x, Vec V, void *_ctx)
192 {
193 UserCtx ctx = (UserCtx)_ctx;
194 Vec FTFx, FTd;
195
196 PetscFunctionBegin;
197 /* work1 is A^T Ax, work2 is Ab, W is A^T A*/
198 FTFx = ctx->workRight[0];
199 FTd = ctx->workRight[1];
200 PetscCall(MatMult(ctx->W, x, FTFx));
201 PetscCall(MatMultTranspose(ctx->F, ctx->d, FTd));
202 PetscCall(VecWAXPY(V, -1., FTd, FTFx));
203 PetscFunctionReturn(PETSC_SUCCESS);
204 }
205
206 /* returns FTF */
HessianMisfit(Tao tao,Vec x,Mat H,Mat Hpre,void * _ctx)207 static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
208 {
209 UserCtx ctx = (UserCtx)_ctx;
210
211 PetscFunctionBegin;
212 if (H != ctx->W) PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
213 if (Hpre != ctx->W) PetscCall(MatCopy(ctx->W, Hpre, DIFFERENT_NONZERO_PATTERN));
214 PetscFunctionReturn(PETSC_SUCCESS);
215 }
216
217 /* computes augment Lagrangian objective (with scaled dual):
218 * 0.5 * ||F x - d||^2 + 0.5 * mu ||x - z + u||^2 */
ObjectiveMisfitADMM(Tao tao,Vec x,PetscReal * J,void * _ctx)219 static PetscErrorCode ObjectiveMisfitADMM(Tao tao, Vec x, PetscReal *J, void *_ctx)
220 {
221 UserCtx ctx = (UserCtx)_ctx;
222 PetscReal mu, workNorm, misfit;
223 Vec z, u, temp;
224
225 PetscFunctionBegin;
226 mu = ctx->mu;
227 z = ctx->workRight[5];
228 u = ctx->workRight[6];
229 temp = ctx->workRight[10];
230 /* misfit = f(x) */
231 PetscCall(ObjectiveMisfit(tao, x, &misfit, _ctx));
232 PetscCall(VecCopy(x, temp));
233 /* temp = x - z + u */
234 PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u));
235 /* workNorm = ||x - z + u||^2 */
236 PetscCall(VecDot(temp, temp, &workNorm));
237 /* augment Lagrangian objective (with scaled dual): f(x) + 0.5 * mu ||x -z + u||^2 */
238 *J = misfit + 0.5 * mu * workNorm;
239 PetscFunctionReturn(PETSC_SUCCESS);
240 }
241
242 /* computes FTFx - FTd mu*(x - z + u) */
GradientMisfitADMM(Tao tao,Vec x,Vec V,void * _ctx)243 static PetscErrorCode GradientMisfitADMM(Tao tao, Vec x, Vec V, void *_ctx)
244 {
245 UserCtx ctx = (UserCtx)_ctx;
246 PetscReal mu;
247 Vec z, u, temp;
248
249 PetscFunctionBegin;
250 mu = ctx->mu;
251 z = ctx->workRight[5];
252 u = ctx->workRight[6];
253 temp = ctx->workRight[10];
254 PetscCall(GradientMisfit(tao, x, V, _ctx));
255 PetscCall(VecCopy(x, temp));
256 /* temp = x - z + u */
257 PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u));
258 /* V = FTFx - FTd mu*(x - z + u) */
259 PetscCall(VecAXPY(V, mu, temp));
260 PetscFunctionReturn(PETSC_SUCCESS);
261 }
262
263 /* returns FTF + diag(mu) */
HessianMisfitADMM(Tao tao,Vec x,Mat H,Mat Hpre,void * _ctx)264 static PetscErrorCode HessianMisfitADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
265 {
266 UserCtx ctx = (UserCtx)_ctx;
267
268 PetscFunctionBegin;
269 PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
270 PetscCall(MatShift(H, ctx->mu));
271 if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN));
272 PetscFunctionReturn(PETSC_SUCCESS);
273 }
274
275 /* computes || x ||_p (mult by 0.5 in case of NORM_2) */
ObjectiveRegularization(Tao tao,Vec x,PetscReal * J,void * _ctx)276 static PetscErrorCode ObjectiveRegularization(Tao tao, Vec x, PetscReal *J, void *_ctx)
277 {
278 UserCtx ctx = (UserCtx)_ctx;
279 PetscReal norm;
280
281 PetscFunctionBegin;
282 *J = 0;
283 PetscCall(VecNorm(x, ctx->p, &norm));
284 if (ctx->p == NORM_2) norm = 0.5 * norm * norm;
285 *J = ctx->alpha * norm;
286 PetscFunctionReturn(PETSC_SUCCESS);
287 }
288
289 /* NORM_2 Case: return x
290 * NORM_1 Case: x/(|x| + eps)
291 * Else: TODO */
GradientRegularization(Tao tao,Vec x,Vec V,void * _ctx)292 static PetscErrorCode GradientRegularization(Tao tao, Vec x, Vec V, void *_ctx)
293 {
294 UserCtx ctx = (UserCtx)_ctx;
295 PetscReal eps = ctx->eps;
296
297 PetscFunctionBegin;
298 if (ctx->p == NORM_2) {
299 PetscCall(VecCopy(x, V));
300 } else if (ctx->p == NORM_1) {
301 PetscCall(VecCopy(x, ctx->workRight[1]));
302 PetscCall(VecAbs(ctx->workRight[1]));
303 PetscCall(VecShift(ctx->workRight[1], eps));
304 PetscCall(VecPointwiseDivide(V, x, ctx->workRight[1]));
305 } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
306 PetscFunctionReturn(PETSC_SUCCESS);
307 }
308
309 /* NORM_2 Case: returns diag(mu)
310 * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) */
HessianRegularization(Tao tao,Vec x,Mat H,Mat Hpre,void * _ctx)311 static PetscErrorCode HessianRegularization(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
312 {
313 UserCtx ctx = (UserCtx)_ctx;
314 PetscReal eps = ctx->eps;
315 Vec copy1, copy2, copy3;
316
317 PetscFunctionBegin;
318 if (ctx->p == NORM_2) {
319 /* Identity matrix scaled by mu */
320 PetscCall(MatZeroEntries(H));
321 PetscCall(MatShift(H, ctx->mu));
322 if (Hpre != H) {
323 PetscCall(MatZeroEntries(Hpre));
324 PetscCall(MatShift(Hpre, ctx->mu));
325 }
326 } else if (ctx->p == NORM_1) {
327 /* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)) */
328 copy1 = ctx->workRight[1];
329 copy2 = ctx->workRight[2];
330 copy3 = ctx->workRight[3];
331 /* copy1 : 1/sqrt(x_i^2 + eps) */
332 PetscCall(VecCopy(x, copy1));
333 PetscCall(VecPow(copy1, 2));
334 PetscCall(VecShift(copy1, eps));
335 PetscCall(VecSqrtAbs(copy1));
336 PetscCall(VecReciprocal(copy1));
337 /* copy2: x_i^2.*/
338 PetscCall(VecCopy(x, copy2));
339 PetscCall(VecPow(copy2, 2));
340 /* copy3: abs(x_i^2 + eps) */
341 PetscCall(VecCopy(x, copy3));
342 PetscCall(VecPow(copy3, 2));
343 PetscCall(VecShift(copy3, eps));
344 PetscCall(VecAbs(copy3));
345 /* copy2: 1 - x_i^2/abs(x_i^2 + eps) */
346 PetscCall(VecPointwiseDivide(copy2, copy2, copy3));
347 PetscCall(VecScale(copy2, -1.));
348 PetscCall(VecShift(copy2, 1.));
349 PetscCall(VecAXPY(copy1, 1., copy2));
350 PetscCall(VecScale(copy1, ctx->mu));
351 PetscCall(MatZeroEntries(H));
352 PetscCall(MatDiagonalSet(H, copy1, INSERT_VALUES));
353 if (Hpre != H) {
354 PetscCall(MatZeroEntries(Hpre));
355 PetscCall(MatDiagonalSet(Hpre, copy1, INSERT_VALUES));
356 }
357 } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
358 PetscFunctionReturn(PETSC_SUCCESS);
359 }
360
361 /* NORM_2 Case: 0.5 || x ||_2 + 0.5 * mu * ||x + u - z||^2
362 * Else : || x ||_2 + 0.5 * mu * ||x + u - z||^2 */
ObjectiveRegularizationADMM(Tao tao,Vec z,PetscReal * J,void * _ctx)363 static PetscErrorCode ObjectiveRegularizationADMM(Tao tao, Vec z, PetscReal *J, void *_ctx)
364 {
365 UserCtx ctx = (UserCtx)_ctx;
366 PetscReal mu, workNorm, reg;
367 Vec x, u, temp;
368
369 PetscFunctionBegin;
370 mu = ctx->mu;
371 x = ctx->workRight[4];
372 u = ctx->workRight[6];
373 temp = ctx->workRight[10];
374 PetscCall(ObjectiveRegularization(tao, z, ®, _ctx));
375 PetscCall(VecCopy(z, temp));
376 /* temp = x + u -z */
377 PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u));
378 /* workNorm = ||x + u - z ||^2 */
379 PetscCall(VecDot(temp, temp, &workNorm));
380 *J = reg + 0.5 * mu * workNorm;
381 PetscFunctionReturn(PETSC_SUCCESS);
382 }
383
384 /* NORM_2 Case: x - mu*(x + u - z)
385 * NORM_1 Case: x/(|x| + eps) - mu*(x + u - z)
386 * Else: TODO */
GradientRegularizationADMM(Tao tao,Vec z,Vec V,void * _ctx)387 static PetscErrorCode GradientRegularizationADMM(Tao tao, Vec z, Vec V, void *_ctx)
388 {
389 UserCtx ctx = (UserCtx)_ctx;
390 PetscReal mu;
391 Vec x, u, temp;
392
393 PetscFunctionBegin;
394 mu = ctx->mu;
395 x = ctx->workRight[4];
396 u = ctx->workRight[6];
397 temp = ctx->workRight[10];
398 PetscCall(GradientRegularization(tao, z, V, _ctx));
399 PetscCall(VecCopy(z, temp));
400 /* temp = x + u -z */
401 PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u));
402 PetscCall(VecAXPY(V, -mu, temp));
403 PetscFunctionReturn(PETSC_SUCCESS);
404 }
405
406 /* NORM_2 Case: returns diag(mu)
407 * NORM_1 Case: FTF + diag(mu) */
HessianRegularizationADMM(Tao tao,Vec x,Mat H,Mat Hpre,void * _ctx)408 static PetscErrorCode HessianRegularizationADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
409 {
410 UserCtx ctx = (UserCtx)_ctx;
411
412 PetscFunctionBegin;
413 if (ctx->p == NORM_2) {
414 /* Identity matrix scaled by mu */
415 PetscCall(MatZeroEntries(H));
416 PetscCall(MatShift(H, ctx->mu));
417 if (Hpre != H) {
418 PetscCall(MatZeroEntries(Hpre));
419 PetscCall(MatShift(Hpre, ctx->mu));
420 }
421 } else if (ctx->p == NORM_1) {
422 PetscCall(HessianMisfit(tao, x, H, Hpre, (void *)ctx));
423 PetscCall(MatShift(H, ctx->mu));
424 if (Hpre != H) PetscCall(MatShift(Hpre, ctx->mu));
425 } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
426 PetscFunctionReturn(PETSC_SUCCESS);
427 }
428
429 /* NORM_2 Case : (1/2) * ||F x - d||^2 + 0.5 * || x ||_p
430 * NORM_1 Case : (1/2) * ||F x - d||^2 + || x ||_p */
ObjectiveComplete(Tao tao,Vec x,PetscReal * J,PetscCtx ctx)431 static PetscErrorCode ObjectiveComplete(Tao tao, Vec x, PetscReal *J, PetscCtx ctx)
432 {
433 PetscReal Jm, Jr;
434
435 PetscFunctionBegin;
436 PetscCall(ObjectiveMisfit(tao, x, &Jm, ctx));
437 PetscCall(ObjectiveRegularization(tao, x, &Jr, ctx));
438 *J = Jm + Jr;
439 PetscFunctionReturn(PETSC_SUCCESS);
440 }
441
442 /* NORM_2 Case: FTFx - FTd + x
443 * NORM_1 Case: FTFx - FTd + x/(|x| + eps) */
GradientComplete(Tao tao,Vec x,Vec V,PetscCtx ctx)444 static PetscErrorCode GradientComplete(Tao tao, Vec x, Vec V, PetscCtx ctx)
445 {
446 UserCtx cntx = (UserCtx)ctx;
447
448 PetscFunctionBegin;
449 PetscCall(GradientMisfit(tao, x, cntx->workRight[2], ctx));
450 PetscCall(GradientRegularization(tao, x, cntx->workRight[3], ctx));
451 PetscCall(VecWAXPY(V, 1, cntx->workRight[2], cntx->workRight[3]));
452 PetscFunctionReturn(PETSC_SUCCESS);
453 }
454
455 /* NORM_2 Case: diag(mu) + FTF
456 * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) + FTF */
HessianComplete(Tao tao,Vec x,Mat H,Mat Hpre,PetscCtx ctx)457 static PetscErrorCode HessianComplete(Tao tao, Vec x, Mat H, Mat Hpre, PetscCtx ctx)
458 {
459 Mat tempH;
460
461 PetscFunctionBegin;
462 PetscCall(MatDuplicate(H, MAT_SHARE_NONZERO_PATTERN, &tempH));
463 PetscCall(HessianMisfit(tao, x, H, H, ctx));
464 PetscCall(HessianRegularization(tao, x, tempH, tempH, ctx));
465 PetscCall(MatAXPY(H, 1., tempH, DIFFERENT_NONZERO_PATTERN));
466 if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN));
467 PetscCall(MatDestroy(&tempH));
468 PetscFunctionReturn(PETSC_SUCCESS);
469 }
470
TaoSolveADMM(UserCtx ctx,Vec x)471 static PetscErrorCode TaoSolveADMM(UserCtx ctx, Vec x)
472 {
473 PetscInt i;
474 PetscReal u_norm, r_norm, s_norm, primal, dual, x_norm, z_norm;
475 Tao tao1, tao2;
476 Vec xk, z, u, diff, zold, zdiff, temp;
477 PetscReal mu;
478
479 PetscFunctionBegin;
480 xk = ctx->workRight[4];
481 z = ctx->workRight[5];
482 u = ctx->workRight[6];
483 diff = ctx->workRight[7];
484 zold = ctx->workRight[8];
485 zdiff = ctx->workRight[9];
486 temp = ctx->workRight[11];
487 mu = ctx->mu;
488 PetscCall(VecSet(u, 0.));
489 PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao1));
490 PetscCall(TaoSetType(tao1, TAONLS));
491 PetscCall(TaoSetObjective(tao1, ObjectiveMisfitADMM, (void *)ctx));
492 PetscCall(TaoSetGradient(tao1, NULL, GradientMisfitADMM, (void *)ctx));
493 PetscCall(TaoSetHessian(tao1, ctx->Hm, ctx->Hm, HessianMisfitADMM, (void *)ctx));
494 PetscCall(VecSet(xk, 0.));
495 PetscCall(TaoSetSolution(tao1, xk));
496 PetscCall(TaoSetOptionsPrefix(tao1, "misfit_"));
497 PetscCall(TaoSetFromOptions(tao1));
498 PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao2));
499 if (ctx->p == NORM_2) {
500 PetscCall(TaoSetType(tao2, TAONLS));
501 PetscCall(TaoSetObjective(tao2, ObjectiveRegularizationADMM, (void *)ctx));
502 PetscCall(TaoSetGradient(tao2, NULL, GradientRegularizationADMM, (void *)ctx));
503 PetscCall(TaoSetHessian(tao2, ctx->Hr, ctx->Hr, HessianRegularizationADMM, (void *)ctx));
504 }
505 PetscCall(VecSet(z, 0.));
506 PetscCall(TaoSetSolution(tao2, z));
507 PetscCall(TaoSetOptionsPrefix(tao2, "reg_"));
508 PetscCall(TaoSetFromOptions(tao2));
509
510 for (i = 0; i < ctx->iter; i++) {
511 PetscCall(VecCopy(z, zold));
512 PetscCall(TaoSolve(tao1)); /* Updates xk */
513 if (ctx->p == NORM_1) {
514 PetscCall(VecWAXPY(temp, 1., xk, u));
515 PetscCall(TaoSoftThreshold(temp, -ctx->alpha / mu, ctx->alpha / mu, z));
516 } else {
517 PetscCall(TaoSolve(tao2)); /* Update zk */
518 }
519 /* u = u + xk -z */
520 PetscCall(VecAXPBYPCZ(u, 1., -1., 1., xk, z));
521 /* r_norm : norm(x-z) */
522 PetscCall(VecWAXPY(diff, -1., z, xk));
523 PetscCall(VecNorm(diff, NORM_2, &r_norm));
524 /* s_norm : norm(-mu(z-zold)) */
525 PetscCall(VecWAXPY(zdiff, -1., zold, z));
526 PetscCall(VecNorm(zdiff, NORM_2, &s_norm));
527 s_norm = s_norm * mu;
528 /* primal : sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z))*/
529 PetscCall(VecNorm(xk, NORM_2, &x_norm));
530 PetscCall(VecNorm(z, NORM_2, &z_norm));
531 primal = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * PetscMax(x_norm, z_norm);
532 /* Duality : sqrt(n)*ABSTOL + RELTOL*norm(mu*u)*/
533 PetscCall(VecNorm(u, NORM_2, &u_norm));
534 dual = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * u_norm * mu;
535 PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao1), "Iter %" PetscInt_FMT " : ||x-z||: %g, mu*||z-zold||: %g\n", i, (double)r_norm, (double)s_norm));
536 if (r_norm < primal && s_norm < dual) break;
537 }
538 PetscCall(VecCopy(xk, x));
539 PetscCall(TaoDestroy(&tao1));
540 PetscCall(TaoDestroy(&tao2));
541 PetscFunctionReturn(PETSC_SUCCESS);
542 }
543
544 /* Second order Taylor remainder convergence test */
TaylorTest(UserCtx ctx,Tao tao,Vec x,PetscReal * C)545 static PetscErrorCode TaylorTest(UserCtx ctx, Tao tao, Vec x, PetscReal *C)
546 {
547 PetscReal h, J, temp;
548 PetscInt i, j;
549 PetscInt numValues;
550 PetscReal Jx, Jxhat_comp, Jxhat_pred;
551 PetscReal *Js, *hs;
552 PetscReal gdotdx;
553 PetscReal minrate = PETSC_MAX_REAL;
554 MPI_Comm comm = PetscObjectComm((PetscObject)x);
555 Vec g, dx, xhat;
556
557 PetscFunctionBegin;
558 PetscCall(VecDuplicate(x, &g));
559 PetscCall(VecDuplicate(x, &xhat));
560 /* choose a perturbation direction */
561 PetscCall(VecDuplicate(x, &dx));
562 PetscCall(VecSetRandom(dx, ctx->rctx));
563 /* evaluate objective at x: J(x) */
564 PetscCall(TaoComputeObjective(tao, x, &Jx));
565 /* evaluate gradient at x, save in vector g */
566 PetscCall(TaoComputeGradient(tao, x, g));
567 PetscCall(VecDot(g, dx, &gdotdx));
568
569 for (numValues = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor) numValues++;
570 PetscCall(PetscCalloc2(numValues, &Js, numValues, &hs));
571 for (i = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor, i++) {
572 PetscCall(VecWAXPY(xhat, h, dx, x));
573 PetscCall(TaoComputeObjective(tao, xhat, &Jxhat_comp));
574 /* J(\hat(x)) \approx J(x) + g^T (xhat - x) = J(x) + h * g^T dx */
575 Jxhat_pred = Jx + h * gdotdx;
576 /* Vector to dJdm scalar? Dot?*/
577 J = PetscAbsReal(Jxhat_comp - Jxhat_pred);
578 PetscCall(PetscPrintf(comm, "J(xhat): %g, predicted: %g, diff %g\n", (double)Jxhat_comp, (double)Jxhat_pred, (double)J));
579 Js[i] = J;
580 hs[i] = h;
581 }
582 for (j = 1; j < numValues; j++) {
583 temp = PetscLogReal(Js[j] / Js[j - 1]) / PetscLogReal(hs[j] / hs[j - 1]);
584 PetscCall(PetscPrintf(comm, "Convergence rate step %" PetscInt_FMT ": %g\n", j - 1, (double)temp));
585 minrate = PetscMin(minrate, temp);
586 }
587 /* If O is not ~2, then the test is wrong */
588 PetscCall(PetscFree2(Js, hs));
589 *C = minrate;
590 PetscCall(VecDestroy(&dx));
591 PetscCall(VecDestroy(&xhat));
592 PetscCall(VecDestroy(&g));
593 PetscFunctionReturn(PETSC_SUCCESS);
594 }
595
main(int argc,char ** argv)596 int main(int argc, char **argv)
597 {
598 UserCtx ctx;
599 Tao tao;
600 Vec x;
601 Mat H;
602
603 PetscFunctionBeginUser;
604 PetscCall(PetscInitialize(&argc, &argv, NULL, help));
605 PetscCall(PetscNew(&ctx));
606 PetscCall(ConfigureContext(ctx));
607 /* Define two functions that could pass as objectives to TaoSetObjective(): one
608 * for the misfit component, and one for the regularization component */
609 /* ObjectiveMisfit() and ObjectiveRegularization() */
610
611 /* Define a single function that calls both components adds them together: the complete objective,
612 * in the absence of a Tao implementation that handles separability */
613 /* ObjectiveComplete() */
614 PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao));
615 PetscCall(TaoSetType(tao, TAONM));
616 PetscCall(TaoSetObjective(tao, ObjectiveComplete, (void *)ctx));
617 PetscCall(TaoSetGradient(tao, NULL, GradientComplete, (void *)ctx));
618 PetscCall(MatDuplicate(ctx->W, MAT_SHARE_NONZERO_PATTERN, &H));
619 PetscCall(TaoSetHessian(tao, H, H, HessianComplete, (void *)ctx));
620 PetscCall(MatCreateVecs(ctx->F, NULL, &x));
621 PetscCall(VecSet(x, 0.));
622 PetscCall(TaoSetSolution(tao, x));
623 PetscCall(TaoSetFromOptions(tao));
624 if (ctx->use_admm) PetscCall(TaoSolveADMM(ctx, x));
625 else PetscCall(TaoSolve(tao));
626 /* examine solution */
627 PetscCall(VecViewFromOptions(x, NULL, "-view_sol"));
628 if (ctx->taylor) {
629 PetscReal rate;
630 PetscCall(TaylorTest(ctx, tao, x, &rate));
631 }
632 if (ctx->soft) PetscCall(TaoSoftThreshold(x, 0., 0., x));
633 PetscCall(MatDestroy(&H));
634 PetscCall(TaoDestroy(&tao));
635 PetscCall(VecDestroy(&x));
636 PetscCall(DestroyContext(&ctx));
637 PetscCall(PetscFinalize());
638 return 0;
639 }
640
641 /*TEST
642
643 build:
644 requires: !complex
645
646 test:
647 suffix: 0
648 args:
649
650 test:
651 suffix: l1_1
652 args: -p 1 -tao_type lmvm -alpha 1. -epsilon 1.e-7 -m 64 -n 64 -view_sol -matrix_format 1
653
654 test:
655 suffix: hessian_1
656 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nls
657
658 test:
659 suffix: hessian_2
660 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nls
661
662 test:
663 suffix: nm_1
664 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nm -tao_max_it 50
665
666 test:
667 suffix: nm_2
668 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nm -tao_max_it 50
669
670 test:
671 suffix: lmvm_1
672 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type lmvm -tao_max_it 40
673
674 test:
675 suffix: lmvm_2
676 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type lmvm -tao_max_it 15
677
678 test:
679 suffix: soft_threshold_admm_1
680 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm
681
682 test:
683 suffix: hessian_admm_1
684 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nls -misfit_tao_type nls
685
686 test:
687 suffix: hessian_admm_2
688 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nls -misfit_tao_type nls
689
690 test:
691 suffix: nm_admm_1
692 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nm -misfit_tao_type nm
693
694 test:
695 suffix: nm_admm_2
696 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nm -misfit_tao_type nm -iter 7
697
698 test:
699 suffix: lmvm_admm_1
700 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm
701
702 test:
703 suffix: lmvm_admm_2
704 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm
705
706 test:
707 suffix: soft
708 args: -taylor 0 -soft 1
709 output_file: output/empty.out
710
711 TEST*/
712