1c4762a1bSJed Brown static char help[] = "Simple example to test separable objective optimizers.\n";
2c4762a1bSJed Brown
3c4762a1bSJed Brown #include <petsc.h>
4c4762a1bSJed Brown #include <petsctao.h>
5c4762a1bSJed Brown #include <petscvec.h>
6c4762a1bSJed Brown #include <petscmath.h>
7c4762a1bSJed Brown
8c4762a1bSJed Brown #define NWORKLEFT 4
9c4762a1bSJed Brown #define NWORKRIGHT 12
10c4762a1bSJed Brown
119371c9d4SSatish Balay typedef struct _UserCtx {
12c4762a1bSJed Brown PetscInt m; /* The row dimension of F */
13c4762a1bSJed Brown PetscInt n; /* The column dimension of F */
14c4762a1bSJed Brown PetscInt matops; /* Matrix format. 0 for stencil, 1 for random */
15be87f6c0SPierre Jolivet PetscInt iter; /* Number of iterations for ADMM */
16c4762a1bSJed Brown PetscReal hStart; /* Starting point for Taylor test */
17c4762a1bSJed Brown PetscReal hFactor; /* Taylor test step factor */
18c4762a1bSJed Brown PetscReal hMin; /* Taylor test end goal */
19c4762a1bSJed Brown PetscReal alpha; /* regularization constant applied to || x ||_p */
20c4762a1bSJed Brown PetscReal eps; /* small constant for approximating gradient of || x ||_1 */
21c4762a1bSJed Brown PetscReal mu; /* the augmented Lagrangian term in ADMM */
22c4762a1bSJed Brown PetscReal abstol;
23c4762a1bSJed Brown PetscReal reltol;
24c4762a1bSJed Brown Mat F; /* matrix in least squares component $(1/2) * || F x - d ||_2^2$ */
25c4762a1bSJed Brown Mat W; /* Workspace matrix. ATA */
26c4762a1bSJed Brown Mat Hm; /* Hessian Misfit*/
27c4762a1bSJed Brown Mat Hr; /* Hessian Reg*/
28c4762a1bSJed Brown Vec d; /* RHS in least squares component $(1/2) * || F x - d ||_2^2$ */
29c4762a1bSJed Brown Vec workLeft[NWORKLEFT]; /* Workspace for temporary vec */
30c4762a1bSJed Brown Vec workRight[NWORKRIGHT]; /* Workspace for temporary vec */
31c4762a1bSJed Brown NormType p;
32c4762a1bSJed Brown PetscRandom rctx;
3384430a0dSHansol Suh PetscBool soft;
34c4762a1bSJed Brown PetscBool taylor; /* Flag to determine whether to run Taylor test or not */
35c4762a1bSJed Brown PetscBool use_admm; /* Flag to determine whether to run Taylor test or not */
36c4762a1bSJed Brown } *UserCtx;
37c4762a1bSJed Brown
CreateRHS(UserCtx ctx)38d71ae5a4SJacob Faibussowitsch static PetscErrorCode CreateRHS(UserCtx ctx)
39d71ae5a4SJacob Faibussowitsch {
40c4762a1bSJed Brown PetscFunctionBegin;
41c4762a1bSJed Brown /* build the rhs d in ctx */
42f4f49eeaSPierre Jolivet PetscCall(VecCreate(PETSC_COMM_WORLD, &ctx->d));
439566063dSJacob Faibussowitsch PetscCall(VecSetSizes(ctx->d, PETSC_DECIDE, ctx->m));
449566063dSJacob Faibussowitsch PetscCall(VecSetFromOptions(ctx->d));
459566063dSJacob Faibussowitsch PetscCall(VecSetRandom(ctx->d, ctx->rctx));
463ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
47c4762a1bSJed Brown }
48c4762a1bSJed Brown
CreateMatrix(UserCtx ctx)49d71ae5a4SJacob Faibussowitsch static PetscErrorCode CreateMatrix(UserCtx ctx)
50d71ae5a4SJacob Faibussowitsch {
51c4762a1bSJed Brown PetscInt Istart, Iend, i, j, Ii, gridN, I_n, I_s, I_e, I_w;
52c4762a1bSJed Brown PetscLogStage stage;
53c4762a1bSJed Brown
54c4762a1bSJed Brown PetscFunctionBegin;
55c4762a1bSJed Brown /* build the matrix F in ctx */
56f4f49eeaSPierre Jolivet PetscCall(MatCreate(PETSC_COMM_WORLD, &ctx->F));
579566063dSJacob Faibussowitsch PetscCall(MatSetSizes(ctx->F, PETSC_DECIDE, PETSC_DECIDE, ctx->m, ctx->n));
589566063dSJacob Faibussowitsch PetscCall(MatSetType(ctx->F, MATAIJ)); /* TODO: Decide specific SetType other than dummy*/
599566063dSJacob Faibussowitsch PetscCall(MatMPIAIJSetPreallocation(ctx->F, 5, NULL, 5, NULL)); /*TODO: some number other than 5?*/
609566063dSJacob Faibussowitsch PetscCall(MatSeqAIJSetPreallocation(ctx->F, 5, NULL));
619566063dSJacob Faibussowitsch PetscCall(MatSetUp(ctx->F));
629566063dSJacob Faibussowitsch PetscCall(MatGetOwnershipRange(ctx->F, &Istart, &Iend));
639566063dSJacob Faibussowitsch PetscCall(PetscLogStageRegister("Assembly", &stage));
649566063dSJacob Faibussowitsch PetscCall(PetscLogStagePush(stage));
65c4762a1bSJed Brown
663c859ba3SBarry Smith /* Set matrix elements in 2-D five point stencil format. */
67f4f49eeaSPierre Jolivet if (!ctx->matops) {
683c859ba3SBarry Smith PetscCheck(ctx->m == ctx->n, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Stencil matrix must be square");
69c4762a1bSJed Brown gridN = (PetscInt)PetscSqrtReal((PetscReal)ctx->m);
703c859ba3SBarry Smith PetscCheck(gridN * gridN == ctx->m, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Number of rows must be square");
71c4762a1bSJed Brown for (Ii = Istart; Ii < Iend; Ii++) {
729371c9d4SSatish Balay i = Ii / gridN;
739371c9d4SSatish Balay j = Ii % gridN;
74c4762a1bSJed Brown I_n = i * gridN + j + 1;
75c4762a1bSJed Brown if (j + 1 >= gridN) I_n = -1;
76c4762a1bSJed Brown I_s = i * gridN + j - 1;
77c4762a1bSJed Brown if (j - 1 < 0) I_s = -1;
78c4762a1bSJed Brown I_e = (i + 1) * gridN + j;
79c4762a1bSJed Brown if (i + 1 >= gridN) I_e = -1;
80c4762a1bSJed Brown I_w = (i - 1) * gridN + j;
81c4762a1bSJed Brown if (i - 1 < 0) I_w = -1;
829566063dSJacob Faibussowitsch PetscCall(MatSetValue(ctx->F, Ii, Ii, 4., INSERT_VALUES));
839566063dSJacob Faibussowitsch PetscCall(MatSetValue(ctx->F, Ii, I_n, -1., INSERT_VALUES));
849566063dSJacob Faibussowitsch PetscCall(MatSetValue(ctx->F, Ii, I_s, -1., INSERT_VALUES));
859566063dSJacob Faibussowitsch PetscCall(MatSetValue(ctx->F, Ii, I_e, -1., INSERT_VALUES));
869566063dSJacob Faibussowitsch PetscCall(MatSetValue(ctx->F, Ii, I_w, -1., INSERT_VALUES));
87c4762a1bSJed Brown }
889566063dSJacob Faibussowitsch } else PetscCall(MatSetRandom(ctx->F, ctx->rctx));
899566063dSJacob Faibussowitsch PetscCall(MatAssemblyBegin(ctx->F, MAT_FINAL_ASSEMBLY));
909566063dSJacob Faibussowitsch PetscCall(MatAssemblyEnd(ctx->F, MAT_FINAL_ASSEMBLY));
919566063dSJacob Faibussowitsch PetscCall(PetscLogStagePop());
92c4762a1bSJed Brown /* Stencil matrix is symmetric. Setting symmetric flag for ICC/Cholesky preconditioner */
93f4f49eeaSPierre Jolivet if (!ctx->matops) PetscCall(MatSetOption(ctx->F, MAT_SYMMETRIC, PETSC_TRUE));
94fb842aefSJose E. Roman PetscCall(MatTransposeMatMult(ctx->F, ctx->F, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &ctx->W));
95c4762a1bSJed Brown /* Setup Hessian Workspace in same shape as W */
96f4f49eeaSPierre Jolivet PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &ctx->Hm));
97f4f49eeaSPierre Jolivet PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &ctx->Hr));
983ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
99c4762a1bSJed Brown }
100c4762a1bSJed Brown
SetupWorkspace(UserCtx ctx)101d71ae5a4SJacob Faibussowitsch static PetscErrorCode SetupWorkspace(UserCtx ctx)
102d71ae5a4SJacob Faibussowitsch {
103c4762a1bSJed Brown PetscInt i;
104c4762a1bSJed Brown
105c4762a1bSJed Brown PetscFunctionBegin;
1069566063dSJacob Faibussowitsch PetscCall(MatCreateVecs(ctx->F, &ctx->workLeft[0], &ctx->workRight[0]));
107f4f49eeaSPierre Jolivet for (i = 1; i < NWORKLEFT; i++) PetscCall(VecDuplicate(ctx->workLeft[0], &ctx->workLeft[i]));
108f4f49eeaSPierre Jolivet for (i = 1; i < NWORKRIGHT; i++) PetscCall(VecDuplicate(ctx->workRight[0], &ctx->workRight[i]));
1093ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
110c4762a1bSJed Brown }
111c4762a1bSJed Brown
ConfigureContext(UserCtx ctx)112d71ae5a4SJacob Faibussowitsch static PetscErrorCode ConfigureContext(UserCtx ctx)
113d71ae5a4SJacob Faibussowitsch {
114c4762a1bSJed Brown PetscFunctionBegin;
115c4762a1bSJed Brown ctx->m = 16;
116c4762a1bSJed Brown ctx->n = 16;
117c4762a1bSJed Brown ctx->eps = 1.e-3;
118c4762a1bSJed Brown ctx->abstol = 1.e-4;
119c4762a1bSJed Brown ctx->reltol = 1.e-2;
120c4762a1bSJed Brown ctx->hStart = 1.;
121c4762a1bSJed Brown ctx->hMin = 1.e-3;
122c4762a1bSJed Brown ctx->hFactor = 0.5;
123c4762a1bSJed Brown ctx->alpha = 1.;
124c4762a1bSJed Brown ctx->mu = 1.0;
125c4762a1bSJed Brown ctx->matops = 0;
126c4762a1bSJed Brown ctx->iter = 10;
127c4762a1bSJed Brown ctx->p = NORM_2;
12884430a0dSHansol Suh ctx->soft = PETSC_FALSE;
129c4762a1bSJed Brown ctx->taylor = PETSC_TRUE;
130c4762a1bSJed Brown ctx->use_admm = PETSC_FALSE;
131d0609cedSBarry Smith PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "ex4.c");
132f4f49eeaSPierre Jolivet PetscCall(PetscOptionsInt("-m", "The row dimension of matrix F", "ex4.c", ctx->m, &ctx->m, NULL));
133f4f49eeaSPierre Jolivet PetscCall(PetscOptionsInt("-n", "The column dimension of matrix F", "ex4.c", ctx->n, &ctx->n, NULL));
134f4f49eeaSPierre Jolivet PetscCall(PetscOptionsInt("-matrix_format", "Decide format of F matrix. 0 for stencil, 1 for random", "ex4.c", ctx->matops, &ctx->matops, NULL));
135f4f49eeaSPierre Jolivet PetscCall(PetscOptionsInt("-iter", "Iteration number ADMM", "ex4.c", ctx->iter, &ctx->iter, NULL));
136f4f49eeaSPierre Jolivet PetscCall(PetscOptionsReal("-alpha", "The regularization multiplier. 1 default", "ex4.c", ctx->alpha, &ctx->alpha, NULL));
137f4f49eeaSPierre Jolivet PetscCall(PetscOptionsReal("-epsilon", "The small constant added to |x_i| in the denominator to approximate the gradient of ||x||_1", "ex4.c", ctx->eps, &ctx->eps, NULL));
138f4f49eeaSPierre Jolivet PetscCall(PetscOptionsReal("-mu", "The augmented lagrangian multiplier in ADMM", "ex4.c", ctx->mu, &ctx->mu, NULL));
139f4f49eeaSPierre Jolivet PetscCall(PetscOptionsReal("-hStart", "Taylor test starting point. 1 default.", "ex4.c", ctx->hStart, &ctx->hStart, NULL));
140f4f49eeaSPierre Jolivet PetscCall(PetscOptionsReal("-hFactor", "Taylor test multiplier factor. 0.5 default", "ex4.c", ctx->hFactor, &ctx->hFactor, NULL));
141f4f49eeaSPierre Jolivet PetscCall(PetscOptionsReal("-hMin", "Taylor test ending condition. 1.e-3 default", "ex4.c", ctx->hMin, &ctx->hMin, NULL));
142f4f49eeaSPierre Jolivet PetscCall(PetscOptionsReal("-abstol", "Absolute stopping criterion for ADMM", "ex4.c", ctx->abstol, &ctx->abstol, NULL));
143f4f49eeaSPierre Jolivet PetscCall(PetscOptionsReal("-reltol", "Relative stopping criterion for ADMM", "ex4.c", ctx->reltol, &ctx->reltol, NULL));
144f4f49eeaSPierre Jolivet PetscCall(PetscOptionsBool("-taylor", "Flag for Taylor test. Default is true.", "ex4.c", ctx->taylor, &ctx->taylor, NULL));
145f4f49eeaSPierre Jolivet PetscCall(PetscOptionsBool("-soft", "Flag for testing soft threshold no-op case. Default is false.", "ex4.c", ctx->soft, &ctx->soft, NULL));
146f4f49eeaSPierre Jolivet PetscCall(PetscOptionsBool("-use_admm", "Use the ADMM solver in this example.", "ex4.c", ctx->use_admm, &ctx->use_admm, NULL));
147f4f49eeaSPierre Jolivet PetscCall(PetscOptionsEnum("-p", "Norm type.", "ex4.c", NormTypes, (PetscEnum)ctx->p, (PetscEnum *)&ctx->p, NULL));
148d0609cedSBarry Smith PetscOptionsEnd();
149c4762a1bSJed Brown /* Creating random ctx */
150f4f49eeaSPierre Jolivet PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &ctx->rctx));
1519566063dSJacob Faibussowitsch PetscCall(PetscRandomSetFromOptions(ctx->rctx));
1529566063dSJacob Faibussowitsch PetscCall(CreateMatrix(ctx));
1539566063dSJacob Faibussowitsch PetscCall(CreateRHS(ctx));
1549566063dSJacob Faibussowitsch PetscCall(SetupWorkspace(ctx));
1553ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
156c4762a1bSJed Brown }
157c4762a1bSJed Brown
DestroyContext(UserCtx * ctx)158d71ae5a4SJacob Faibussowitsch static PetscErrorCode DestroyContext(UserCtx *ctx)
159d71ae5a4SJacob Faibussowitsch {
160c4762a1bSJed Brown PetscInt i;
161c4762a1bSJed Brown
162c4762a1bSJed Brown PetscFunctionBegin;
16357508eceSPierre Jolivet PetscCall(MatDestroy(&(*ctx)->F));
16457508eceSPierre Jolivet PetscCall(MatDestroy(&(*ctx)->W));
16557508eceSPierre Jolivet PetscCall(MatDestroy(&(*ctx)->Hm));
16657508eceSPierre Jolivet PetscCall(MatDestroy(&(*ctx)->Hr));
16757508eceSPierre Jolivet PetscCall(VecDestroy(&(*ctx)->d));
16857508eceSPierre Jolivet for (i = 0; i < NWORKLEFT; i++) PetscCall(VecDestroy(&(*ctx)->workLeft[i]));
16957508eceSPierre Jolivet for (i = 0; i < NWORKRIGHT; i++) PetscCall(VecDestroy(&(*ctx)->workRight[i]));
17057508eceSPierre Jolivet PetscCall(PetscRandomDestroy(&(*ctx)->rctx));
1719566063dSJacob Faibussowitsch PetscCall(PetscFree(*ctx));
1723ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
173c4762a1bSJed Brown }
174c4762a1bSJed Brown
175c4762a1bSJed Brown /* compute (1/2) * ||F x - d||^2 */
ObjectiveMisfit(Tao tao,Vec x,PetscReal * J,void * _ctx)176d71ae5a4SJacob Faibussowitsch static PetscErrorCode ObjectiveMisfit(Tao tao, Vec x, PetscReal *J, void *_ctx)
177d71ae5a4SJacob Faibussowitsch {
178c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx;
179c4762a1bSJed Brown Vec y;
180c4762a1bSJed Brown
181c4762a1bSJed Brown PetscFunctionBegin;
182c4762a1bSJed Brown y = ctx->workLeft[0];
1839566063dSJacob Faibussowitsch PetscCall(MatMult(ctx->F, x, y));
1849566063dSJacob Faibussowitsch PetscCall(VecAXPY(y, -1., ctx->d));
1859566063dSJacob Faibussowitsch PetscCall(VecDot(y, y, J));
186c4762a1bSJed Brown *J *= 0.5;
1873ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
188c4762a1bSJed Brown }
189c4762a1bSJed Brown
190c4762a1bSJed Brown /* compute V = FTFx - FTd */
GradientMisfit(Tao tao,Vec x,Vec V,void * _ctx)191d71ae5a4SJacob Faibussowitsch static PetscErrorCode GradientMisfit(Tao tao, Vec x, Vec V, void *_ctx)
192d71ae5a4SJacob Faibussowitsch {
193c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx;
194c4762a1bSJed Brown Vec FTFx, FTd;
195c4762a1bSJed Brown
196c4762a1bSJed Brown PetscFunctionBegin;
197c4762a1bSJed Brown /* work1 is A^T Ax, work2 is Ab, W is A^T A*/
198c4762a1bSJed Brown FTFx = ctx->workRight[0];
199c4762a1bSJed Brown FTd = ctx->workRight[1];
2009566063dSJacob Faibussowitsch PetscCall(MatMult(ctx->W, x, FTFx));
2019566063dSJacob Faibussowitsch PetscCall(MatMultTranspose(ctx->F, ctx->d, FTd));
2029566063dSJacob Faibussowitsch PetscCall(VecWAXPY(V, -1., FTd, FTFx));
2033ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
204c4762a1bSJed Brown }
205c4762a1bSJed Brown
206c4762a1bSJed Brown /* returns FTF */
HessianMisfit(Tao tao,Vec x,Mat H,Mat Hpre,void * _ctx)207d71ae5a4SJacob Faibussowitsch static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
208d71ae5a4SJacob Faibussowitsch {
209c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx;
210c4762a1bSJed Brown
211c4762a1bSJed Brown PetscFunctionBegin;
2129566063dSJacob Faibussowitsch if (H != ctx->W) PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
2139566063dSJacob Faibussowitsch if (Hpre != ctx->W) PetscCall(MatCopy(ctx->W, Hpre, DIFFERENT_NONZERO_PATTERN));
2143ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
215c4762a1bSJed Brown }
216c4762a1bSJed Brown
217c4762a1bSJed Brown /* computes augment Lagrangian objective (with scaled dual):
218c4762a1bSJed Brown * 0.5 * ||F x - d||^2 + 0.5 * mu ||x - z + u||^2 */
ObjectiveMisfitADMM(Tao tao,Vec x,PetscReal * J,void * _ctx)219d71ae5a4SJacob Faibussowitsch static PetscErrorCode ObjectiveMisfitADMM(Tao tao, Vec x, PetscReal *J, void *_ctx)
220d71ae5a4SJacob Faibussowitsch {
221c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx;
222c4762a1bSJed Brown PetscReal mu, workNorm, misfit;
223c4762a1bSJed Brown Vec z, u, temp;
224c4762a1bSJed Brown
225c4762a1bSJed Brown PetscFunctionBegin;
226c4762a1bSJed Brown mu = ctx->mu;
227c4762a1bSJed Brown z = ctx->workRight[5];
228c4762a1bSJed Brown u = ctx->workRight[6];
229c4762a1bSJed Brown temp = ctx->workRight[10];
230c4762a1bSJed Brown /* misfit = f(x) */
2319566063dSJacob Faibussowitsch PetscCall(ObjectiveMisfit(tao, x, &misfit, _ctx));
2329566063dSJacob Faibussowitsch PetscCall(VecCopy(x, temp));
233c4762a1bSJed Brown /* temp = x - z + u */
2349566063dSJacob Faibussowitsch PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u));
235c4762a1bSJed Brown /* workNorm = ||x - z + u||^2 */
2369566063dSJacob Faibussowitsch PetscCall(VecDot(temp, temp, &workNorm));
237c4762a1bSJed Brown /* augment Lagrangian objective (with scaled dual): f(x) + 0.5 * mu ||x -z + u||^2 */
238c4762a1bSJed Brown *J = misfit + 0.5 * mu * workNorm;
2393ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
240c4762a1bSJed Brown }
241c4762a1bSJed Brown
242c4762a1bSJed Brown /* computes FTFx - FTd mu*(x - z + u) */
GradientMisfitADMM(Tao tao,Vec x,Vec V,void * _ctx)243d71ae5a4SJacob Faibussowitsch static PetscErrorCode GradientMisfitADMM(Tao tao, Vec x, Vec V, void *_ctx)
244d71ae5a4SJacob Faibussowitsch {
245c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx;
246c4762a1bSJed Brown PetscReal mu;
247c4762a1bSJed Brown Vec z, u, temp;
248c4762a1bSJed Brown
249c4762a1bSJed Brown PetscFunctionBegin;
250c4762a1bSJed Brown mu = ctx->mu;
251c4762a1bSJed Brown z = ctx->workRight[5];
252c4762a1bSJed Brown u = ctx->workRight[6];
253c4762a1bSJed Brown temp = ctx->workRight[10];
2549566063dSJacob Faibussowitsch PetscCall(GradientMisfit(tao, x, V, _ctx));
2559566063dSJacob Faibussowitsch PetscCall(VecCopy(x, temp));
256c4762a1bSJed Brown /* temp = x - z + u */
2579566063dSJacob Faibussowitsch PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u));
258c4762a1bSJed Brown /* V = FTFx - FTd mu*(x - z + u) */
2599566063dSJacob Faibussowitsch PetscCall(VecAXPY(V, mu, temp));
2603ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
261c4762a1bSJed Brown }
262c4762a1bSJed Brown
263c4762a1bSJed Brown /* returns FTF + diag(mu) */
HessianMisfitADMM(Tao tao,Vec x,Mat H,Mat Hpre,void * _ctx)264d71ae5a4SJacob Faibussowitsch static PetscErrorCode HessianMisfitADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
265d71ae5a4SJacob Faibussowitsch {
266c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx;
267c4762a1bSJed Brown
268c4762a1bSJed Brown PetscFunctionBegin;
2699566063dSJacob Faibussowitsch PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
2709566063dSJacob Faibussowitsch PetscCall(MatShift(H, ctx->mu));
27148a46eb9SPierre Jolivet if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN));
2723ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
273c4762a1bSJed Brown }
274c4762a1bSJed Brown
275c4762a1bSJed Brown /* computes || x ||_p (mult by 0.5 in case of NORM_2) */
ObjectiveRegularization(Tao tao,Vec x,PetscReal * J,void * _ctx)276d71ae5a4SJacob Faibussowitsch static PetscErrorCode ObjectiveRegularization(Tao tao, Vec x, PetscReal *J, void *_ctx)
277d71ae5a4SJacob Faibussowitsch {
278c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx;
279c4762a1bSJed Brown PetscReal norm;
280c4762a1bSJed Brown
281c4762a1bSJed Brown PetscFunctionBegin;
282c4762a1bSJed Brown *J = 0;
2839566063dSJacob Faibussowitsch PetscCall(VecNorm(x, ctx->p, &norm));
284c4762a1bSJed Brown if (ctx->p == NORM_2) norm = 0.5 * norm * norm;
285c4762a1bSJed Brown *J = ctx->alpha * norm;
2863ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
287c4762a1bSJed Brown }
288c4762a1bSJed Brown
289c4762a1bSJed Brown /* NORM_2 Case: return x
290c4762a1bSJed Brown * NORM_1 Case: x/(|x| + eps)
291c4762a1bSJed Brown * Else: TODO */
GradientRegularization(Tao tao,Vec x,Vec V,void * _ctx)292d71ae5a4SJacob Faibussowitsch static PetscErrorCode GradientRegularization(Tao tao, Vec x, Vec V, void *_ctx)
293d71ae5a4SJacob Faibussowitsch {
294c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx;
295c4762a1bSJed Brown PetscReal eps = ctx->eps;
296c4762a1bSJed Brown
297c4762a1bSJed Brown PetscFunctionBegin;
298c4762a1bSJed Brown if (ctx->p == NORM_2) {
2999566063dSJacob Faibussowitsch PetscCall(VecCopy(x, V));
300c4762a1bSJed Brown } else if (ctx->p == NORM_1) {
3019566063dSJacob Faibussowitsch PetscCall(VecCopy(x, ctx->workRight[1]));
3029566063dSJacob Faibussowitsch PetscCall(VecAbs(ctx->workRight[1]));
3039566063dSJacob Faibussowitsch PetscCall(VecShift(ctx->workRight[1], eps));
3049566063dSJacob Faibussowitsch PetscCall(VecPointwiseDivide(V, x, ctx->workRight[1]));
305c4762a1bSJed Brown } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
3063ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
307c4762a1bSJed Brown }
308c4762a1bSJed Brown
309c4762a1bSJed Brown /* NORM_2 Case: returns diag(mu)
310c4762a1bSJed Brown * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) */
HessianRegularization(Tao tao,Vec x,Mat H,Mat Hpre,void * _ctx)311d71ae5a4SJacob Faibussowitsch static PetscErrorCode HessianRegularization(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
312d71ae5a4SJacob Faibussowitsch {
313c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx;
314c4762a1bSJed Brown PetscReal eps = ctx->eps;
315c4762a1bSJed Brown Vec copy1, copy2, copy3;
316c4762a1bSJed Brown
317c4762a1bSJed Brown PetscFunctionBegin;
318c4762a1bSJed Brown if (ctx->p == NORM_2) {
319c4762a1bSJed Brown /* Identity matrix scaled by mu */
3209566063dSJacob Faibussowitsch PetscCall(MatZeroEntries(H));
3219566063dSJacob Faibussowitsch PetscCall(MatShift(H, ctx->mu));
322c4762a1bSJed Brown if (Hpre != H) {
3239566063dSJacob Faibussowitsch PetscCall(MatZeroEntries(Hpre));
3249566063dSJacob Faibussowitsch PetscCall(MatShift(Hpre, ctx->mu));
325c4762a1bSJed Brown }
326c4762a1bSJed Brown } else if (ctx->p == NORM_1) {
327c4762a1bSJed Brown /* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)) */
328c4762a1bSJed Brown copy1 = ctx->workRight[1];
329c4762a1bSJed Brown copy2 = ctx->workRight[2];
330c4762a1bSJed Brown copy3 = ctx->workRight[3];
331c4762a1bSJed Brown /* copy1 : 1/sqrt(x_i^2 + eps) */
3329566063dSJacob Faibussowitsch PetscCall(VecCopy(x, copy1));
3339566063dSJacob Faibussowitsch PetscCall(VecPow(copy1, 2));
3349566063dSJacob Faibussowitsch PetscCall(VecShift(copy1, eps));
3359566063dSJacob Faibussowitsch PetscCall(VecSqrtAbs(copy1));
3369566063dSJacob Faibussowitsch PetscCall(VecReciprocal(copy1));
337c4762a1bSJed Brown /* copy2: x_i^2.*/
3389566063dSJacob Faibussowitsch PetscCall(VecCopy(x, copy2));
3399566063dSJacob Faibussowitsch PetscCall(VecPow(copy2, 2));
340c4762a1bSJed Brown /* copy3: abs(x_i^2 + eps) */
3419566063dSJacob Faibussowitsch PetscCall(VecCopy(x, copy3));
3429566063dSJacob Faibussowitsch PetscCall(VecPow(copy3, 2));
3439566063dSJacob Faibussowitsch PetscCall(VecShift(copy3, eps));
3449566063dSJacob Faibussowitsch PetscCall(VecAbs(copy3));
345c4762a1bSJed Brown /* copy2: 1 - x_i^2/abs(x_i^2 + eps) */
3469566063dSJacob Faibussowitsch PetscCall(VecPointwiseDivide(copy2, copy2, copy3));
3479566063dSJacob Faibussowitsch PetscCall(VecScale(copy2, -1.));
3489566063dSJacob Faibussowitsch PetscCall(VecShift(copy2, 1.));
3499566063dSJacob Faibussowitsch PetscCall(VecAXPY(copy1, 1., copy2));
3509566063dSJacob Faibussowitsch PetscCall(VecScale(copy1, ctx->mu));
3519566063dSJacob Faibussowitsch PetscCall(MatZeroEntries(H));
3529566063dSJacob Faibussowitsch PetscCall(MatDiagonalSet(H, copy1, INSERT_VALUES));
353c4762a1bSJed Brown if (Hpre != H) {
3549566063dSJacob Faibussowitsch PetscCall(MatZeroEntries(Hpre));
3559566063dSJacob Faibussowitsch PetscCall(MatDiagonalSet(Hpre, copy1, INSERT_VALUES));
356c4762a1bSJed Brown }
357c4762a1bSJed Brown } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
3583ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
359c4762a1bSJed Brown }
360c4762a1bSJed Brown
361c4762a1bSJed Brown /* NORM_2 Case: 0.5 || x ||_2 + 0.5 * mu * ||x + u - z||^2
362c4762a1bSJed Brown * Else : || x ||_2 + 0.5 * mu * ||x + u - z||^2 */
ObjectiveRegularizationADMM(Tao tao,Vec z,PetscReal * J,void * _ctx)363d71ae5a4SJacob Faibussowitsch static PetscErrorCode ObjectiveRegularizationADMM(Tao tao, Vec z, PetscReal *J, void *_ctx)
364d71ae5a4SJacob Faibussowitsch {
365c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx;
366c4762a1bSJed Brown PetscReal mu, workNorm, reg;
367c4762a1bSJed Brown Vec x, u, temp;
368c4762a1bSJed Brown
369c4762a1bSJed Brown PetscFunctionBegin;
370c4762a1bSJed Brown mu = ctx->mu;
371c4762a1bSJed Brown x = ctx->workRight[4];
372c4762a1bSJed Brown u = ctx->workRight[6];
373c4762a1bSJed Brown temp = ctx->workRight[10];
3749566063dSJacob Faibussowitsch PetscCall(ObjectiveRegularization(tao, z, ®, _ctx));
3759566063dSJacob Faibussowitsch PetscCall(VecCopy(z, temp));
376c4762a1bSJed Brown /* temp = x + u -z */
3779566063dSJacob Faibussowitsch PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u));
378c4762a1bSJed Brown /* workNorm = ||x + u - z ||^2 */
3799566063dSJacob Faibussowitsch PetscCall(VecDot(temp, temp, &workNorm));
380c4762a1bSJed Brown *J = reg + 0.5 * mu * workNorm;
3813ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
382c4762a1bSJed Brown }
383c4762a1bSJed Brown
384c4762a1bSJed Brown /* NORM_2 Case: x - mu*(x + u - z)
385c4762a1bSJed Brown * NORM_1 Case: x/(|x| + eps) - mu*(x + u - z)
386c4762a1bSJed Brown * Else: TODO */
GradientRegularizationADMM(Tao tao,Vec z,Vec V,void * _ctx)387d71ae5a4SJacob Faibussowitsch static PetscErrorCode GradientRegularizationADMM(Tao tao, Vec z, Vec V, void *_ctx)
388d71ae5a4SJacob Faibussowitsch {
389c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx;
390c4762a1bSJed Brown PetscReal mu;
391c4762a1bSJed Brown Vec x, u, temp;
392c4762a1bSJed Brown
393c4762a1bSJed Brown PetscFunctionBegin;
394c4762a1bSJed Brown mu = ctx->mu;
395c4762a1bSJed Brown x = ctx->workRight[4];
396c4762a1bSJed Brown u = ctx->workRight[6];
397c4762a1bSJed Brown temp = ctx->workRight[10];
3989566063dSJacob Faibussowitsch PetscCall(GradientRegularization(tao, z, V, _ctx));
3999566063dSJacob Faibussowitsch PetscCall(VecCopy(z, temp));
400c4762a1bSJed Brown /* temp = x + u -z */
4019566063dSJacob Faibussowitsch PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u));
4029566063dSJacob Faibussowitsch PetscCall(VecAXPY(V, -mu, temp));
4033ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
404c4762a1bSJed Brown }
405c4762a1bSJed Brown
406c4762a1bSJed Brown /* NORM_2 Case: returns diag(mu)
407c4762a1bSJed Brown * NORM_1 Case: FTF + diag(mu) */
HessianRegularizationADMM(Tao tao,Vec x,Mat H,Mat Hpre,void * _ctx)408d71ae5a4SJacob Faibussowitsch static PetscErrorCode HessianRegularizationADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
409d71ae5a4SJacob Faibussowitsch {
410c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx;
411c4762a1bSJed Brown
412c4762a1bSJed Brown PetscFunctionBegin;
413c4762a1bSJed Brown if (ctx->p == NORM_2) {
414c4762a1bSJed Brown /* Identity matrix scaled by mu */
4159566063dSJacob Faibussowitsch PetscCall(MatZeroEntries(H));
4169566063dSJacob Faibussowitsch PetscCall(MatShift(H, ctx->mu));
417c4762a1bSJed Brown if (Hpre != H) {
4189566063dSJacob Faibussowitsch PetscCall(MatZeroEntries(Hpre));
4199566063dSJacob Faibussowitsch PetscCall(MatShift(Hpre, ctx->mu));
420c4762a1bSJed Brown }
421c4762a1bSJed Brown } else if (ctx->p == NORM_1) {
4229566063dSJacob Faibussowitsch PetscCall(HessianMisfit(tao, x, H, Hpre, (void *)ctx));
4239566063dSJacob Faibussowitsch PetscCall(MatShift(H, ctx->mu));
4249566063dSJacob Faibussowitsch if (Hpre != H) PetscCall(MatShift(Hpre, ctx->mu));
425c4762a1bSJed Brown } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
4263ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
427c4762a1bSJed Brown }
428c4762a1bSJed Brown
429c4762a1bSJed Brown /* NORM_2 Case : (1/2) * ||F x - d||^2 + 0.5 * || x ||_p
430c4762a1bSJed Brown * NORM_1 Case : (1/2) * ||F x - d||^2 + || x ||_p */
ObjectiveComplete(Tao tao,Vec x,PetscReal * J,PetscCtx ctx)431*2a8381b2SBarry Smith static PetscErrorCode ObjectiveComplete(Tao tao, Vec x, PetscReal *J, PetscCtx ctx)
432d71ae5a4SJacob Faibussowitsch {
433c4762a1bSJed Brown PetscReal Jm, Jr;
434c4762a1bSJed Brown
435c4762a1bSJed Brown PetscFunctionBegin;
4369566063dSJacob Faibussowitsch PetscCall(ObjectiveMisfit(tao, x, &Jm, ctx));
4379566063dSJacob Faibussowitsch PetscCall(ObjectiveRegularization(tao, x, &Jr, ctx));
438c4762a1bSJed Brown *J = Jm + Jr;
4393ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
440c4762a1bSJed Brown }
441c4762a1bSJed Brown
442c4762a1bSJed Brown /* NORM_2 Case: FTFx - FTd + x
443c4762a1bSJed Brown * NORM_1 Case: FTFx - FTd + x/(|x| + eps) */
GradientComplete(Tao tao,Vec x,Vec V,PetscCtx ctx)444*2a8381b2SBarry Smith static PetscErrorCode GradientComplete(Tao tao, Vec x, Vec V, PetscCtx ctx)
445d71ae5a4SJacob Faibussowitsch {
446c4762a1bSJed Brown UserCtx cntx = (UserCtx)ctx;
447c4762a1bSJed Brown
448c4762a1bSJed Brown PetscFunctionBegin;
4499566063dSJacob Faibussowitsch PetscCall(GradientMisfit(tao, x, cntx->workRight[2], ctx));
4509566063dSJacob Faibussowitsch PetscCall(GradientRegularization(tao, x, cntx->workRight[3], ctx));
4519566063dSJacob Faibussowitsch PetscCall(VecWAXPY(V, 1, cntx->workRight[2], cntx->workRight[3]));
4523ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
453c4762a1bSJed Brown }
454c4762a1bSJed Brown
455c4762a1bSJed Brown /* NORM_2 Case: diag(mu) + FTF
456c4762a1bSJed Brown * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) + FTF */
HessianComplete(Tao tao,Vec x,Mat H,Mat Hpre,PetscCtx ctx)457*2a8381b2SBarry Smith static PetscErrorCode HessianComplete(Tao tao, Vec x, Mat H, Mat Hpre, PetscCtx ctx)
458d71ae5a4SJacob Faibussowitsch {
459c4762a1bSJed Brown Mat tempH;
460c4762a1bSJed Brown
461c4762a1bSJed Brown PetscFunctionBegin;
4629566063dSJacob Faibussowitsch PetscCall(MatDuplicate(H, MAT_SHARE_NONZERO_PATTERN, &tempH));
4639566063dSJacob Faibussowitsch PetscCall(HessianMisfit(tao, x, H, H, ctx));
4649566063dSJacob Faibussowitsch PetscCall(HessianRegularization(tao, x, tempH, tempH, ctx));
4659566063dSJacob Faibussowitsch PetscCall(MatAXPY(H, 1., tempH, DIFFERENT_NONZERO_PATTERN));
46648a46eb9SPierre Jolivet if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN));
4679566063dSJacob Faibussowitsch PetscCall(MatDestroy(&tempH));
4683ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
469c4762a1bSJed Brown }
470c4762a1bSJed Brown
TaoSolveADMM(UserCtx ctx,Vec x)471d71ae5a4SJacob Faibussowitsch static PetscErrorCode TaoSolveADMM(UserCtx ctx, Vec x)
472d71ae5a4SJacob Faibussowitsch {
473c4762a1bSJed Brown PetscInt i;
474c4762a1bSJed Brown PetscReal u_norm, r_norm, s_norm, primal, dual, x_norm, z_norm;
475c4762a1bSJed Brown Tao tao1, tao2;
476c4762a1bSJed Brown Vec xk, z, u, diff, zold, zdiff, temp;
477c4762a1bSJed Brown PetscReal mu;
478c4762a1bSJed Brown
479c4762a1bSJed Brown PetscFunctionBegin;
480c4762a1bSJed Brown xk = ctx->workRight[4];
481c4762a1bSJed Brown z = ctx->workRight[5];
482c4762a1bSJed Brown u = ctx->workRight[6];
483c4762a1bSJed Brown diff = ctx->workRight[7];
484c4762a1bSJed Brown zold = ctx->workRight[8];
485c4762a1bSJed Brown zdiff = ctx->workRight[9];
486c4762a1bSJed Brown temp = ctx->workRight[11];
487c4762a1bSJed Brown mu = ctx->mu;
4889566063dSJacob Faibussowitsch PetscCall(VecSet(u, 0.));
4899566063dSJacob Faibussowitsch PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao1));
4909566063dSJacob Faibussowitsch PetscCall(TaoSetType(tao1, TAONLS));
4919566063dSJacob Faibussowitsch PetscCall(TaoSetObjective(tao1, ObjectiveMisfitADMM, (void *)ctx));
4929566063dSJacob Faibussowitsch PetscCall(TaoSetGradient(tao1, NULL, GradientMisfitADMM, (void *)ctx));
4939566063dSJacob Faibussowitsch PetscCall(TaoSetHessian(tao1, ctx->Hm, ctx->Hm, HessianMisfitADMM, (void *)ctx));
4949566063dSJacob Faibussowitsch PetscCall(VecSet(xk, 0.));
4959566063dSJacob Faibussowitsch PetscCall(TaoSetSolution(tao1, xk));
4969566063dSJacob Faibussowitsch PetscCall(TaoSetOptionsPrefix(tao1, "misfit_"));
4979566063dSJacob Faibussowitsch PetscCall(TaoSetFromOptions(tao1));
4989566063dSJacob Faibussowitsch PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao2));
499c4762a1bSJed Brown if (ctx->p == NORM_2) {
5009566063dSJacob Faibussowitsch PetscCall(TaoSetType(tao2, TAONLS));
5019566063dSJacob Faibussowitsch PetscCall(TaoSetObjective(tao2, ObjectiveRegularizationADMM, (void *)ctx));
5029566063dSJacob Faibussowitsch PetscCall(TaoSetGradient(tao2, NULL, GradientRegularizationADMM, (void *)ctx));
5039566063dSJacob Faibussowitsch PetscCall(TaoSetHessian(tao2, ctx->Hr, ctx->Hr, HessianRegularizationADMM, (void *)ctx));
504c4762a1bSJed Brown }
5059566063dSJacob Faibussowitsch PetscCall(VecSet(z, 0.));
5069566063dSJacob Faibussowitsch PetscCall(TaoSetSolution(tao2, z));
5079566063dSJacob Faibussowitsch PetscCall(TaoSetOptionsPrefix(tao2, "reg_"));
5089566063dSJacob Faibussowitsch PetscCall(TaoSetFromOptions(tao2));
509c4762a1bSJed Brown
510c4762a1bSJed Brown for (i = 0; i < ctx->iter; i++) {
5119566063dSJacob Faibussowitsch PetscCall(VecCopy(z, zold));
5129566063dSJacob Faibussowitsch PetscCall(TaoSolve(tao1)); /* Updates xk */
513c4762a1bSJed Brown if (ctx->p == NORM_1) {
5149566063dSJacob Faibussowitsch PetscCall(VecWAXPY(temp, 1., xk, u));
5159566063dSJacob Faibussowitsch PetscCall(TaoSoftThreshold(temp, -ctx->alpha / mu, ctx->alpha / mu, z));
516c4762a1bSJed Brown } else {
5179566063dSJacob Faibussowitsch PetscCall(TaoSolve(tao2)); /* Update zk */
518c4762a1bSJed Brown }
519c4762a1bSJed Brown /* u = u + xk -z */
5209566063dSJacob Faibussowitsch PetscCall(VecAXPBYPCZ(u, 1., -1., 1., xk, z));
521c4762a1bSJed Brown /* r_norm : norm(x-z) */
5229566063dSJacob Faibussowitsch PetscCall(VecWAXPY(diff, -1., z, xk));
5239566063dSJacob Faibussowitsch PetscCall(VecNorm(diff, NORM_2, &r_norm));
524c4762a1bSJed Brown /* s_norm : norm(-mu(z-zold)) */
5259566063dSJacob Faibussowitsch PetscCall(VecWAXPY(zdiff, -1., zold, z));
5269566063dSJacob Faibussowitsch PetscCall(VecNorm(zdiff, NORM_2, &s_norm));
527c4762a1bSJed Brown s_norm = s_norm * mu;
528c4762a1bSJed Brown /* primal : sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z))*/
5299566063dSJacob Faibussowitsch PetscCall(VecNorm(xk, NORM_2, &x_norm));
5309566063dSJacob Faibussowitsch PetscCall(VecNorm(z, NORM_2, &z_norm));
531c4762a1bSJed Brown primal = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * PetscMax(x_norm, z_norm);
532c4762a1bSJed Brown /* Duality : sqrt(n)*ABSTOL + RELTOL*norm(mu*u)*/
5339566063dSJacob Faibussowitsch PetscCall(VecNorm(u, NORM_2, &u_norm));
534c4762a1bSJed Brown dual = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * u_norm * mu;
53563a3b9bcSJacob Faibussowitsch PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao1), "Iter %" PetscInt_FMT " : ||x-z||: %g, mu*||z-zold||: %g\n", i, (double)r_norm, (double)s_norm));
536c4762a1bSJed Brown if (r_norm < primal && s_norm < dual) break;
537c4762a1bSJed Brown }
5389566063dSJacob Faibussowitsch PetscCall(VecCopy(xk, x));
5399566063dSJacob Faibussowitsch PetscCall(TaoDestroy(&tao1));
5409566063dSJacob Faibussowitsch PetscCall(TaoDestroy(&tao2));
5413ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
542c4762a1bSJed Brown }
543c4762a1bSJed Brown
544c4762a1bSJed Brown /* Second order Taylor remainder convergence test */
TaylorTest(UserCtx ctx,Tao tao,Vec x,PetscReal * C)545d71ae5a4SJacob Faibussowitsch static PetscErrorCode TaylorTest(UserCtx ctx, Tao tao, Vec x, PetscReal *C)
546d71ae5a4SJacob Faibussowitsch {
547c4762a1bSJed Brown PetscReal h, J, temp;
548c4762a1bSJed Brown PetscInt i, j;
549c4762a1bSJed Brown PetscInt numValues;
550c4762a1bSJed Brown PetscReal Jx, Jxhat_comp, Jxhat_pred;
551c4762a1bSJed Brown PetscReal *Js, *hs;
552c4762a1bSJed Brown PetscReal gdotdx;
553c4762a1bSJed Brown PetscReal minrate = PETSC_MAX_REAL;
554c4762a1bSJed Brown MPI_Comm comm = PetscObjectComm((PetscObject)x);
555c4762a1bSJed Brown Vec g, dx, xhat;
556c4762a1bSJed Brown
557c4762a1bSJed Brown PetscFunctionBegin;
5589566063dSJacob Faibussowitsch PetscCall(VecDuplicate(x, &g));
5599566063dSJacob Faibussowitsch PetscCall(VecDuplicate(x, &xhat));
560c4762a1bSJed Brown /* choose a perturbation direction */
5619566063dSJacob Faibussowitsch PetscCall(VecDuplicate(x, &dx));
5629566063dSJacob Faibussowitsch PetscCall(VecSetRandom(dx, ctx->rctx));
563c4762a1bSJed Brown /* evaluate objective at x: J(x) */
5649566063dSJacob Faibussowitsch PetscCall(TaoComputeObjective(tao, x, &Jx));
565c4762a1bSJed Brown /* evaluate gradient at x, save in vector g */
5669566063dSJacob Faibussowitsch PetscCall(TaoComputeGradient(tao, x, g));
5679566063dSJacob Faibussowitsch PetscCall(VecDot(g, dx, &gdotdx));
568c4762a1bSJed Brown
569c4762a1bSJed Brown for (numValues = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor) numValues++;
5709566063dSJacob Faibussowitsch PetscCall(PetscCalloc2(numValues, &Js, numValues, &hs));
571c4762a1bSJed Brown for (i = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor, i++) {
5729566063dSJacob Faibussowitsch PetscCall(VecWAXPY(xhat, h, dx, x));
5739566063dSJacob Faibussowitsch PetscCall(TaoComputeObjective(tao, xhat, &Jxhat_comp));
574c4762a1bSJed Brown /* J(\hat(x)) \approx J(x) + g^T (xhat - x) = J(x) + h * g^T dx */
575c4762a1bSJed Brown Jxhat_pred = Jx + h * gdotdx;
576c4762a1bSJed Brown /* Vector to dJdm scalar? Dot?*/
577c4762a1bSJed Brown J = PetscAbsReal(Jxhat_comp - Jxhat_pred);
5789566063dSJacob Faibussowitsch PetscCall(PetscPrintf(comm, "J(xhat): %g, predicted: %g, diff %g\n", (double)Jxhat_comp, (double)Jxhat_pred, (double)J));
579c4762a1bSJed Brown Js[i] = J;
580c4762a1bSJed Brown hs[i] = h;
581c4762a1bSJed Brown }
582c4762a1bSJed Brown for (j = 1; j < numValues; j++) {
583c4762a1bSJed Brown temp = PetscLogReal(Js[j] / Js[j - 1]) / PetscLogReal(hs[j] / hs[j - 1]);
58463a3b9bcSJacob Faibussowitsch PetscCall(PetscPrintf(comm, "Convergence rate step %" PetscInt_FMT ": %g\n", j - 1, (double)temp));
585c4762a1bSJed Brown minrate = PetscMin(minrate, temp);
586c4762a1bSJed Brown }
587c4762a1bSJed Brown /* If O is not ~2, then the test is wrong */
5889566063dSJacob Faibussowitsch PetscCall(PetscFree2(Js, hs));
589c4762a1bSJed Brown *C = minrate;
5909566063dSJacob Faibussowitsch PetscCall(VecDestroy(&dx));
5919566063dSJacob Faibussowitsch PetscCall(VecDestroy(&xhat));
5929566063dSJacob Faibussowitsch PetscCall(VecDestroy(&g));
5933ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
594c4762a1bSJed Brown }
595c4762a1bSJed Brown
main(int argc,char ** argv)596d71ae5a4SJacob Faibussowitsch int main(int argc, char **argv)
597d71ae5a4SJacob Faibussowitsch {
598c4762a1bSJed Brown UserCtx ctx;
599c4762a1bSJed Brown Tao tao;
600c4762a1bSJed Brown Vec x;
601c4762a1bSJed Brown Mat H;
602c4762a1bSJed Brown
603327415f7SBarry Smith PetscFunctionBeginUser;
6049566063dSJacob Faibussowitsch PetscCall(PetscInitialize(&argc, &argv, NULL, help));
6059566063dSJacob Faibussowitsch PetscCall(PetscNew(&ctx));
6069566063dSJacob Faibussowitsch PetscCall(ConfigureContext(ctx));
607a82e8c82SStefano Zampini /* Define two functions that could pass as objectives to TaoSetObjective(): one
608c4762a1bSJed Brown * for the misfit component, and one for the regularization component */
609c4762a1bSJed Brown /* ObjectiveMisfit() and ObjectiveRegularization() */
610c4762a1bSJed Brown
611c4762a1bSJed Brown /* Define a single function that calls both components adds them together: the complete objective,
612c4762a1bSJed Brown * in the absence of a Tao implementation that handles separability */
613c4762a1bSJed Brown /* ObjectiveComplete() */
6149566063dSJacob Faibussowitsch PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao));
6159566063dSJacob Faibussowitsch PetscCall(TaoSetType(tao, TAONM));
6169566063dSJacob Faibussowitsch PetscCall(TaoSetObjective(tao, ObjectiveComplete, (void *)ctx));
6179566063dSJacob Faibussowitsch PetscCall(TaoSetGradient(tao, NULL, GradientComplete, (void *)ctx));
6189566063dSJacob Faibussowitsch PetscCall(MatDuplicate(ctx->W, MAT_SHARE_NONZERO_PATTERN, &H));
6199566063dSJacob Faibussowitsch PetscCall(TaoSetHessian(tao, H, H, HessianComplete, (void *)ctx));
6209566063dSJacob Faibussowitsch PetscCall(MatCreateVecs(ctx->F, NULL, &x));
6219566063dSJacob Faibussowitsch PetscCall(VecSet(x, 0.));
6229566063dSJacob Faibussowitsch PetscCall(TaoSetSolution(tao, x));
6239566063dSJacob Faibussowitsch PetscCall(TaoSetFromOptions(tao));
6241baa6e33SBarry Smith if (ctx->use_admm) PetscCall(TaoSolveADMM(ctx, x));
6251baa6e33SBarry Smith else PetscCall(TaoSolve(tao));
626c4762a1bSJed Brown /* examine solution */
6279566063dSJacob Faibussowitsch PetscCall(VecViewFromOptions(x, NULL, "-view_sol"));
628c4762a1bSJed Brown if (ctx->taylor) {
629c4762a1bSJed Brown PetscReal rate;
6309566063dSJacob Faibussowitsch PetscCall(TaylorTest(ctx, tao, x, &rate));
631c4762a1bSJed Brown }
6323a7d0413SPierre Jolivet if (ctx->soft) PetscCall(TaoSoftThreshold(x, 0., 0., x));
6339566063dSJacob Faibussowitsch PetscCall(MatDestroy(&H));
6349566063dSJacob Faibussowitsch PetscCall(TaoDestroy(&tao));
6359566063dSJacob Faibussowitsch PetscCall(VecDestroy(&x));
6369566063dSJacob Faibussowitsch PetscCall(DestroyContext(&ctx));
6379566063dSJacob Faibussowitsch PetscCall(PetscFinalize());
638b122ec5aSJacob Faibussowitsch return 0;
639c4762a1bSJed Brown }
640c4762a1bSJed Brown
641c4762a1bSJed Brown /*TEST
642c4762a1bSJed Brown
643c4762a1bSJed Brown build:
644c4762a1bSJed Brown requires: !complex
645c4762a1bSJed Brown
646c4762a1bSJed Brown test:
647c4762a1bSJed Brown suffix: 0
648c4762a1bSJed Brown args:
649c4762a1bSJed Brown
650c4762a1bSJed Brown test:
651c4762a1bSJed Brown suffix: l1_1
652c4762a1bSJed Brown args: -p 1 -tao_type lmvm -alpha 1. -epsilon 1.e-7 -m 64 -n 64 -view_sol -matrix_format 1
653c4762a1bSJed Brown
654c4762a1bSJed Brown test:
655c4762a1bSJed Brown suffix: hessian_1
656c5f5e425SStefano Zampini args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nls
657c4762a1bSJed Brown
658c4762a1bSJed Brown test:
659c4762a1bSJed Brown suffix: hessian_2
660c5f5e425SStefano Zampini args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nls
661c4762a1bSJed Brown
662c4762a1bSJed Brown test:
663c4762a1bSJed Brown suffix: nm_1
664c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nm -tao_max_it 50
665c4762a1bSJed Brown
666c4762a1bSJed Brown test:
667c4762a1bSJed Brown suffix: nm_2
668c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nm -tao_max_it 50
669c4762a1bSJed Brown
670c4762a1bSJed Brown test:
671c4762a1bSJed Brown suffix: lmvm_1
672c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type lmvm -tao_max_it 40
673c4762a1bSJed Brown
674c4762a1bSJed Brown test:
675c4762a1bSJed Brown suffix: lmvm_2
676c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type lmvm -tao_max_it 15
677c4762a1bSJed Brown
678c4762a1bSJed Brown test:
679c4762a1bSJed Brown suffix: soft_threshold_admm_1
680c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm
681c4762a1bSJed Brown
682c4762a1bSJed Brown test:
683c4762a1bSJed Brown suffix: hessian_admm_1
684c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nls -misfit_tao_type nls
685c4762a1bSJed Brown
686c4762a1bSJed Brown test:
687c4762a1bSJed Brown suffix: hessian_admm_2
688c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nls -misfit_tao_type nls
689c4762a1bSJed Brown
690c4762a1bSJed Brown test:
691c4762a1bSJed Brown suffix: nm_admm_1
692c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nm -misfit_tao_type nm
693c4762a1bSJed Brown
694c4762a1bSJed Brown test:
695c4762a1bSJed Brown suffix: nm_admm_2
696c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nm -misfit_tao_type nm -iter 7
697c4762a1bSJed Brown
698c4762a1bSJed Brown test:
699c4762a1bSJed Brown suffix: lmvm_admm_1
700c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm
701c4762a1bSJed Brown
702c4762a1bSJed Brown test:
703c4762a1bSJed Brown suffix: lmvm_admm_2
704c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm
705c4762a1bSJed Brown
70684430a0dSHansol Suh test:
70784430a0dSHansol Suh suffix: soft
70884430a0dSHansol Suh args: -taylor 0 -soft 1
709e0008caeSPierre Jolivet output_file: output/empty.out
71084430a0dSHansol Suh
711c4762a1bSJed Brown TEST*/
712