xref: /petsc/src/tao/constrained/tutorials/tomographyADMM.c (revision 9371c9d470a9602b6d10a8bf50c9b2280a79e45a)
1 #include <petsctao.h>
2 /*
3 Description:   ADMM tomography reconstruction example .
4                0.5*||Ax-b||^2 + lambda*g(x)
5 Reference:     BRGN Tomography Example
6 */
7 
8 static char help[] = "Finds the ADMM solution to the under constraint linear model Ax = b, with regularizer. \n\
9                       A is a M*N real matrix (M<N), x is sparse. A good regularizer is an L1 regularizer. \n\
10                       We first split the operator into 0.5*||Ax-b||^2, f(x), and lambda*||x||_1, g(z), where lambda is user specified weight. \n\
11                       g(z) could be either ||z||_1, or ||z||_2^2. Default closed form solution for NORM1 would be soft-threshold, which is \n\
12                       natively supported in admm.c with -tao_admm_regularizer_type soft-threshold. Or user can use regular TAO solver for  \n\
13                       either NORM1 or NORM2 or TAOSHELL, with -reg {1,2,3} \n\
14                       Then, we augment both f and g, and solve it via ADMM. \n\
15                       D is the M*N transform matrix so that D*x is sparse. \n";
16 
17 typedef struct {
18   PetscInt  M, N, K, reg;
19   PetscReal lambda, eps, mumin;
20   Mat       A, ATA, H, Hx, D, Hz, DTD, HF;
21   Vec       c, xlb, xub, x, b, workM, workN, workN2, workN3, xGT; /* observation b, ground truth xGT, the lower bound and upper bound of x*/
22 } AppCtx;
23 
24 /*------------------------------------------------------------*/
25 
26 PetscErrorCode NullJacobian(Tao tao, Vec X, Mat J, Mat Jpre, void *ptr) {
27   PetscFunctionBegin;
28   PetscFunctionReturn(0);
29 }
30 
31 /*------------------------------------------------------------*/
32 
33 static PetscErrorCode TaoShellSolve_SoftThreshold(Tao tao) {
34   PetscReal lambda, mu;
35   AppCtx   *user;
36   Vec       out, work, y, x;
37   Tao       admm_tao, misfit;
38 
39   PetscFunctionBegin;
40   user = NULL;
41   mu   = 0;
42   PetscCall(TaoGetADMMParentTao(tao, &admm_tao));
43   PetscCall(TaoADMMGetMisfitSubsolver(admm_tao, &misfit));
44   PetscCall(TaoADMMGetSpectralPenalty(admm_tao, &mu));
45   PetscCall(TaoShellGetContext(tao, &user));
46 
47   lambda = user->lambda;
48   work   = user->workN;
49   PetscCall(TaoGetSolution(tao, &out));
50   PetscCall(TaoGetSolution(misfit, &x));
51   PetscCall(TaoADMMGetDualVector(admm_tao, &y));
52 
53   /* Dx + y/mu */
54   PetscCall(MatMult(user->D, x, work));
55   PetscCall(VecAXPY(work, 1 / mu, y));
56 
57   /* soft thresholding */
58   PetscCall(TaoSoftThreshold(work, -lambda / mu, lambda / mu, out));
59   PetscFunctionReturn(0);
60 }
61 
62 /*------------------------------------------------------------*/
63 
64 PetscErrorCode MisfitObjectiveAndGradient(Tao tao, Vec X, PetscReal *f, Vec g, void *ptr) {
65   AppCtx *user = (AppCtx *)ptr;
66 
67   PetscFunctionBegin;
68   /* Objective  0.5*||Ax-b||_2^2 */
69   PetscCall(MatMult(user->A, X, user->workM));
70   PetscCall(VecAXPY(user->workM, -1, user->b));
71   PetscCall(VecDot(user->workM, user->workM, f));
72   *f *= 0.5;
73   /* Gradient. ATAx-ATb */
74   PetscCall(MatMult(user->ATA, X, user->workN));
75   PetscCall(MatMultTranspose(user->A, user->b, user->workN2));
76   PetscCall(VecWAXPY(g, -1., user->workN2, user->workN));
77   PetscFunctionReturn(0);
78 }
79 
80 /*------------------------------------------------------------*/
81 
82 PetscErrorCode RegularizerObjectiveAndGradient1(Tao tao, Vec X, PetscReal *f_reg, Vec G_reg, void *ptr) {
83   AppCtx *user = (AppCtx *)ptr;
84 
85   PetscFunctionBegin;
86   /* compute regularizer objective
87    * f = f + lambda*sum(sqrt(y.^2+epsilon^2) - epsilon), where y = D*x */
88   PetscCall(VecCopy(X, user->workN2));
89   PetscCall(VecPow(user->workN2, 2.));
90   PetscCall(VecShift(user->workN2, user->eps * user->eps));
91   PetscCall(VecSqrtAbs(user->workN2));
92   PetscCall(VecCopy(user->workN2, user->workN3));
93   PetscCall(VecShift(user->workN2, -user->eps));
94   PetscCall(VecSum(user->workN2, f_reg));
95   *f_reg *= user->lambda;
96   /* compute regularizer gradient = lambda*x */
97   PetscCall(VecPointwiseDivide(G_reg, X, user->workN3));
98   PetscCall(VecScale(G_reg, user->lambda));
99   PetscFunctionReturn(0);
100 }
101 
102 /*------------------------------------------------------------*/
103 
104 PetscErrorCode RegularizerObjectiveAndGradient2(Tao tao, Vec X, PetscReal *f_reg, Vec G_reg, void *ptr) {
105   AppCtx   *user = (AppCtx *)ptr;
106   PetscReal temp;
107 
108   PetscFunctionBegin;
109   /* compute regularizer objective = lambda*|z|_2^2 */
110   PetscCall(VecDot(X, X, &temp));
111   *f_reg = 0.5 * user->lambda * temp;
112   /* compute regularizer gradient = lambda*z */
113   PetscCall(VecCopy(X, G_reg));
114   PetscCall(VecScale(G_reg, user->lambda));
115   PetscFunctionReturn(0);
116 }
117 
118 /*------------------------------------------------------------*/
119 
120 static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr) {
121   PetscFunctionBegin;
122   PetscFunctionReturn(0);
123 }
124 
125 /*------------------------------------------------------------*/
126 
127 static PetscErrorCode HessianReg(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr) {
128   AppCtx *user = (AppCtx *)ptr;
129 
130   PetscFunctionBegin;
131   PetscCall(MatMult(user->D, x, user->workN));
132   PetscCall(VecPow(user->workN2, 2.));
133   PetscCall(VecShift(user->workN2, user->eps * user->eps));
134   PetscCall(VecSqrtAbs(user->workN2));
135   PetscCall(VecShift(user->workN2, -user->eps));
136   PetscCall(VecReciprocal(user->workN2));
137   PetscCall(VecScale(user->workN2, user->eps * user->eps));
138   PetscCall(MatDiagonalSet(H, user->workN2, INSERT_VALUES));
139   PetscFunctionReturn(0);
140 }
141 
142 /*------------------------------------------------------------*/
143 
144 PetscErrorCode FullObjGrad(Tao tao, Vec X, PetscReal *f, Vec g, void *ptr) {
145   AppCtx   *user = (AppCtx *)ptr;
146   PetscReal f_reg;
147 
148   PetscFunctionBegin;
149   /* Objective  0.5*||Ax-b||_2^2 + lambda*||x||_2^2*/
150   PetscCall(MatMult(user->A, X, user->workM));
151   PetscCall(VecAXPY(user->workM, -1, user->b));
152   PetscCall(VecDot(user->workM, user->workM, f));
153   PetscCall(VecNorm(X, NORM_2, &f_reg));
154   *f *= 0.5;
155   *f += user->lambda * f_reg * f_reg;
156   /* Gradient. ATAx-ATb + 2*lambda*x */
157   PetscCall(MatMult(user->ATA, X, user->workN));
158   PetscCall(MatMultTranspose(user->A, user->b, user->workN2));
159   PetscCall(VecWAXPY(g, -1., user->workN2, user->workN));
160   PetscCall(VecAXPY(g, 2 * user->lambda, X));
161   PetscFunctionReturn(0);
162 }
163 /*------------------------------------------------------------*/
164 
165 static PetscErrorCode HessianFull(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr) {
166   PetscFunctionBegin;
167   PetscFunctionReturn(0);
168 }
169 /*------------------------------------------------------------*/
170 
171 PetscErrorCode InitializeUserData(AppCtx *user) {
172   char        dataFile[] = "tomographyData_A_b_xGT"; /* Matrix A and vectors b, xGT(ground truth) binary files generated by Matlab. Debug: change from "tomographyData_A_b_xGT" to "cs1Data_A_b_xGT". */
173   PetscViewer fd;                                    /* used to load data from file */
174   PetscInt    k, n;
175   PetscScalar v;
176 
177   PetscFunctionBegin;
178   /* Load the A matrix, b vector, and xGT vector from a binary file. */
179   PetscCall(PetscViewerBinaryOpen(PETSC_COMM_WORLD, dataFile, FILE_MODE_READ, &fd));
180   PetscCall(MatCreate(PETSC_COMM_WORLD, &user->A));
181   PetscCall(MatSetType(user->A, MATAIJ));
182   PetscCall(MatLoad(user->A, fd));
183   PetscCall(VecCreate(PETSC_COMM_WORLD, &user->b));
184   PetscCall(VecLoad(user->b, fd));
185   PetscCall(VecCreate(PETSC_COMM_WORLD, &user->xGT));
186   PetscCall(VecLoad(user->xGT, fd));
187   PetscCall(PetscViewerDestroy(&fd));
188 
189   PetscCall(MatGetSize(user->A, &user->M, &user->N));
190 
191   PetscCall(MatCreate(PETSC_COMM_WORLD, &user->D));
192   PetscCall(MatSetSizes(user->D, PETSC_DECIDE, PETSC_DECIDE, user->N, user->N));
193   PetscCall(MatSetFromOptions(user->D));
194   PetscCall(MatSetUp(user->D));
195   for (k = 0; k < user->N; k++) {
196     v = 1.0;
197     n = k + 1;
198     if (k < user->N - 1) { PetscCall(MatSetValues(user->D, 1, &k, 1, &n, &v, INSERT_VALUES)); }
199     v = -1.0;
200     PetscCall(MatSetValues(user->D, 1, &k, 1, &k, &v, INSERT_VALUES));
201   }
202   PetscCall(MatAssemblyBegin(user->D, MAT_FINAL_ASSEMBLY));
203   PetscCall(MatAssemblyEnd(user->D, MAT_FINAL_ASSEMBLY));
204 
205   PetscCall(MatTransposeMatMult(user->D, user->D, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &user->DTD));
206 
207   PetscCall(MatCreate(PETSC_COMM_WORLD, &user->Hz));
208   PetscCall(MatSetSizes(user->Hz, PETSC_DECIDE, PETSC_DECIDE, user->N, user->N));
209   PetscCall(MatSetFromOptions(user->Hz));
210   PetscCall(MatSetUp(user->Hz));
211   PetscCall(MatAssemblyBegin(user->Hz, MAT_FINAL_ASSEMBLY));
212   PetscCall(MatAssemblyEnd(user->Hz, MAT_FINAL_ASSEMBLY));
213 
214   PetscCall(VecCreate(PETSC_COMM_WORLD, &(user->x)));
215   PetscCall(VecCreate(PETSC_COMM_WORLD, &(user->workM)));
216   PetscCall(VecCreate(PETSC_COMM_WORLD, &(user->workN)));
217   PetscCall(VecCreate(PETSC_COMM_WORLD, &(user->workN2)));
218   PetscCall(VecSetSizes(user->x, PETSC_DECIDE, user->N));
219   PetscCall(VecSetSizes(user->workM, PETSC_DECIDE, user->M));
220   PetscCall(VecSetSizes(user->workN, PETSC_DECIDE, user->N));
221   PetscCall(VecSetSizes(user->workN2, PETSC_DECIDE, user->N));
222   PetscCall(VecSetFromOptions(user->x));
223   PetscCall(VecSetFromOptions(user->workM));
224   PetscCall(VecSetFromOptions(user->workN));
225   PetscCall(VecSetFromOptions(user->workN2));
226 
227   PetscCall(VecDuplicate(user->workN, &(user->workN3)));
228   PetscCall(VecDuplicate(user->x, &(user->xlb)));
229   PetscCall(VecDuplicate(user->x, &(user->xub)));
230   PetscCall(VecDuplicate(user->x, &(user->c)));
231   PetscCall(VecSet(user->xlb, 0.0));
232   PetscCall(VecSet(user->c, 0.0));
233   PetscCall(VecSet(user->xub, PETSC_INFINITY));
234 
235   PetscCall(MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->ATA)));
236   PetscCall(MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->Hx)));
237   PetscCall(MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->HF)));
238 
239   PetscCall(MatAssemblyBegin(user->ATA, MAT_FINAL_ASSEMBLY));
240   PetscCall(MatAssemblyEnd(user->ATA, MAT_FINAL_ASSEMBLY));
241   PetscCall(MatAssemblyBegin(user->Hx, MAT_FINAL_ASSEMBLY));
242   PetscCall(MatAssemblyEnd(user->Hx, MAT_FINAL_ASSEMBLY));
243   PetscCall(MatAssemblyBegin(user->HF, MAT_FINAL_ASSEMBLY));
244   PetscCall(MatAssemblyEnd(user->HF, MAT_FINAL_ASSEMBLY));
245 
246   user->lambda = 1.e-8;
247   user->eps    = 1.e-3;
248   user->reg    = 2;
249   user->mumin  = 5.e-6;
250 
251   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "tomographyADMM.c");
252   PetscCall(PetscOptionsInt("-reg", "Regularization scheme for z solver (1,2)", "tomographyADMM.c", user->reg, &(user->reg), NULL));
253   PetscCall(PetscOptionsReal("-lambda", "The regularization multiplier. 1 default", "tomographyADMM.c", user->lambda, &(user->lambda), NULL));
254   PetscCall(PetscOptionsReal("-eps", "L1 norm epsilon padding", "tomographyADMM.c", user->eps, &(user->eps), NULL));
255   PetscCall(PetscOptionsReal("-mumin", "Minimum value for ADMM spectral penalty", "tomographyADMM.c", user->mumin, &(user->mumin), NULL));
256   PetscOptionsEnd();
257   PetscFunctionReturn(0);
258 }
259 
260 /*------------------------------------------------------------*/
261 
262 PetscErrorCode DestroyContext(AppCtx *user) {
263   PetscFunctionBegin;
264   PetscCall(MatDestroy(&user->A));
265   PetscCall(MatDestroy(&user->ATA));
266   PetscCall(MatDestroy(&user->Hx));
267   PetscCall(MatDestroy(&user->Hz));
268   PetscCall(MatDestroy(&user->HF));
269   PetscCall(MatDestroy(&user->D));
270   PetscCall(MatDestroy(&user->DTD));
271   PetscCall(VecDestroy(&user->xGT));
272   PetscCall(VecDestroy(&user->xlb));
273   PetscCall(VecDestroy(&user->xub));
274   PetscCall(VecDestroy(&user->b));
275   PetscCall(VecDestroy(&user->x));
276   PetscCall(VecDestroy(&user->c));
277   PetscCall(VecDestroy(&user->workN3));
278   PetscCall(VecDestroy(&user->workN2));
279   PetscCall(VecDestroy(&user->workN));
280   PetscCall(VecDestroy(&user->workM));
281   PetscFunctionReturn(0);
282 }
283 
284 /*------------------------------------------------------------*/
285 
286 int main(int argc, char **argv) {
287   Tao         tao, misfit, reg;
288   PetscReal   v1, v2;
289   AppCtx     *user;
290   PetscViewer fd;
291   char        resultFile[] = "tomographyResult_x";
292 
293   PetscFunctionBeginUser;
294   PetscCall(PetscInitialize(&argc, &argv, (char *)0, help));
295   PetscCall(PetscNew(&user));
296   PetscCall(InitializeUserData(user));
297 
298   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao));
299   PetscCall(TaoSetType(tao, TAOADMM));
300   PetscCall(TaoSetSolution(tao, user->x));
301   /* f(x) + g(x) for parent tao */
302   PetscCall(TaoADMMSetSpectralPenalty(tao, 1.));
303   PetscCall(TaoSetObjectiveAndGradient(tao, NULL, FullObjGrad, (void *)user));
304   PetscCall(MatShift(user->HF, user->lambda));
305   PetscCall(TaoSetHessian(tao, user->HF, user->HF, HessianFull, (void *)user));
306 
307   /* f(x) for misfit tao */
308   PetscCall(TaoADMMSetMisfitObjectiveAndGradientRoutine(tao, MisfitObjectiveAndGradient, (void *)user));
309   PetscCall(TaoADMMSetMisfitHessianRoutine(tao, user->Hx, user->Hx, HessianMisfit, (void *)user));
310   PetscCall(TaoADMMSetMisfitHessianChangeStatus(tao, PETSC_FALSE));
311   PetscCall(TaoADMMSetMisfitConstraintJacobian(tao, user->D, user->D, NullJacobian, (void *)user));
312 
313   /* g(x) for regularizer tao */
314   if (user->reg == 1) {
315     PetscCall(TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient1, (void *)user));
316     PetscCall(TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianReg, (void *)user));
317     PetscCall(TaoADMMSetRegHessianChangeStatus(tao, PETSC_TRUE));
318   } else if (user->reg == 2) {
319     PetscCall(TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient2, (void *)user));
320     PetscCall(MatShift(user->Hz, 1));
321     PetscCall(MatScale(user->Hz, user->lambda));
322     PetscCall(TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianMisfit, (void *)user));
323     PetscCall(TaoADMMSetRegHessianChangeStatus(tao, PETSC_TRUE));
324   } else PetscCheck(user->reg == 3, PETSC_COMM_WORLD, PETSC_ERR_ARG_UNKNOWN_TYPE, "Incorrect Reg type"); /* TaoShell case */
325 
326   /* Set type for the misfit solver */
327   PetscCall(TaoADMMGetMisfitSubsolver(tao, &misfit));
328   PetscCall(TaoADMMGetRegularizationSubsolver(tao, &reg));
329   PetscCall(TaoSetType(misfit, TAONLS));
330   if (user->reg == 3) {
331     PetscCall(TaoSetType(reg, TAOSHELL));
332     PetscCall(TaoShellSetContext(reg, (void *)user));
333     PetscCall(TaoShellSetSolve(reg, TaoShellSolve_SoftThreshold));
334   } else {
335     PetscCall(TaoSetType(reg, TAONLS));
336   }
337   PetscCall(TaoSetVariableBounds(misfit, user->xlb, user->xub));
338 
339   /* Soft Thresholding solves the ADMM problem with the L1 regularizer lambda*||z||_1 and the x-z=0 constraint */
340   PetscCall(TaoADMMSetRegularizerCoefficient(tao, user->lambda));
341   PetscCall(TaoADMMSetRegularizerConstraintJacobian(tao, NULL, NULL, NullJacobian, (void *)user));
342   PetscCall(TaoADMMSetMinimumSpectralPenalty(tao, user->mumin));
343 
344   PetscCall(TaoADMMSetConstraintVectorRHS(tao, user->c));
345   PetscCall(TaoSetFromOptions(tao));
346   PetscCall(TaoSolve(tao));
347 
348   /* Save x (reconstruction of object) vector to a binary file, which maybe read from Matlab and convert to a 2D image for comparison. */
349   PetscCall(PetscViewerBinaryOpen(PETSC_COMM_WORLD, resultFile, FILE_MODE_WRITE, &fd));
350   PetscCall(VecView(user->x, fd));
351   PetscCall(PetscViewerDestroy(&fd));
352 
353   /* compute the error */
354   PetscCall(VecAXPY(user->x, -1, user->xGT));
355   PetscCall(VecNorm(user->x, NORM_2, &v1));
356   PetscCall(VecNorm(user->xGT, NORM_2, &v2));
357   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "relative reconstruction error: ||x-xGT||/||xGT|| = %6.4e.\n", (double)(v1 / v2)));
358 
359   /* Free TAO data structures */
360   PetscCall(TaoDestroy(&tao));
361   PetscCall(DestroyContext(user));
362   PetscCall(PetscFree(user));
363   PetscCall(PetscFinalize());
364   return 0;
365 }
366 
367 /*TEST
368 
369    build:
370       requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES)
371 
372    test:
373       suffix: 1
374       localrunfiles: tomographyData_A_b_xGT
375       args:  -lambda 1.e-8 -tao_monitor -tao_type nls -tao_nls_pc_type icc
376 
377    test:
378       suffix: 2
379       localrunfiles: tomographyData_A_b_xGT
380       args:  -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8  -misfit_tao_nls_pc_type icc -misfit_tao_monitor -reg_tao_monitor
381 
382    test:
383       suffix: 3
384       localrunfiles: tomographyData_A_b_xGT
385       args:  -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor
386 
387    test:
388       suffix: 4
389       localrunfiles: tomographyData_A_b_xGT
390       args:  -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -misfit_tao_monitor -misfit_tao_nls_pc_type icc
391 
392    test:
393       suffix: 5
394       localrunfiles: tomographyData_A_b_xGT
395       args:  -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc
396 
397    test:
398       suffix: 6
399       localrunfiles: tomographyData_A_b_xGT
400       args:  -reg 3 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc
401 
402 TEST*/
403