xref: /petsc/src/tao/constrained/tutorials/tomographyADMM.c (revision b122ec5aa1bd4469eb4e0673542fb7de3f411254)
1 #include <petsctao.h>
2 /*
3 Description:   ADMM tomography reconstruction example .
4                0.5*||Ax-b||^2 + lambda*g(x)
5 Reference:     BRGN Tomography Example
6 */
7 
8 static char help[] = "Finds the ADMM solution to the under constraint linear model Ax = b, with regularizer. \n\
9                       A is a M*N real matrix (M<N), x is sparse. A good regularizer is an L1 regularizer. \n\
10                       We first split the operator into 0.5*||Ax-b||^2, f(x), and lambda*||x||_1, g(z), where lambda is user specified weight. \n\
11                       g(z) could be either ||z||_1, or ||z||_2^2. Default closed form solution for NORM1 would be soft-threshold, which is \n\
12                       natively supported in admm.c with -tao_admm_regularizer_type soft-threshold. Or user can use regular TAO solver for  \n\
13                       either NORM1 or NORM2 or TAOSHELL, with -reg {1,2,3} \n\
14                       Then, we augment both f and g, and solve it via ADMM. \n\
15                       D is the M*N transform matrix so that D*x is sparse. \n";
16 
17 typedef struct {
18   PetscInt  M,N,K,reg;
19   PetscReal lambda,eps,mumin;
20   Mat       A,ATA,H,Hx,D,Hz,DTD,HF;
21   Vec       c,xlb,xub,x,b,workM,workN,workN2,workN3,xGT;    /* observation b, ground truth xGT, the lower bound and upper bound of x*/
22 } AppCtx;
23 
24 /*------------------------------------------------------------*/
25 
26 PetscErrorCode NullJacobian(Tao tao,Vec X,Mat J,Mat Jpre,void *ptr)
27 {
28   PetscFunctionBegin;
29   PetscFunctionReturn(0);
30 }
31 
32 /*------------------------------------------------------------*/
33 
34 static PetscErrorCode TaoShellSolve_SoftThreshold(Tao tao)
35 {
36   PetscReal      lambda, mu;
37   AppCtx         *user;
38   Vec            out,work,y,x;
39   Tao            admm_tao,misfit;
40 
41   PetscFunctionBegin;
42   user = NULL;
43   mu   = 0;
44   CHKERRQ(TaoGetADMMParentTao(tao,&admm_tao));
45   CHKERRQ(TaoADMMGetMisfitSubsolver(admm_tao, &misfit));
46   CHKERRQ(TaoADMMGetSpectralPenalty(admm_tao,&mu));
47   CHKERRQ(TaoShellGetContext(tao,&user));
48 
49   lambda = user->lambda;
50   work   = user->workN;
51   CHKERRQ(TaoGetSolution(tao, &out));
52   CHKERRQ(TaoGetSolution(misfit, &x));
53   CHKERRQ(TaoADMMGetDualVector(admm_tao, &y));
54 
55   /* Dx + y/mu */
56   CHKERRQ(MatMult(user->D,x,work));
57   CHKERRQ(VecAXPY(work,1/mu,y));
58 
59   /* soft thresholding */
60   CHKERRQ(TaoSoftThreshold(work, -lambda/mu, lambda/mu, out));
61   PetscFunctionReturn(0);
62 }
63 
64 /*------------------------------------------------------------*/
65 
66 PetscErrorCode MisfitObjectiveAndGradient(Tao tao,Vec X,PetscReal *f,Vec g,void *ptr)
67 {
68   AppCtx         *user = (AppCtx*)ptr;
69 
70   PetscFunctionBegin;
71   /* Objective  0.5*||Ax-b||_2^2 */
72   CHKERRQ(MatMult(user->A,X,user->workM));
73   CHKERRQ(VecAXPY(user->workM,-1,user->b));
74   CHKERRQ(VecDot(user->workM,user->workM,f));
75   *f  *= 0.5;
76   /* Gradient. ATAx-ATb */
77   CHKERRQ(MatMult(user->ATA,X,user->workN));
78   CHKERRQ(MatMultTranspose(user->A,user->b,user->workN2));
79   CHKERRQ(VecWAXPY(g,-1.,user->workN2,user->workN));
80   PetscFunctionReturn(0);
81 }
82 
83 /*------------------------------------------------------------*/
84 
85 PetscErrorCode RegularizerObjectiveAndGradient1(Tao tao,Vec X,PetscReal *f_reg,Vec G_reg,void *ptr)
86 {
87   AppCtx         *user = (AppCtx*)ptr;
88 
89   PetscFunctionBegin;
90   /* compute regularizer objective
91    * f = f + lambda*sum(sqrt(y.^2+epsilon^2) - epsilon), where y = D*x */
92   CHKERRQ(VecCopy(X,user->workN2));
93   CHKERRQ(VecPow(user->workN2,2.));
94   CHKERRQ(VecShift(user->workN2,user->eps*user->eps));
95   CHKERRQ(VecSqrtAbs(user->workN2));
96   CHKERRQ(VecCopy(user->workN2, user->workN3));
97   CHKERRQ(VecShift(user->workN2,-user->eps));
98   CHKERRQ(VecSum(user->workN2,f_reg));
99   *f_reg *= user->lambda;
100   /* compute regularizer gradient = lambda*x */
101   CHKERRQ(VecPointwiseDivide(G_reg,X,user->workN3));
102   CHKERRQ(VecScale(G_reg,user->lambda));
103   PetscFunctionReturn(0);
104 }
105 
106 /*------------------------------------------------------------*/
107 
108 PetscErrorCode RegularizerObjectiveAndGradient2(Tao tao,Vec X,PetscReal *f_reg,Vec G_reg,void *ptr)
109 {
110   AppCtx         *user = (AppCtx*)ptr;
111   PetscReal      temp;
112 
113   PetscFunctionBegin;
114   /* compute regularizer objective = lambda*|z|_2^2 */
115   CHKERRQ(VecDot(X,X,&temp));
116   *f_reg = 0.5*user->lambda*temp;
117   /* compute regularizer gradient = lambda*z */
118   CHKERRQ(VecCopy(X,G_reg));
119   CHKERRQ(VecScale(G_reg,user->lambda));
120   PetscFunctionReturn(0);
121 }
122 
123 /*------------------------------------------------------------*/
124 
125 static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
126 {
127   PetscFunctionBegin;
128   PetscFunctionReturn(0);
129 }
130 
131 /*------------------------------------------------------------*/
132 
133 static PetscErrorCode HessianReg(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
134 {
135   AppCtx         *user = (AppCtx*)ptr;
136 
137   PetscFunctionBegin;
138   CHKERRQ(MatMult(user->D,x,user->workN));
139   CHKERRQ(VecPow(user->workN2,2.));
140   CHKERRQ(VecShift(user->workN2,user->eps*user->eps));
141   CHKERRQ(VecSqrtAbs(user->workN2));
142   CHKERRQ(VecShift(user->workN2,-user->eps));
143   CHKERRQ(VecReciprocal(user->workN2));
144   CHKERRQ(VecScale(user->workN2,user->eps*user->eps));
145   CHKERRQ(MatDiagonalSet(H,user->workN2,INSERT_VALUES));
146   PetscFunctionReturn(0);
147 }
148 
149 /*------------------------------------------------------------*/
150 
151 PetscErrorCode FullObjGrad(Tao tao,Vec X,PetscReal *f,Vec g,void *ptr)
152 {
153   AppCtx         *user = (AppCtx*)ptr;
154   PetscReal      f_reg;
155 
156   PetscFunctionBegin;
157   /* Objective  0.5*||Ax-b||_2^2 + lambda*||x||_2^2*/
158   CHKERRQ(MatMult(user->A,X,user->workM));
159   CHKERRQ(VecAXPY(user->workM,-1,user->b));
160   CHKERRQ(VecDot(user->workM,user->workM,f));
161   CHKERRQ(VecNorm(X,NORM_2,&f_reg));
162   *f  *= 0.5;
163   *f  += user->lambda*f_reg*f_reg;
164   /* Gradient. ATAx-ATb + 2*lambda*x */
165   CHKERRQ(MatMult(user->ATA,X,user->workN));
166   CHKERRQ(MatMultTranspose(user->A,user->b,user->workN2));
167   CHKERRQ(VecWAXPY(g,-1.,user->workN2,user->workN));
168   CHKERRQ(VecAXPY(g,2*user->lambda,X));
169   PetscFunctionReturn(0);
170 }
171 /*------------------------------------------------------------*/
172 
173 static PetscErrorCode HessianFull(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
174 {
175   PetscFunctionBegin;
176   PetscFunctionReturn(0);
177 }
178 /*------------------------------------------------------------*/
179 
180 PetscErrorCode InitializeUserData(AppCtx *user)
181 {
182   char           dataFile[] = "tomographyData_A_b_xGT";   /* Matrix A and vectors b, xGT(ground truth) binary files generated by Matlab. Debug: change from "tomographyData_A_b_xGT" to "cs1Data_A_b_xGT". */
183   PetscViewer    fd;   /* used to load data from file */
184   PetscErrorCode ierr;
185   PetscInt       k,n;
186   PetscScalar    v;
187   PetscFunctionBegin;
188 
189   /* Load the A matrix, b vector, and xGT vector from a binary file. */
190   CHKERRQ(PetscViewerBinaryOpen(PETSC_COMM_WORLD,dataFile,FILE_MODE_READ,&fd));
191   CHKERRQ(MatCreate(PETSC_COMM_WORLD,&user->A));
192   CHKERRQ(MatSetType(user->A,MATAIJ));
193   CHKERRQ(MatLoad(user->A,fd));
194   CHKERRQ(VecCreate(PETSC_COMM_WORLD,&user->b));
195   CHKERRQ(VecLoad(user->b,fd));
196   CHKERRQ(VecCreate(PETSC_COMM_WORLD,&user->xGT));
197   CHKERRQ(VecLoad(user->xGT,fd));
198   CHKERRQ(PetscViewerDestroy(&fd));
199 
200   CHKERRQ(MatGetSize(user->A,&user->M,&user->N));
201 
202   CHKERRQ(MatCreate(PETSC_COMM_WORLD,&user->D));
203   CHKERRQ(MatSetSizes(user->D,PETSC_DECIDE,PETSC_DECIDE,user->N,user->N));
204   CHKERRQ(MatSetFromOptions(user->D));
205   CHKERRQ(MatSetUp(user->D));
206   for (k=0; k<user->N; k++) {
207     v = 1.0;
208     n = k+1;
209     if (k< user->N -1) {
210       CHKERRQ(MatSetValues(user->D,1,&k,1,&n,&v,INSERT_VALUES));
211     }
212     v    = -1.0;
213     CHKERRQ(MatSetValues(user->D,1,&k,1,&k,&v,INSERT_VALUES));
214   }
215   CHKERRQ(MatAssemblyBegin(user->D,MAT_FINAL_ASSEMBLY));
216   CHKERRQ(MatAssemblyEnd(user->D,MAT_FINAL_ASSEMBLY));
217 
218   CHKERRQ(MatTransposeMatMult(user->D,user->D,MAT_INITIAL_MATRIX,PETSC_DEFAULT,&user->DTD));
219 
220   CHKERRQ(MatCreate(PETSC_COMM_WORLD,&user->Hz));
221   CHKERRQ(MatSetSizes(user->Hz,PETSC_DECIDE,PETSC_DECIDE,user->N,user->N));
222   CHKERRQ(MatSetFromOptions(user->Hz));
223   CHKERRQ(MatSetUp(user->Hz));
224   CHKERRQ(MatAssemblyBegin(user->Hz,MAT_FINAL_ASSEMBLY));
225   CHKERRQ(MatAssemblyEnd(user->Hz,MAT_FINAL_ASSEMBLY));
226 
227   CHKERRQ(VecCreate(PETSC_COMM_WORLD,&(user->x)));
228   CHKERRQ(VecCreate(PETSC_COMM_WORLD,&(user->workM)));
229   CHKERRQ(VecCreate(PETSC_COMM_WORLD,&(user->workN)));
230   CHKERRQ(VecCreate(PETSC_COMM_WORLD,&(user->workN2)));
231   CHKERRQ(VecSetSizes(user->x,PETSC_DECIDE,user->N));
232   CHKERRQ(VecSetSizes(user->workM,PETSC_DECIDE,user->M));
233   CHKERRQ(VecSetSizes(user->workN,PETSC_DECIDE,user->N));
234   CHKERRQ(VecSetSizes(user->workN2,PETSC_DECIDE,user->N));
235   CHKERRQ(VecSetFromOptions(user->x));
236   CHKERRQ(VecSetFromOptions(user->workM));
237   CHKERRQ(VecSetFromOptions(user->workN));
238   CHKERRQ(VecSetFromOptions(user->workN2));
239 
240   CHKERRQ(VecDuplicate(user->workN,&(user->workN3)));
241   CHKERRQ(VecDuplicate(user->x,&(user->xlb)));
242   CHKERRQ(VecDuplicate(user->x,&(user->xub)));
243   CHKERRQ(VecDuplicate(user->x,&(user->c)));
244   CHKERRQ(VecSet(user->xlb,0.0));
245   CHKERRQ(VecSet(user->c,0.0));
246   CHKERRQ(VecSet(user->xub,PETSC_INFINITY));
247 
248   CHKERRQ(MatTransposeMatMult(user->A,user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->ATA)));
249   CHKERRQ(MatTransposeMatMult(user->A,user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->Hx)));
250   CHKERRQ(MatTransposeMatMult(user->A,user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->HF)));
251 
252   CHKERRQ(MatAssemblyBegin(user->ATA,MAT_FINAL_ASSEMBLY));
253   CHKERRQ(MatAssemblyEnd(user->ATA,MAT_FINAL_ASSEMBLY));
254   CHKERRQ(MatAssemblyBegin(user->Hx,MAT_FINAL_ASSEMBLY));
255   CHKERRQ(MatAssemblyEnd(user->Hx,MAT_FINAL_ASSEMBLY));
256   CHKERRQ(MatAssemblyBegin(user->HF,MAT_FINAL_ASSEMBLY));
257   CHKERRQ(MatAssemblyEnd(user->HF,MAT_FINAL_ASSEMBLY));
258 
259   user->lambda = 1.e-8;
260   user->eps    = 1.e-3;
261   user->reg    = 2;
262   user->mumin  = 5.e-6;
263 
264   ierr = PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "tomographyADMM.c");CHKERRQ(ierr);
265   CHKERRQ(PetscOptionsInt("-reg","Regularization scheme for z solver (1,2)", "tomographyADMM.c", user->reg, &(user->reg), NULL));
266   CHKERRQ(PetscOptionsReal("-lambda", "The regularization multiplier. 1 default", "tomographyADMM.c", user->lambda, &(user->lambda), NULL));
267   CHKERRQ(PetscOptionsReal("-eps", "L1 norm epsilon padding", "tomographyADMM.c", user->eps, &(user->eps), NULL));
268   CHKERRQ(PetscOptionsReal("-mumin", "Minimum value for ADMM spectral penalty", "tomographyADMM.c", user->mumin, &(user->mumin), NULL));
269   ierr = PetscOptionsEnd();CHKERRQ(ierr);
270   PetscFunctionReturn(0);
271 }
272 
273 /*------------------------------------------------------------*/
274 
275 PetscErrorCode DestroyContext(AppCtx *user)
276 {
277   PetscFunctionBegin;
278   CHKERRQ(MatDestroy(&user->A));
279   CHKERRQ(MatDestroy(&user->ATA));
280   CHKERRQ(MatDestroy(&user->Hx));
281   CHKERRQ(MatDestroy(&user->Hz));
282   CHKERRQ(MatDestroy(&user->HF));
283   CHKERRQ(MatDestroy(&user->D));
284   CHKERRQ(MatDestroy(&user->DTD));
285   CHKERRQ(VecDestroy(&user->xGT));
286   CHKERRQ(VecDestroy(&user->xlb));
287   CHKERRQ(VecDestroy(&user->xub));
288   CHKERRQ(VecDestroy(&user->b));
289   CHKERRQ(VecDestroy(&user->x));
290   CHKERRQ(VecDestroy(&user->c));
291   CHKERRQ(VecDestroy(&user->workN3));
292   CHKERRQ(VecDestroy(&user->workN2));
293   CHKERRQ(VecDestroy(&user->workN));
294   CHKERRQ(VecDestroy(&user->workM));
295   PetscFunctionReturn(0);
296 }
297 
298 /*------------------------------------------------------------*/
299 
300 int main(int argc,char **argv)
301 {
302   Tao            tao,misfit,reg;
303   PetscReal      v1,v2;
304   AppCtx*        user;
305   PetscViewer    fd;
306   char           resultFile[] = "tomographyResult_x";
307 
308   CHKERRQ(PetscInitialize(&argc,&argv,(char*)0,help));
309   CHKERRQ(PetscNew(&user));
310   CHKERRQ(InitializeUserData(user));
311 
312   CHKERRQ(TaoCreate(PETSC_COMM_WORLD, &tao));
313   CHKERRQ(TaoSetType(tao, TAOADMM));
314   CHKERRQ(TaoSetSolution(tao, user->x));
315   /* f(x) + g(x) for parent tao */
316   CHKERRQ(TaoADMMSetSpectralPenalty(tao,1.));
317   CHKERRQ(TaoSetObjectiveAndGradient(tao,NULL, FullObjGrad, (void*)user));
318   CHKERRQ(MatShift(user->HF,user->lambda));
319   CHKERRQ(TaoSetHessian(tao, user->HF, user->HF, HessianFull, (void*)user));
320 
321   /* f(x) for misfit tao */
322   CHKERRQ(TaoADMMSetMisfitObjectiveAndGradientRoutine(tao, MisfitObjectiveAndGradient, (void*)user));
323   CHKERRQ(TaoADMMSetMisfitHessianRoutine(tao, user->Hx, user->Hx, HessianMisfit, (void*)user));
324   CHKERRQ(TaoADMMSetMisfitHessianChangeStatus(tao,PETSC_FALSE));
325   CHKERRQ(TaoADMMSetMisfitConstraintJacobian(tao,user->D,user->D,NullJacobian,(void*)user));
326 
327   /* g(x) for regularizer tao */
328   if (user->reg == 1) {
329     CHKERRQ(TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient1, (void*)user));
330     CHKERRQ(TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianReg, (void*)user));
331     CHKERRQ(TaoADMMSetRegHessianChangeStatus(tao,PETSC_TRUE));
332   } else if (user->reg == 2) {
333     CHKERRQ(TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient2, (void*)user));
334     CHKERRQ(MatShift(user->Hz,1));
335     CHKERRQ(MatScale(user->Hz,user->lambda));
336     CHKERRQ(TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianMisfit, (void*)user));
337     CHKERRQ(TaoADMMSetRegHessianChangeStatus(tao,PETSC_TRUE));
338   } else PetscCheck(user->reg == 3,PETSC_COMM_WORLD, PETSC_ERR_ARG_UNKNOWN_TYPE, "Incorrect Reg type"); /* TaoShell case */
339 
340   /* Set type for the misfit solver */
341   CHKERRQ(TaoADMMGetMisfitSubsolver(tao, &misfit));
342   CHKERRQ(TaoADMMGetRegularizationSubsolver(tao, &reg));
343   CHKERRQ(TaoSetType(misfit,TAONLS));
344   if (user->reg == 3) {
345     CHKERRQ(TaoSetType(reg,TAOSHELL));
346     CHKERRQ(TaoShellSetContext(reg, (void*) user));
347     CHKERRQ(TaoShellSetSolve(reg, TaoShellSolve_SoftThreshold));
348   } else {
349     CHKERRQ(TaoSetType(reg,TAONLS));
350   }
351   CHKERRQ(TaoSetVariableBounds(misfit,user->xlb,user->xub));
352 
353   /* Soft Thresholding solves the ADMM problem with the L1 regularizer lambda*||z||_1 and the x-z=0 constraint */
354   CHKERRQ(TaoADMMSetRegularizerCoefficient(tao, user->lambda));
355   CHKERRQ(TaoADMMSetRegularizerConstraintJacobian(tao,NULL,NULL,NullJacobian,(void*)user));
356   CHKERRQ(TaoADMMSetMinimumSpectralPenalty(tao,user->mumin));
357 
358   CHKERRQ(TaoADMMSetConstraintVectorRHS(tao,user->c));
359   CHKERRQ(TaoSetFromOptions(tao));
360   CHKERRQ(TaoSolve(tao));
361 
362   /* Save x (reconstruction of object) vector to a binary file, which maybe read from Matlab and convert to a 2D image for comparison. */
363   CHKERRQ(PetscViewerBinaryOpen(PETSC_COMM_WORLD,resultFile,FILE_MODE_WRITE,&fd));
364   CHKERRQ(VecView(user->x,fd));
365   CHKERRQ(PetscViewerDestroy(&fd));
366 
367   /* compute the error */
368   CHKERRQ(VecAXPY(user->x,-1,user->xGT));
369   CHKERRQ(VecNorm(user->x,NORM_2,&v1));
370   CHKERRQ(VecNorm(user->xGT,NORM_2,&v2));
371   CHKERRQ(PetscPrintf(PETSC_COMM_WORLD, "relative reconstruction error: ||x-xGT||/||xGT|| = %6.4e.\n", (double)(v1/v2)));
372 
373   /* Free TAO data structures */
374   CHKERRQ(TaoDestroy(&tao));
375   CHKERRQ(DestroyContext(user));
376   CHKERRQ(PetscFree(user));
377   CHKERRQ(PetscFinalize());
378   return 0;
379 }
380 
381 /*TEST
382 
383    build:
384       requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES)
385 
386    test:
387       suffix: 1
388       localrunfiles: tomographyData_A_b_xGT
389       args:  -lambda 1.e-8 -tao_monitor -tao_type nls -tao_nls_pc_type icc
390 
391    test:
392       suffix: 2
393       localrunfiles: tomographyData_A_b_xGT
394       args:  -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8  -misfit_tao_nls_pc_type icc -misfit_tao_monitor -reg_tao_monitor
395 
396    test:
397       suffix: 3
398       localrunfiles: tomographyData_A_b_xGT
399       args:  -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor
400 
401    test:
402       suffix: 4
403       localrunfiles: tomographyData_A_b_xGT
404       args:  -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -misfit_tao_monitor -misfit_tao_nls_pc_type icc
405 
406    test:
407       suffix: 5
408       localrunfiles: tomographyData_A_b_xGT
409       args:  -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc
410 
411    test:
412       suffix: 6
413       localrunfiles: tomographyData_A_b_xGT
414       args:  -reg 3 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc
415 
416 TEST*/
417