1 #include <petsctao.h> 2 /* 3 Description: ADMM tomography reconstruction example . 4 0.5*||Ax-b||^2 + lambda*g(x) 5 Reference: BRGN Tomography Example 6 */ 7 8 static char help[] = "Finds the ADMM solution to the under constraint linear model Ax = b, with regularizer. \n\ 9 A is a M*N real matrix (M<N), x is sparse. A good regularizer is an L1 regularizer. \n\ 10 We first split the operator into 0.5*||Ax-b||^2, f(x), and lambda*||x||_1, g(z), where lambda is user specified weight. \n\ 11 g(z) could be either ||z||_1, or ||z||_2^2. Default closed form solution for NORM1 would be soft-threshold, which is \n\ 12 natively supported in admm.c with -tao_admm_regularizer_type soft-threshold. Or user can use regular TAO solver for \n\ 13 either NORM1 or NORM2 or TAOSHELL, with -reg {1,2,3} \n\ 14 Then, we augment both f and g, and solve it via ADMM. \n\ 15 D is the M*N transform matrix so that D*x is sparse. \n"; 16 17 typedef struct { 18 PetscInt M, N, K, reg; 19 PetscReal lambda, eps, mumin; 20 Mat A, ATA, H, Hx, D, Hz, DTD, HF; 21 Vec c, xlb, xub, x, b, workM, workN, workN2, workN3, xGT; /* observation b, ground truth xGT, the lower bound and upper bound of x*/ 22 } AppCtx; 23 24 /*------------------------------------------------------------*/ 25 26 PetscErrorCode NullJacobian(Tao tao, Vec X, Mat J, Mat Jpre, void *ptr) { 27 PetscFunctionBegin; 28 PetscFunctionReturn(0); 29 } 30 31 /*------------------------------------------------------------*/ 32 33 static PetscErrorCode TaoShellSolve_SoftThreshold(Tao tao) { 34 PetscReal lambda, mu; 35 AppCtx *user; 36 Vec out, work, y, x; 37 Tao admm_tao, misfit; 38 39 PetscFunctionBegin; 40 user = NULL; 41 mu = 0; 42 PetscCall(TaoGetADMMParentTao(tao, &admm_tao)); 43 PetscCall(TaoADMMGetMisfitSubsolver(admm_tao, &misfit)); 44 PetscCall(TaoADMMGetSpectralPenalty(admm_tao, &mu)); 45 PetscCall(TaoShellGetContext(tao, &user)); 46 47 lambda = user->lambda; 48 work = user->workN; 49 PetscCall(TaoGetSolution(tao, &out)); 50 PetscCall(TaoGetSolution(misfit, &x)); 51 PetscCall(TaoADMMGetDualVector(admm_tao, &y)); 52 53 /* Dx + y/mu */ 54 PetscCall(MatMult(user->D, x, work)); 55 PetscCall(VecAXPY(work, 1 / mu, y)); 56 57 /* soft thresholding */ 58 PetscCall(TaoSoftThreshold(work, -lambda / mu, lambda / mu, out)); 59 PetscFunctionReturn(0); 60 } 61 62 /*------------------------------------------------------------*/ 63 64 PetscErrorCode MisfitObjectiveAndGradient(Tao tao, Vec X, PetscReal *f, Vec g, void *ptr) { 65 AppCtx *user = (AppCtx *)ptr; 66 67 PetscFunctionBegin; 68 /* Objective 0.5*||Ax-b||_2^2 */ 69 PetscCall(MatMult(user->A, X, user->workM)); 70 PetscCall(VecAXPY(user->workM, -1, user->b)); 71 PetscCall(VecDot(user->workM, user->workM, f)); 72 *f *= 0.5; 73 /* Gradient. ATAx-ATb */ 74 PetscCall(MatMult(user->ATA, X, user->workN)); 75 PetscCall(MatMultTranspose(user->A, user->b, user->workN2)); 76 PetscCall(VecWAXPY(g, -1., user->workN2, user->workN)); 77 PetscFunctionReturn(0); 78 } 79 80 /*------------------------------------------------------------*/ 81 82 PetscErrorCode RegularizerObjectiveAndGradient1(Tao tao, Vec X, PetscReal *f_reg, Vec G_reg, void *ptr) { 83 AppCtx *user = (AppCtx *)ptr; 84 85 PetscFunctionBegin; 86 /* compute regularizer objective 87 * f = f + lambda*sum(sqrt(y.^2+epsilon^2) - epsilon), where y = D*x */ 88 PetscCall(VecCopy(X, user->workN2)); 89 PetscCall(VecPow(user->workN2, 2.)); 90 PetscCall(VecShift(user->workN2, user->eps * user->eps)); 91 PetscCall(VecSqrtAbs(user->workN2)); 92 PetscCall(VecCopy(user->workN2, user->workN3)); 93 PetscCall(VecShift(user->workN2, -user->eps)); 94 PetscCall(VecSum(user->workN2, f_reg)); 95 *f_reg *= user->lambda; 96 /* compute regularizer gradient = lambda*x */ 97 PetscCall(VecPointwiseDivide(G_reg, X, user->workN3)); 98 PetscCall(VecScale(G_reg, user->lambda)); 99 PetscFunctionReturn(0); 100 } 101 102 /*------------------------------------------------------------*/ 103 104 PetscErrorCode RegularizerObjectiveAndGradient2(Tao tao, Vec X, PetscReal *f_reg, Vec G_reg, void *ptr) { 105 AppCtx *user = (AppCtx *)ptr; 106 PetscReal temp; 107 108 PetscFunctionBegin; 109 /* compute regularizer objective = lambda*|z|_2^2 */ 110 PetscCall(VecDot(X, X, &temp)); 111 *f_reg = 0.5 * user->lambda * temp; 112 /* compute regularizer gradient = lambda*z */ 113 PetscCall(VecCopy(X, G_reg)); 114 PetscCall(VecScale(G_reg, user->lambda)); 115 PetscFunctionReturn(0); 116 } 117 118 /*------------------------------------------------------------*/ 119 120 static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr) { 121 PetscFunctionBegin; 122 PetscFunctionReturn(0); 123 } 124 125 /*------------------------------------------------------------*/ 126 127 static PetscErrorCode HessianReg(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr) { 128 AppCtx *user = (AppCtx *)ptr; 129 130 PetscFunctionBegin; 131 PetscCall(MatMult(user->D, x, user->workN)); 132 PetscCall(VecPow(user->workN2, 2.)); 133 PetscCall(VecShift(user->workN2, user->eps * user->eps)); 134 PetscCall(VecSqrtAbs(user->workN2)); 135 PetscCall(VecShift(user->workN2, -user->eps)); 136 PetscCall(VecReciprocal(user->workN2)); 137 PetscCall(VecScale(user->workN2, user->eps * user->eps)); 138 PetscCall(MatDiagonalSet(H, user->workN2, INSERT_VALUES)); 139 PetscFunctionReturn(0); 140 } 141 142 /*------------------------------------------------------------*/ 143 144 PetscErrorCode FullObjGrad(Tao tao, Vec X, PetscReal *f, Vec g, void *ptr) { 145 AppCtx *user = (AppCtx *)ptr; 146 PetscReal f_reg; 147 148 PetscFunctionBegin; 149 /* Objective 0.5*||Ax-b||_2^2 + lambda*||x||_2^2*/ 150 PetscCall(MatMult(user->A, X, user->workM)); 151 PetscCall(VecAXPY(user->workM, -1, user->b)); 152 PetscCall(VecDot(user->workM, user->workM, f)); 153 PetscCall(VecNorm(X, NORM_2, &f_reg)); 154 *f *= 0.5; 155 *f += user->lambda * f_reg * f_reg; 156 /* Gradient. ATAx-ATb + 2*lambda*x */ 157 PetscCall(MatMult(user->ATA, X, user->workN)); 158 PetscCall(MatMultTranspose(user->A, user->b, user->workN2)); 159 PetscCall(VecWAXPY(g, -1., user->workN2, user->workN)); 160 PetscCall(VecAXPY(g, 2 * user->lambda, X)); 161 PetscFunctionReturn(0); 162 } 163 /*------------------------------------------------------------*/ 164 165 static PetscErrorCode HessianFull(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr) { 166 PetscFunctionBegin; 167 PetscFunctionReturn(0); 168 } 169 /*------------------------------------------------------------*/ 170 171 PetscErrorCode InitializeUserData(AppCtx *user) { 172 char dataFile[] = "tomographyData_A_b_xGT"; /* Matrix A and vectors b, xGT(ground truth) binary files generated by Matlab. Debug: change from "tomographyData_A_b_xGT" to "cs1Data_A_b_xGT". */ 173 PetscViewer fd; /* used to load data from file */ 174 PetscInt k, n; 175 PetscScalar v; 176 177 PetscFunctionBegin; 178 /* Load the A matrix, b vector, and xGT vector from a binary file. */ 179 PetscCall(PetscViewerBinaryOpen(PETSC_COMM_WORLD, dataFile, FILE_MODE_READ, &fd)); 180 PetscCall(MatCreate(PETSC_COMM_WORLD, &user->A)); 181 PetscCall(MatSetType(user->A, MATAIJ)); 182 PetscCall(MatLoad(user->A, fd)); 183 PetscCall(VecCreate(PETSC_COMM_WORLD, &user->b)); 184 PetscCall(VecLoad(user->b, fd)); 185 PetscCall(VecCreate(PETSC_COMM_WORLD, &user->xGT)); 186 PetscCall(VecLoad(user->xGT, fd)); 187 PetscCall(PetscViewerDestroy(&fd)); 188 189 PetscCall(MatGetSize(user->A, &user->M, &user->N)); 190 191 PetscCall(MatCreate(PETSC_COMM_WORLD, &user->D)); 192 PetscCall(MatSetSizes(user->D, PETSC_DECIDE, PETSC_DECIDE, user->N, user->N)); 193 PetscCall(MatSetFromOptions(user->D)); 194 PetscCall(MatSetUp(user->D)); 195 for (k = 0; k < user->N; k++) { 196 v = 1.0; 197 n = k + 1; 198 if (k < user->N - 1) { PetscCall(MatSetValues(user->D, 1, &k, 1, &n, &v, INSERT_VALUES)); } 199 v = -1.0; 200 PetscCall(MatSetValues(user->D, 1, &k, 1, &k, &v, INSERT_VALUES)); 201 } 202 PetscCall(MatAssemblyBegin(user->D, MAT_FINAL_ASSEMBLY)); 203 PetscCall(MatAssemblyEnd(user->D, MAT_FINAL_ASSEMBLY)); 204 205 PetscCall(MatTransposeMatMult(user->D, user->D, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &user->DTD)); 206 207 PetscCall(MatCreate(PETSC_COMM_WORLD, &user->Hz)); 208 PetscCall(MatSetSizes(user->Hz, PETSC_DECIDE, PETSC_DECIDE, user->N, user->N)); 209 PetscCall(MatSetFromOptions(user->Hz)); 210 PetscCall(MatSetUp(user->Hz)); 211 PetscCall(MatAssemblyBegin(user->Hz, MAT_FINAL_ASSEMBLY)); 212 PetscCall(MatAssemblyEnd(user->Hz, MAT_FINAL_ASSEMBLY)); 213 214 PetscCall(VecCreate(PETSC_COMM_WORLD, &(user->x))); 215 PetscCall(VecCreate(PETSC_COMM_WORLD, &(user->workM))); 216 PetscCall(VecCreate(PETSC_COMM_WORLD, &(user->workN))); 217 PetscCall(VecCreate(PETSC_COMM_WORLD, &(user->workN2))); 218 PetscCall(VecSetSizes(user->x, PETSC_DECIDE, user->N)); 219 PetscCall(VecSetSizes(user->workM, PETSC_DECIDE, user->M)); 220 PetscCall(VecSetSizes(user->workN, PETSC_DECIDE, user->N)); 221 PetscCall(VecSetSizes(user->workN2, PETSC_DECIDE, user->N)); 222 PetscCall(VecSetFromOptions(user->x)); 223 PetscCall(VecSetFromOptions(user->workM)); 224 PetscCall(VecSetFromOptions(user->workN)); 225 PetscCall(VecSetFromOptions(user->workN2)); 226 227 PetscCall(VecDuplicate(user->workN, &(user->workN3))); 228 PetscCall(VecDuplicate(user->x, &(user->xlb))); 229 PetscCall(VecDuplicate(user->x, &(user->xub))); 230 PetscCall(VecDuplicate(user->x, &(user->c))); 231 PetscCall(VecSet(user->xlb, 0.0)); 232 PetscCall(VecSet(user->c, 0.0)); 233 PetscCall(VecSet(user->xub, PETSC_INFINITY)); 234 235 PetscCall(MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->ATA))); 236 PetscCall(MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->Hx))); 237 PetscCall(MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->HF))); 238 239 PetscCall(MatAssemblyBegin(user->ATA, MAT_FINAL_ASSEMBLY)); 240 PetscCall(MatAssemblyEnd(user->ATA, MAT_FINAL_ASSEMBLY)); 241 PetscCall(MatAssemblyBegin(user->Hx, MAT_FINAL_ASSEMBLY)); 242 PetscCall(MatAssemblyEnd(user->Hx, MAT_FINAL_ASSEMBLY)); 243 PetscCall(MatAssemblyBegin(user->HF, MAT_FINAL_ASSEMBLY)); 244 PetscCall(MatAssemblyEnd(user->HF, MAT_FINAL_ASSEMBLY)); 245 246 user->lambda = 1.e-8; 247 user->eps = 1.e-3; 248 user->reg = 2; 249 user->mumin = 5.e-6; 250 251 PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "tomographyADMM.c"); 252 PetscCall(PetscOptionsInt("-reg", "Regularization scheme for z solver (1,2)", "tomographyADMM.c", user->reg, &(user->reg), NULL)); 253 PetscCall(PetscOptionsReal("-lambda", "The regularization multiplier. 1 default", "tomographyADMM.c", user->lambda, &(user->lambda), NULL)); 254 PetscCall(PetscOptionsReal("-eps", "L1 norm epsilon padding", "tomographyADMM.c", user->eps, &(user->eps), NULL)); 255 PetscCall(PetscOptionsReal("-mumin", "Minimum value for ADMM spectral penalty", "tomographyADMM.c", user->mumin, &(user->mumin), NULL)); 256 PetscOptionsEnd(); 257 PetscFunctionReturn(0); 258 } 259 260 /*------------------------------------------------------------*/ 261 262 PetscErrorCode DestroyContext(AppCtx *user) { 263 PetscFunctionBegin; 264 PetscCall(MatDestroy(&user->A)); 265 PetscCall(MatDestroy(&user->ATA)); 266 PetscCall(MatDestroy(&user->Hx)); 267 PetscCall(MatDestroy(&user->Hz)); 268 PetscCall(MatDestroy(&user->HF)); 269 PetscCall(MatDestroy(&user->D)); 270 PetscCall(MatDestroy(&user->DTD)); 271 PetscCall(VecDestroy(&user->xGT)); 272 PetscCall(VecDestroy(&user->xlb)); 273 PetscCall(VecDestroy(&user->xub)); 274 PetscCall(VecDestroy(&user->b)); 275 PetscCall(VecDestroy(&user->x)); 276 PetscCall(VecDestroy(&user->c)); 277 PetscCall(VecDestroy(&user->workN3)); 278 PetscCall(VecDestroy(&user->workN2)); 279 PetscCall(VecDestroy(&user->workN)); 280 PetscCall(VecDestroy(&user->workM)); 281 PetscFunctionReturn(0); 282 } 283 284 /*------------------------------------------------------------*/ 285 286 int main(int argc, char **argv) { 287 Tao tao, misfit, reg; 288 PetscReal v1, v2; 289 AppCtx *user; 290 PetscViewer fd; 291 char resultFile[] = "tomographyResult_x"; 292 293 PetscFunctionBeginUser; 294 PetscCall(PetscInitialize(&argc, &argv, (char *)0, help)); 295 PetscCall(PetscNew(&user)); 296 PetscCall(InitializeUserData(user)); 297 298 PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao)); 299 PetscCall(TaoSetType(tao, TAOADMM)); 300 PetscCall(TaoSetSolution(tao, user->x)); 301 /* f(x) + g(x) for parent tao */ 302 PetscCall(TaoADMMSetSpectralPenalty(tao, 1.)); 303 PetscCall(TaoSetObjectiveAndGradient(tao, NULL, FullObjGrad, (void *)user)); 304 PetscCall(MatShift(user->HF, user->lambda)); 305 PetscCall(TaoSetHessian(tao, user->HF, user->HF, HessianFull, (void *)user)); 306 307 /* f(x) for misfit tao */ 308 PetscCall(TaoADMMSetMisfitObjectiveAndGradientRoutine(tao, MisfitObjectiveAndGradient, (void *)user)); 309 PetscCall(TaoADMMSetMisfitHessianRoutine(tao, user->Hx, user->Hx, HessianMisfit, (void *)user)); 310 PetscCall(TaoADMMSetMisfitHessianChangeStatus(tao, PETSC_FALSE)); 311 PetscCall(TaoADMMSetMisfitConstraintJacobian(tao, user->D, user->D, NullJacobian, (void *)user)); 312 313 /* g(x) for regularizer tao */ 314 if (user->reg == 1) { 315 PetscCall(TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient1, (void *)user)); 316 PetscCall(TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianReg, (void *)user)); 317 PetscCall(TaoADMMSetRegHessianChangeStatus(tao, PETSC_TRUE)); 318 } else if (user->reg == 2) { 319 PetscCall(TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient2, (void *)user)); 320 PetscCall(MatShift(user->Hz, 1)); 321 PetscCall(MatScale(user->Hz, user->lambda)); 322 PetscCall(TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianMisfit, (void *)user)); 323 PetscCall(TaoADMMSetRegHessianChangeStatus(tao, PETSC_TRUE)); 324 } else PetscCheck(user->reg == 3, PETSC_COMM_WORLD, PETSC_ERR_ARG_UNKNOWN_TYPE, "Incorrect Reg type"); /* TaoShell case */ 325 326 /* Set type for the misfit solver */ 327 PetscCall(TaoADMMGetMisfitSubsolver(tao, &misfit)); 328 PetscCall(TaoADMMGetRegularizationSubsolver(tao, ®)); 329 PetscCall(TaoSetType(misfit, TAONLS)); 330 if (user->reg == 3) { 331 PetscCall(TaoSetType(reg, TAOSHELL)); 332 PetscCall(TaoShellSetContext(reg, (void *)user)); 333 PetscCall(TaoShellSetSolve(reg, TaoShellSolve_SoftThreshold)); 334 } else { 335 PetscCall(TaoSetType(reg, TAONLS)); 336 } 337 PetscCall(TaoSetVariableBounds(misfit, user->xlb, user->xub)); 338 339 /* Soft Thresholding solves the ADMM problem with the L1 regularizer lambda*||z||_1 and the x-z=0 constraint */ 340 PetscCall(TaoADMMSetRegularizerCoefficient(tao, user->lambda)); 341 PetscCall(TaoADMMSetRegularizerConstraintJacobian(tao, NULL, NULL, NullJacobian, (void *)user)); 342 PetscCall(TaoADMMSetMinimumSpectralPenalty(tao, user->mumin)); 343 344 PetscCall(TaoADMMSetConstraintVectorRHS(tao, user->c)); 345 PetscCall(TaoSetFromOptions(tao)); 346 PetscCall(TaoSolve(tao)); 347 348 /* Save x (reconstruction of object) vector to a binary file, which maybe read from Matlab and convert to a 2D image for comparison. */ 349 PetscCall(PetscViewerBinaryOpen(PETSC_COMM_WORLD, resultFile, FILE_MODE_WRITE, &fd)); 350 PetscCall(VecView(user->x, fd)); 351 PetscCall(PetscViewerDestroy(&fd)); 352 353 /* compute the error */ 354 PetscCall(VecAXPY(user->x, -1, user->xGT)); 355 PetscCall(VecNorm(user->x, NORM_2, &v1)); 356 PetscCall(VecNorm(user->xGT, NORM_2, &v2)); 357 PetscCall(PetscPrintf(PETSC_COMM_WORLD, "relative reconstruction error: ||x-xGT||/||xGT|| = %6.4e.\n", (double)(v1 / v2))); 358 359 /* Free TAO data structures */ 360 PetscCall(TaoDestroy(&tao)); 361 PetscCall(DestroyContext(user)); 362 PetscCall(PetscFree(user)); 363 PetscCall(PetscFinalize()); 364 return 0; 365 } 366 367 /*TEST 368 369 build: 370 requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES) 371 372 test: 373 suffix: 1 374 localrunfiles: tomographyData_A_b_xGT 375 args: -lambda 1.e-8 -tao_monitor -tao_type nls -tao_nls_pc_type icc 376 377 test: 378 suffix: 2 379 localrunfiles: tomographyData_A_b_xGT 380 args: -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor -reg_tao_monitor 381 382 test: 383 suffix: 3 384 localrunfiles: tomographyData_A_b_xGT 385 args: -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor 386 387 test: 388 suffix: 4 389 localrunfiles: tomographyData_A_b_xGT 390 args: -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -misfit_tao_monitor -misfit_tao_nls_pc_type icc 391 392 test: 393 suffix: 5 394 localrunfiles: tomographyData_A_b_xGT 395 args: -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc 396 397 test: 398 suffix: 6 399 localrunfiles: tomographyData_A_b_xGT 400 args: -reg 3 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc 401 402 TEST*/ 403