1 #include <petsctao.h> 2 /* 3 Description: ADMM tomography reconstruction example . 4 0.5*||Ax-b||^2 + lambda*g(x) 5 Reference: BRGN Tomography Example 6 */ 7 8 static char help[] = "Finds the ADMM solution to the under constraint linear model Ax = b, with regularizer. \n\ 9 A is a M*N real matrix (M<N), x is sparse. A good regularizer is an L1 regularizer. \n\ 10 We first split the operator into 0.5*||Ax-b||^2, f(x), and lambda*||x||_1, g(z), where lambda is user specified weight. \n\ 11 g(z) could be either ||z||_1, or ||z||_2^2. Default closed form solution for NORM1 would be soft-threshold, which is \n\ 12 natively supported in admm.c with -tao_admm_regularizer_type soft-threshold. Or user can use regular TAO solver for \n\ 13 either NORM1 or NORM2 or TAOSHELL, with -reg {1,2,3} \n\ 14 Then, we augment both f and g, and solve it via ADMM. \n\ 15 D is the M*N transform matrix so that D*x is sparse. \n"; 16 17 typedef struct { 18 PetscInt M,N,K,reg; 19 PetscReal lambda,eps,mumin; 20 Mat A,ATA,H,Hx,D,Hz,DTD,HF; 21 Vec c,xlb,xub,x,b,workM,workN,workN2,workN3,xGT; /* observation b, ground truth xGT, the lower bound and upper bound of x*/ 22 } AppCtx; 23 24 /*------------------------------------------------------------*/ 25 26 PetscErrorCode NullJacobian(Tao tao,Vec X,Mat J,Mat Jpre,void *ptr) 27 { 28 PetscFunctionBegin; 29 PetscFunctionReturn(0); 30 } 31 32 /*------------------------------------------------------------*/ 33 34 static PetscErrorCode TaoShellSolve_SoftThreshold(Tao tao) 35 { 36 PetscReal lambda, mu; 37 AppCtx *user; 38 Vec out,work,y,x; 39 Tao admm_tao,misfit; 40 41 PetscFunctionBegin; 42 user = NULL; 43 mu = 0; 44 PetscCall(TaoGetADMMParentTao(tao,&admm_tao)); 45 PetscCall(TaoADMMGetMisfitSubsolver(admm_tao, &misfit)); 46 PetscCall(TaoADMMGetSpectralPenalty(admm_tao,&mu)); 47 PetscCall(TaoShellGetContext(tao,&user)); 48 49 lambda = user->lambda; 50 work = user->workN; 51 PetscCall(TaoGetSolution(tao, &out)); 52 PetscCall(TaoGetSolution(misfit, &x)); 53 PetscCall(TaoADMMGetDualVector(admm_tao, &y)); 54 55 /* Dx + y/mu */ 56 PetscCall(MatMult(user->D,x,work)); 57 PetscCall(VecAXPY(work,1/mu,y)); 58 59 /* soft thresholding */ 60 PetscCall(TaoSoftThreshold(work, -lambda/mu, lambda/mu, out)); 61 PetscFunctionReturn(0); 62 } 63 64 /*------------------------------------------------------------*/ 65 66 PetscErrorCode MisfitObjectiveAndGradient(Tao tao,Vec X,PetscReal *f,Vec g,void *ptr) 67 { 68 AppCtx *user = (AppCtx*)ptr; 69 70 PetscFunctionBegin; 71 /* Objective 0.5*||Ax-b||_2^2 */ 72 PetscCall(MatMult(user->A,X,user->workM)); 73 PetscCall(VecAXPY(user->workM,-1,user->b)); 74 PetscCall(VecDot(user->workM,user->workM,f)); 75 *f *= 0.5; 76 /* Gradient. ATAx-ATb */ 77 PetscCall(MatMult(user->ATA,X,user->workN)); 78 PetscCall(MatMultTranspose(user->A,user->b,user->workN2)); 79 PetscCall(VecWAXPY(g,-1.,user->workN2,user->workN)); 80 PetscFunctionReturn(0); 81 } 82 83 /*------------------------------------------------------------*/ 84 85 PetscErrorCode RegularizerObjectiveAndGradient1(Tao tao,Vec X,PetscReal *f_reg,Vec G_reg,void *ptr) 86 { 87 AppCtx *user = (AppCtx*)ptr; 88 89 PetscFunctionBegin; 90 /* compute regularizer objective 91 * f = f + lambda*sum(sqrt(y.^2+epsilon^2) - epsilon), where y = D*x */ 92 PetscCall(VecCopy(X,user->workN2)); 93 PetscCall(VecPow(user->workN2,2.)); 94 PetscCall(VecShift(user->workN2,user->eps*user->eps)); 95 PetscCall(VecSqrtAbs(user->workN2)); 96 PetscCall(VecCopy(user->workN2, user->workN3)); 97 PetscCall(VecShift(user->workN2,-user->eps)); 98 PetscCall(VecSum(user->workN2,f_reg)); 99 *f_reg *= user->lambda; 100 /* compute regularizer gradient = lambda*x */ 101 PetscCall(VecPointwiseDivide(G_reg,X,user->workN3)); 102 PetscCall(VecScale(G_reg,user->lambda)); 103 PetscFunctionReturn(0); 104 } 105 106 /*------------------------------------------------------------*/ 107 108 PetscErrorCode RegularizerObjectiveAndGradient2(Tao tao,Vec X,PetscReal *f_reg,Vec G_reg,void *ptr) 109 { 110 AppCtx *user = (AppCtx*)ptr; 111 PetscReal temp; 112 113 PetscFunctionBegin; 114 /* compute regularizer objective = lambda*|z|_2^2 */ 115 PetscCall(VecDot(X,X,&temp)); 116 *f_reg = 0.5*user->lambda*temp; 117 /* compute regularizer gradient = lambda*z */ 118 PetscCall(VecCopy(X,G_reg)); 119 PetscCall(VecScale(G_reg,user->lambda)); 120 PetscFunctionReturn(0); 121 } 122 123 /*------------------------------------------------------------*/ 124 125 static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr) 126 { 127 PetscFunctionBegin; 128 PetscFunctionReturn(0); 129 } 130 131 /*------------------------------------------------------------*/ 132 133 static PetscErrorCode HessianReg(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr) 134 { 135 AppCtx *user = (AppCtx*)ptr; 136 137 PetscFunctionBegin; 138 PetscCall(MatMult(user->D,x,user->workN)); 139 PetscCall(VecPow(user->workN2,2.)); 140 PetscCall(VecShift(user->workN2,user->eps*user->eps)); 141 PetscCall(VecSqrtAbs(user->workN2)); 142 PetscCall(VecShift(user->workN2,-user->eps)); 143 PetscCall(VecReciprocal(user->workN2)); 144 PetscCall(VecScale(user->workN2,user->eps*user->eps)); 145 PetscCall(MatDiagonalSet(H,user->workN2,INSERT_VALUES)); 146 PetscFunctionReturn(0); 147 } 148 149 /*------------------------------------------------------------*/ 150 151 PetscErrorCode FullObjGrad(Tao tao,Vec X,PetscReal *f,Vec g,void *ptr) 152 { 153 AppCtx *user = (AppCtx*)ptr; 154 PetscReal f_reg; 155 156 PetscFunctionBegin; 157 /* Objective 0.5*||Ax-b||_2^2 + lambda*||x||_2^2*/ 158 PetscCall(MatMult(user->A,X,user->workM)); 159 PetscCall(VecAXPY(user->workM,-1,user->b)); 160 PetscCall(VecDot(user->workM,user->workM,f)); 161 PetscCall(VecNorm(X,NORM_2,&f_reg)); 162 *f *= 0.5; 163 *f += user->lambda*f_reg*f_reg; 164 /* Gradient. ATAx-ATb + 2*lambda*x */ 165 PetscCall(MatMult(user->ATA,X,user->workN)); 166 PetscCall(MatMultTranspose(user->A,user->b,user->workN2)); 167 PetscCall(VecWAXPY(g,-1.,user->workN2,user->workN)); 168 PetscCall(VecAXPY(g,2*user->lambda,X)); 169 PetscFunctionReturn(0); 170 } 171 /*------------------------------------------------------------*/ 172 173 static PetscErrorCode HessianFull(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr) 174 { 175 PetscFunctionBegin; 176 PetscFunctionReturn(0); 177 } 178 /*------------------------------------------------------------*/ 179 180 PetscErrorCode InitializeUserData(AppCtx *user) 181 { 182 char dataFile[] = "tomographyData_A_b_xGT"; /* Matrix A and vectors b, xGT(ground truth) binary files generated by Matlab. Debug: change from "tomographyData_A_b_xGT" to "cs1Data_A_b_xGT". */ 183 PetscViewer fd; /* used to load data from file */ 184 PetscInt k,n; 185 PetscScalar v; 186 187 PetscFunctionBegin; 188 /* Load the A matrix, b vector, and xGT vector from a binary file. */ 189 PetscCall(PetscViewerBinaryOpen(PETSC_COMM_WORLD,dataFile,FILE_MODE_READ,&fd)); 190 PetscCall(MatCreate(PETSC_COMM_WORLD,&user->A)); 191 PetscCall(MatSetType(user->A,MATAIJ)); 192 PetscCall(MatLoad(user->A,fd)); 193 PetscCall(VecCreate(PETSC_COMM_WORLD,&user->b)); 194 PetscCall(VecLoad(user->b,fd)); 195 PetscCall(VecCreate(PETSC_COMM_WORLD,&user->xGT)); 196 PetscCall(VecLoad(user->xGT,fd)); 197 PetscCall(PetscViewerDestroy(&fd)); 198 199 PetscCall(MatGetSize(user->A,&user->M,&user->N)); 200 201 PetscCall(MatCreate(PETSC_COMM_WORLD,&user->D)); 202 PetscCall(MatSetSizes(user->D,PETSC_DECIDE,PETSC_DECIDE,user->N,user->N)); 203 PetscCall(MatSetFromOptions(user->D)); 204 PetscCall(MatSetUp(user->D)); 205 for (k=0; k<user->N; k++) { 206 v = 1.0; 207 n = k+1; 208 if (k< user->N -1) { 209 PetscCall(MatSetValues(user->D,1,&k,1,&n,&v,INSERT_VALUES)); 210 } 211 v = -1.0; 212 PetscCall(MatSetValues(user->D,1,&k,1,&k,&v,INSERT_VALUES)); 213 } 214 PetscCall(MatAssemblyBegin(user->D,MAT_FINAL_ASSEMBLY)); 215 PetscCall(MatAssemblyEnd(user->D,MAT_FINAL_ASSEMBLY)); 216 217 PetscCall(MatTransposeMatMult(user->D,user->D,MAT_INITIAL_MATRIX,PETSC_DEFAULT,&user->DTD)); 218 219 PetscCall(MatCreate(PETSC_COMM_WORLD,&user->Hz)); 220 PetscCall(MatSetSizes(user->Hz,PETSC_DECIDE,PETSC_DECIDE,user->N,user->N)); 221 PetscCall(MatSetFromOptions(user->Hz)); 222 PetscCall(MatSetUp(user->Hz)); 223 PetscCall(MatAssemblyBegin(user->Hz,MAT_FINAL_ASSEMBLY)); 224 PetscCall(MatAssemblyEnd(user->Hz,MAT_FINAL_ASSEMBLY)); 225 226 PetscCall(VecCreate(PETSC_COMM_WORLD,&(user->x))); 227 PetscCall(VecCreate(PETSC_COMM_WORLD,&(user->workM))); 228 PetscCall(VecCreate(PETSC_COMM_WORLD,&(user->workN))); 229 PetscCall(VecCreate(PETSC_COMM_WORLD,&(user->workN2))); 230 PetscCall(VecSetSizes(user->x,PETSC_DECIDE,user->N)); 231 PetscCall(VecSetSizes(user->workM,PETSC_DECIDE,user->M)); 232 PetscCall(VecSetSizes(user->workN,PETSC_DECIDE,user->N)); 233 PetscCall(VecSetSizes(user->workN2,PETSC_DECIDE,user->N)); 234 PetscCall(VecSetFromOptions(user->x)); 235 PetscCall(VecSetFromOptions(user->workM)); 236 PetscCall(VecSetFromOptions(user->workN)); 237 PetscCall(VecSetFromOptions(user->workN2)); 238 239 PetscCall(VecDuplicate(user->workN,&(user->workN3))); 240 PetscCall(VecDuplicate(user->x,&(user->xlb))); 241 PetscCall(VecDuplicate(user->x,&(user->xub))); 242 PetscCall(VecDuplicate(user->x,&(user->c))); 243 PetscCall(VecSet(user->xlb,0.0)); 244 PetscCall(VecSet(user->c,0.0)); 245 PetscCall(VecSet(user->xub,PETSC_INFINITY)); 246 247 PetscCall(MatTransposeMatMult(user->A,user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->ATA))); 248 PetscCall(MatTransposeMatMult(user->A,user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->Hx))); 249 PetscCall(MatTransposeMatMult(user->A,user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->HF))); 250 251 PetscCall(MatAssemblyBegin(user->ATA,MAT_FINAL_ASSEMBLY)); 252 PetscCall(MatAssemblyEnd(user->ATA,MAT_FINAL_ASSEMBLY)); 253 PetscCall(MatAssemblyBegin(user->Hx,MAT_FINAL_ASSEMBLY)); 254 PetscCall(MatAssemblyEnd(user->Hx,MAT_FINAL_ASSEMBLY)); 255 PetscCall(MatAssemblyBegin(user->HF,MAT_FINAL_ASSEMBLY)); 256 PetscCall(MatAssemblyEnd(user->HF,MAT_FINAL_ASSEMBLY)); 257 258 user->lambda = 1.e-8; 259 user->eps = 1.e-3; 260 user->reg = 2; 261 user->mumin = 5.e-6; 262 263 PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "tomographyADMM.c"); 264 PetscCall(PetscOptionsInt("-reg","Regularization scheme for z solver (1,2)", "tomographyADMM.c", user->reg, &(user->reg), NULL)); 265 PetscCall(PetscOptionsReal("-lambda", "The regularization multiplier. 1 default", "tomographyADMM.c", user->lambda, &(user->lambda), NULL)); 266 PetscCall(PetscOptionsReal("-eps", "L1 norm epsilon padding", "tomographyADMM.c", user->eps, &(user->eps), NULL)); 267 PetscCall(PetscOptionsReal("-mumin", "Minimum value for ADMM spectral penalty", "tomographyADMM.c", user->mumin, &(user->mumin), NULL)); 268 PetscOptionsEnd(); 269 PetscFunctionReturn(0); 270 } 271 272 /*------------------------------------------------------------*/ 273 274 PetscErrorCode DestroyContext(AppCtx *user) 275 { 276 PetscFunctionBegin; 277 PetscCall(MatDestroy(&user->A)); 278 PetscCall(MatDestroy(&user->ATA)); 279 PetscCall(MatDestroy(&user->Hx)); 280 PetscCall(MatDestroy(&user->Hz)); 281 PetscCall(MatDestroy(&user->HF)); 282 PetscCall(MatDestroy(&user->D)); 283 PetscCall(MatDestroy(&user->DTD)); 284 PetscCall(VecDestroy(&user->xGT)); 285 PetscCall(VecDestroy(&user->xlb)); 286 PetscCall(VecDestroy(&user->xub)); 287 PetscCall(VecDestroy(&user->b)); 288 PetscCall(VecDestroy(&user->x)); 289 PetscCall(VecDestroy(&user->c)); 290 PetscCall(VecDestroy(&user->workN3)); 291 PetscCall(VecDestroy(&user->workN2)); 292 PetscCall(VecDestroy(&user->workN)); 293 PetscCall(VecDestroy(&user->workM)); 294 PetscFunctionReturn(0); 295 } 296 297 /*------------------------------------------------------------*/ 298 299 int main(int argc,char **argv) 300 { 301 Tao tao,misfit,reg; 302 PetscReal v1,v2; 303 AppCtx* user; 304 PetscViewer fd; 305 char resultFile[] = "tomographyResult_x"; 306 307 PetscFunctionBeginUser; 308 PetscCall(PetscInitialize(&argc,&argv,(char*)0,help)); 309 PetscCall(PetscNew(&user)); 310 PetscCall(InitializeUserData(user)); 311 312 PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao)); 313 PetscCall(TaoSetType(tao, TAOADMM)); 314 PetscCall(TaoSetSolution(tao, user->x)); 315 /* f(x) + g(x) for parent tao */ 316 PetscCall(TaoADMMSetSpectralPenalty(tao,1.)); 317 PetscCall(TaoSetObjectiveAndGradient(tao,NULL, FullObjGrad, (void*)user)); 318 PetscCall(MatShift(user->HF,user->lambda)); 319 PetscCall(TaoSetHessian(tao, user->HF, user->HF, HessianFull, (void*)user)); 320 321 /* f(x) for misfit tao */ 322 PetscCall(TaoADMMSetMisfitObjectiveAndGradientRoutine(tao, MisfitObjectiveAndGradient, (void*)user)); 323 PetscCall(TaoADMMSetMisfitHessianRoutine(tao, user->Hx, user->Hx, HessianMisfit, (void*)user)); 324 PetscCall(TaoADMMSetMisfitHessianChangeStatus(tao,PETSC_FALSE)); 325 PetscCall(TaoADMMSetMisfitConstraintJacobian(tao,user->D,user->D,NullJacobian,(void*)user)); 326 327 /* g(x) for regularizer tao */ 328 if (user->reg == 1) { 329 PetscCall(TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient1, (void*)user)); 330 PetscCall(TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianReg, (void*)user)); 331 PetscCall(TaoADMMSetRegHessianChangeStatus(tao,PETSC_TRUE)); 332 } else if (user->reg == 2) { 333 PetscCall(TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient2, (void*)user)); 334 PetscCall(MatShift(user->Hz,1)); 335 PetscCall(MatScale(user->Hz,user->lambda)); 336 PetscCall(TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianMisfit, (void*)user)); 337 PetscCall(TaoADMMSetRegHessianChangeStatus(tao,PETSC_TRUE)); 338 } else PetscCheck(user->reg == 3,PETSC_COMM_WORLD, PETSC_ERR_ARG_UNKNOWN_TYPE, "Incorrect Reg type"); /* TaoShell case */ 339 340 /* Set type for the misfit solver */ 341 PetscCall(TaoADMMGetMisfitSubsolver(tao, &misfit)); 342 PetscCall(TaoADMMGetRegularizationSubsolver(tao, ®)); 343 PetscCall(TaoSetType(misfit,TAONLS)); 344 if (user->reg == 3) { 345 PetscCall(TaoSetType(reg,TAOSHELL)); 346 PetscCall(TaoShellSetContext(reg, (void*) user)); 347 PetscCall(TaoShellSetSolve(reg, TaoShellSolve_SoftThreshold)); 348 } else { 349 PetscCall(TaoSetType(reg,TAONLS)); 350 } 351 PetscCall(TaoSetVariableBounds(misfit,user->xlb,user->xub)); 352 353 /* Soft Thresholding solves the ADMM problem with the L1 regularizer lambda*||z||_1 and the x-z=0 constraint */ 354 PetscCall(TaoADMMSetRegularizerCoefficient(tao, user->lambda)); 355 PetscCall(TaoADMMSetRegularizerConstraintJacobian(tao,NULL,NULL,NullJacobian,(void*)user)); 356 PetscCall(TaoADMMSetMinimumSpectralPenalty(tao,user->mumin)); 357 358 PetscCall(TaoADMMSetConstraintVectorRHS(tao,user->c)); 359 PetscCall(TaoSetFromOptions(tao)); 360 PetscCall(TaoSolve(tao)); 361 362 /* Save x (reconstruction of object) vector to a binary file, which maybe read from Matlab and convert to a 2D image for comparison. */ 363 PetscCall(PetscViewerBinaryOpen(PETSC_COMM_WORLD,resultFile,FILE_MODE_WRITE,&fd)); 364 PetscCall(VecView(user->x,fd)); 365 PetscCall(PetscViewerDestroy(&fd)); 366 367 /* compute the error */ 368 PetscCall(VecAXPY(user->x,-1,user->xGT)); 369 PetscCall(VecNorm(user->x,NORM_2,&v1)); 370 PetscCall(VecNorm(user->xGT,NORM_2,&v2)); 371 PetscCall(PetscPrintf(PETSC_COMM_WORLD, "relative reconstruction error: ||x-xGT||/||xGT|| = %6.4e.\n", (double)(v1/v2))); 372 373 /* Free TAO data structures */ 374 PetscCall(TaoDestroy(&tao)); 375 PetscCall(DestroyContext(user)); 376 PetscCall(PetscFree(user)); 377 PetscCall(PetscFinalize()); 378 return 0; 379 } 380 381 /*TEST 382 383 build: 384 requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES) 385 386 test: 387 suffix: 1 388 localrunfiles: tomographyData_A_b_xGT 389 args: -lambda 1.e-8 -tao_monitor -tao_type nls -tao_nls_pc_type icc 390 391 test: 392 suffix: 2 393 localrunfiles: tomographyData_A_b_xGT 394 args: -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor -reg_tao_monitor 395 396 test: 397 suffix: 3 398 localrunfiles: tomographyData_A_b_xGT 399 args: -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor 400 401 test: 402 suffix: 4 403 localrunfiles: tomographyData_A_b_xGT 404 args: -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -misfit_tao_monitor -misfit_tao_nls_pc_type icc 405 406 test: 407 suffix: 5 408 localrunfiles: tomographyData_A_b_xGT 409 args: -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc 410 411 test: 412 suffix: 6 413 localrunfiles: tomographyData_A_b_xGT 414 args: -reg 3 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc 415 416 TEST*/ 417