1 static char help[] = "Simple example to test separable objective optimizers.\n"; 2 3 #include <petsc.h> 4 #include <petsctao.h> 5 #include <petscvec.h> 6 #include <petscmath.h> 7 8 #define NWORKLEFT 4 9 #define NWORKRIGHT 12 10 11 typedef struct _UserCtx 12 { 13 PetscInt m; /* The row dimension of F */ 14 PetscInt n; /* The column dimension of F */ 15 PetscInt matops; /* Matrix format. 0 for stencil, 1 for random */ 16 PetscInt iter; /* Numer of iterations for ADMM */ 17 PetscReal hStart; /* Starting point for Taylor test */ 18 PetscReal hFactor; /* Taylor test step factor */ 19 PetscReal hMin; /* Taylor test end goal */ 20 PetscReal alpha; /* regularization constant applied to || x ||_p */ 21 PetscReal eps; /* small constant for approximating gradient of || x ||_1 */ 22 PetscReal mu; /* the augmented Lagrangian term in ADMM */ 23 PetscReal abstol; 24 PetscReal reltol; 25 Mat F; /* matrix in least squares component $(1/2) * || F x - d ||_2^2$ */ 26 Mat W; /* Workspace matrix. ATA */ 27 Mat Hm; /* Hessian Misfit*/ 28 Mat Hr; /* Hessian Reg*/ 29 Vec d; /* RHS in least squares component $(1/2) * || F x - d ||_2^2$ */ 30 Vec workLeft[NWORKLEFT]; /* Workspace for temporary vec */ 31 Vec workRight[NWORKRIGHT]; /* Workspace for temporary vec */ 32 NormType p; 33 PetscRandom rctx; 34 PetscBool taylor; /* Flag to determine whether to run Taylor test or not */ 35 PetscBool use_admm; /* Flag to determine whether to run Taylor test or not */ 36 }* UserCtx; 37 38 static PetscErrorCode CreateRHS(UserCtx ctx) 39 { 40 PetscFunctionBegin; 41 /* build the rhs d in ctx */ 42 PetscCall(VecCreate(PETSC_COMM_WORLD,&(ctx->d))); 43 PetscCall(VecSetSizes(ctx->d,PETSC_DECIDE,ctx->m)); 44 PetscCall(VecSetFromOptions(ctx->d)); 45 PetscCall(VecSetRandom(ctx->d,ctx->rctx)); 46 PetscFunctionReturn(0); 47 } 48 49 static PetscErrorCode CreateMatrix(UserCtx ctx) 50 { 51 PetscInt Istart,Iend,i,j,Ii,gridN,I_n, I_s, I_e, I_w; 52 #if defined(PETSC_USE_LOG) 53 PetscLogStage stage; 54 #endif 55 56 PetscFunctionBegin; 57 /* build the matrix F in ctx */ 58 PetscCall(MatCreate(PETSC_COMM_WORLD, &(ctx->F))); 59 PetscCall(MatSetSizes(ctx->F,PETSC_DECIDE, PETSC_DECIDE, ctx->m, ctx->n)); 60 PetscCall(MatSetType(ctx->F,MATAIJ)); /* TODO: Decide specific SetType other than dummy*/ 61 PetscCall(MatMPIAIJSetPreallocation(ctx->F, 5, NULL, 5, NULL)); /*TODO: some number other than 5?*/ 62 PetscCall(MatSeqAIJSetPreallocation(ctx->F, 5, NULL)); 63 PetscCall(MatSetUp(ctx->F)); 64 PetscCall(MatGetOwnershipRange(ctx->F,&Istart,&Iend)); 65 PetscCall(PetscLogStageRegister("Assembly", &stage)); 66 PetscCall(PetscLogStagePush(stage)); 67 68 /* Set matrix elements in 2-D five point stencil format. */ 69 if (!(ctx->matops)) { 70 PetscCheck(ctx->m == ctx->n,PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Stencil matrix must be square"); 71 gridN = (PetscInt) PetscSqrtReal((PetscReal) ctx->m); 72 PetscCheck(gridN * gridN == ctx->m,PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Number of rows must be square"); 73 for (Ii=Istart; Ii<Iend; Ii++) { 74 i = Ii / gridN; j = Ii % gridN; 75 I_n = i * gridN + j + 1; 76 if (j + 1 >= gridN) I_n = -1; 77 I_s = i * gridN + j - 1; 78 if (j - 1 < 0) I_s = -1; 79 I_e = (i + 1) * gridN + j; 80 if (i + 1 >= gridN) I_e = -1; 81 I_w = (i - 1) * gridN + j; 82 if (i - 1 < 0) I_w = -1; 83 PetscCall(MatSetValue(ctx->F, Ii, Ii, 4., INSERT_VALUES)); 84 PetscCall(MatSetValue(ctx->F, Ii, I_n, -1., INSERT_VALUES)); 85 PetscCall(MatSetValue(ctx->F, Ii, I_s, -1., INSERT_VALUES)); 86 PetscCall(MatSetValue(ctx->F, Ii, I_e, -1., INSERT_VALUES)); 87 PetscCall(MatSetValue(ctx->F, Ii, I_w, -1., INSERT_VALUES)); 88 } 89 } else PetscCall(MatSetRandom(ctx->F, ctx->rctx)); 90 PetscCall(MatAssemblyBegin(ctx->F, MAT_FINAL_ASSEMBLY)); 91 PetscCall(MatAssemblyEnd(ctx->F, MAT_FINAL_ASSEMBLY)); 92 PetscCall(PetscLogStagePop()); 93 /* Stencil matrix is symmetric. Setting symmetric flag for ICC/Cholesky preconditioner */ 94 if (!(ctx->matops)) { 95 PetscCall(MatSetOption(ctx->F,MAT_SYMMETRIC,PETSC_TRUE)); 96 } 97 PetscCall(MatTransposeMatMult(ctx->F,ctx->F, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(ctx->W))); 98 /* Setup Hessian Workspace in same shape as W */ 99 PetscCall(MatDuplicate(ctx->W,MAT_DO_NOT_COPY_VALUES,&(ctx->Hm))); 100 PetscCall(MatDuplicate(ctx->W,MAT_DO_NOT_COPY_VALUES,&(ctx->Hr))); 101 PetscFunctionReturn(0); 102 } 103 104 static PetscErrorCode SetupWorkspace(UserCtx ctx) 105 { 106 PetscInt i; 107 108 PetscFunctionBegin; 109 PetscCall(MatCreateVecs(ctx->F, &ctx->workLeft[0], &ctx->workRight[0])); 110 for (i=1; i<NWORKLEFT; i++) { 111 PetscCall(VecDuplicate(ctx->workLeft[0], &(ctx->workLeft[i]))); 112 } 113 for (i=1; i<NWORKRIGHT; i++) { 114 PetscCall(VecDuplicate(ctx->workRight[0], &(ctx->workRight[i]))); 115 } 116 PetscFunctionReturn(0); 117 } 118 119 static PetscErrorCode ConfigureContext(UserCtx ctx) 120 { 121 PetscFunctionBegin; 122 ctx->m = 16; 123 ctx->n = 16; 124 ctx->eps = 1.e-3; 125 ctx->abstol = 1.e-4; 126 ctx->reltol = 1.e-2; 127 ctx->hStart = 1.; 128 ctx->hMin = 1.e-3; 129 ctx->hFactor = 0.5; 130 ctx->alpha = 1.; 131 ctx->mu = 1.0; 132 ctx->matops = 0; 133 ctx->iter = 10; 134 ctx->p = NORM_2; 135 ctx->taylor = PETSC_TRUE; 136 ctx->use_admm = PETSC_FALSE; 137 PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "ex4.c"); 138 PetscCall(PetscOptionsInt("-m", "The row dimension of matrix F", "ex4.c", ctx->m, &(ctx->m), NULL)); 139 PetscCall(PetscOptionsInt("-n", "The column dimension of matrix F", "ex4.c", ctx->n, &(ctx->n), NULL)); 140 PetscCall(PetscOptionsInt("-matrix_format","Decide format of F matrix. 0 for stencil, 1 for random", "ex4.c", ctx->matops, &(ctx->matops), NULL)); 141 PetscCall(PetscOptionsInt("-iter","Iteration number ADMM", "ex4.c", ctx->iter, &(ctx->iter), NULL)); 142 PetscCall(PetscOptionsReal("-alpha", "The regularization multiplier. 1 default", "ex4.c", ctx->alpha, &(ctx->alpha), NULL)); 143 PetscCall(PetscOptionsReal("-epsilon", "The small constant added to |x_i| in the denominator to approximate the gradient of ||x||_1", "ex4.c", ctx->eps, &(ctx->eps), NULL)); 144 PetscCall(PetscOptionsReal("-mu", "The augmented lagrangian multiplier in ADMM", "ex4.c", ctx->mu, &(ctx->mu), NULL)); 145 PetscCall(PetscOptionsReal("-hStart", "Taylor test starting point. 1 default.", "ex4.c", ctx->hStart, &(ctx->hStart), NULL)); 146 PetscCall(PetscOptionsReal("-hFactor", "Taylor test multiplier factor. 0.5 default", "ex4.c", ctx->hFactor, &(ctx->hFactor), NULL)); 147 PetscCall(PetscOptionsReal("-hMin", "Taylor test ending condition. 1.e-3 default", "ex4.c", ctx->hMin, &(ctx->hMin), NULL)); 148 PetscCall(PetscOptionsReal("-abstol", "Absolute stopping criterion for ADMM", "ex4.c", ctx->abstol, &(ctx->abstol), NULL)); 149 PetscCall(PetscOptionsReal("-reltol", "Relative stopping criterion for ADMM", "ex4.c", ctx->reltol, &(ctx->reltol), NULL)); 150 PetscCall(PetscOptionsBool("-taylor","Flag for Taylor test. Default is true.", "ex4.c", ctx->taylor, &(ctx->taylor), NULL)); 151 PetscCall(PetscOptionsBool("-use_admm","Use the ADMM solver in this example.", "ex4.c", ctx->use_admm, &(ctx->use_admm), NULL)); 152 PetscCall(PetscOptionsEnum("-p","Norm type.", "ex4.c", NormTypes, (PetscEnum)ctx->p, (PetscEnum *) &(ctx->p), NULL)); 153 PetscOptionsEnd(); 154 /* Creating random ctx */ 155 PetscCall(PetscRandomCreate(PETSC_COMM_WORLD,&(ctx->rctx))); 156 PetscCall(PetscRandomSetFromOptions(ctx->rctx)); 157 PetscCall(CreateMatrix(ctx)); 158 PetscCall(CreateRHS(ctx)); 159 PetscCall(SetupWorkspace(ctx)); 160 PetscFunctionReturn(0); 161 } 162 163 static PetscErrorCode DestroyContext(UserCtx *ctx) 164 { 165 PetscInt i; 166 167 PetscFunctionBegin; 168 PetscCall(MatDestroy(&((*ctx)->F))); 169 PetscCall(MatDestroy(&((*ctx)->W))); 170 PetscCall(MatDestroy(&((*ctx)->Hm))); 171 PetscCall(MatDestroy(&((*ctx)->Hr))); 172 PetscCall(VecDestroy(&((*ctx)->d))); 173 for (i=0; i<NWORKLEFT; i++) { 174 PetscCall(VecDestroy(&((*ctx)->workLeft[i]))); 175 } 176 for (i=0; i<NWORKRIGHT; i++) { 177 PetscCall(VecDestroy(&((*ctx)->workRight[i]))); 178 } 179 PetscCall(PetscRandomDestroy(&((*ctx)->rctx))); 180 PetscCall(PetscFree(*ctx)); 181 PetscFunctionReturn(0); 182 } 183 184 /* compute (1/2) * ||F x - d||^2 */ 185 static PetscErrorCode ObjectiveMisfit(Tao tao, Vec x, PetscReal *J, void *_ctx) 186 { 187 UserCtx ctx = (UserCtx) _ctx; 188 Vec y; 189 190 PetscFunctionBegin; 191 y = ctx->workLeft[0]; 192 PetscCall(MatMult(ctx->F, x, y)); 193 PetscCall(VecAXPY(y, -1., ctx->d)); 194 PetscCall(VecDot(y, y, J)); 195 *J *= 0.5; 196 PetscFunctionReturn(0); 197 } 198 199 /* compute V = FTFx - FTd */ 200 static PetscErrorCode GradientMisfit(Tao tao, Vec x, Vec V, void *_ctx) 201 { 202 UserCtx ctx = (UserCtx) _ctx; 203 Vec FTFx, FTd; 204 205 PetscFunctionBegin; 206 /* work1 is A^T Ax, work2 is Ab, W is A^T A*/ 207 FTFx = ctx->workRight[0]; 208 FTd = ctx->workRight[1]; 209 PetscCall(MatMult(ctx->W,x,FTFx)); 210 PetscCall(MatMultTranspose(ctx->F, ctx->d, FTd)); 211 PetscCall(VecWAXPY(V, -1., FTd, FTFx)); 212 PetscFunctionReturn(0); 213 } 214 215 /* returns FTF */ 216 static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) 217 { 218 UserCtx ctx = (UserCtx) _ctx; 219 220 PetscFunctionBegin; 221 if (H != ctx->W) PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN)); 222 if (Hpre != ctx->W) PetscCall(MatCopy(ctx->W, Hpre, DIFFERENT_NONZERO_PATTERN)); 223 PetscFunctionReturn(0); 224 } 225 226 /* computes augment Lagrangian objective (with scaled dual): 227 * 0.5 * ||F x - d||^2 + 0.5 * mu ||x - z + u||^2 */ 228 static PetscErrorCode ObjectiveMisfitADMM(Tao tao, Vec x, PetscReal *J, void *_ctx) 229 { 230 UserCtx ctx = (UserCtx) _ctx; 231 PetscReal mu, workNorm, misfit; 232 Vec z, u, temp; 233 234 PetscFunctionBegin; 235 mu = ctx->mu; 236 z = ctx->workRight[5]; 237 u = ctx->workRight[6]; 238 temp = ctx->workRight[10]; 239 /* misfit = f(x) */ 240 PetscCall(ObjectiveMisfit(tao, x, &misfit, _ctx)); 241 PetscCall(VecCopy(x,temp)); 242 /* temp = x - z + u */ 243 PetscCall(VecAXPBYPCZ(temp,-1.,1.,1.,z,u)); 244 /* workNorm = ||x - z + u||^2 */ 245 PetscCall(VecDot(temp, temp, &workNorm)); 246 /* augment Lagrangian objective (with scaled dual): f(x) + 0.5 * mu ||x -z + u||^2 */ 247 *J = misfit + 0.5 * mu * workNorm; 248 PetscFunctionReturn(0); 249 } 250 251 /* computes FTFx - FTd mu*(x - z + u) */ 252 static PetscErrorCode GradientMisfitADMM(Tao tao, Vec x, Vec V, void *_ctx) 253 { 254 UserCtx ctx = (UserCtx) _ctx; 255 PetscReal mu; 256 Vec z, u, temp; 257 258 PetscFunctionBegin; 259 mu = ctx->mu; 260 z = ctx->workRight[5]; 261 u = ctx->workRight[6]; 262 temp = ctx->workRight[10]; 263 PetscCall(GradientMisfit(tao, x, V, _ctx)); 264 PetscCall(VecCopy(x, temp)); 265 /* temp = x - z + u */ 266 PetscCall(VecAXPBYPCZ(temp,-1.,1.,1.,z,u)); 267 /* V = FTFx - FTd mu*(x - z + u) */ 268 PetscCall(VecAXPY(V, mu, temp)); 269 PetscFunctionReturn(0); 270 } 271 272 /* returns FTF + diag(mu) */ 273 static PetscErrorCode HessianMisfitADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) 274 { 275 UserCtx ctx = (UserCtx) _ctx; 276 277 PetscFunctionBegin; 278 PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN)); 279 PetscCall(MatShift(H, ctx->mu)); 280 if (Hpre != H) { 281 PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN)); 282 } 283 PetscFunctionReturn(0); 284 } 285 286 /* computes || x ||_p (mult by 0.5 in case of NORM_2) */ 287 static PetscErrorCode ObjectiveRegularization(Tao tao, Vec x, PetscReal *J, void *_ctx) 288 { 289 UserCtx ctx = (UserCtx) _ctx; 290 PetscReal norm; 291 292 PetscFunctionBegin; 293 *J = 0; 294 PetscCall(VecNorm (x, ctx->p, &norm)); 295 if (ctx->p == NORM_2) norm = 0.5 * norm * norm; 296 *J = ctx->alpha * norm; 297 PetscFunctionReturn(0); 298 } 299 300 /* NORM_2 Case: return x 301 * NORM_1 Case: x/(|x| + eps) 302 * Else: TODO */ 303 static PetscErrorCode GradientRegularization(Tao tao, Vec x, Vec V, void *_ctx) 304 { 305 UserCtx ctx = (UserCtx) _ctx; 306 PetscReal eps = ctx->eps; 307 308 PetscFunctionBegin; 309 if (ctx->p == NORM_2) { 310 PetscCall(VecCopy(x, V)); 311 } else if (ctx->p == NORM_1) { 312 PetscCall(VecCopy(x, ctx->workRight[1])); 313 PetscCall(VecAbs(ctx->workRight[1])); 314 PetscCall(VecShift(ctx->workRight[1], eps)); 315 PetscCall(VecPointwiseDivide(V, x, ctx->workRight[1])); 316 } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2"); 317 PetscFunctionReturn(0); 318 } 319 320 /* NORM_2 Case: returns diag(mu) 321 * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) */ 322 static PetscErrorCode HessianRegularization(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) 323 { 324 UserCtx ctx = (UserCtx) _ctx; 325 PetscReal eps = ctx->eps; 326 Vec copy1,copy2,copy3; 327 328 PetscFunctionBegin; 329 if (ctx->p == NORM_2) { 330 /* Identity matrix scaled by mu */ 331 PetscCall(MatZeroEntries(H)); 332 PetscCall(MatShift(H,ctx->mu)); 333 if (Hpre != H) { 334 PetscCall(MatZeroEntries(Hpre)); 335 PetscCall(MatShift(Hpre,ctx->mu)); 336 } 337 } else if (ctx->p == NORM_1) { 338 /* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)) */ 339 copy1 = ctx->workRight[1]; 340 copy2 = ctx->workRight[2]; 341 copy3 = ctx->workRight[3]; 342 /* copy1 : 1/sqrt(x_i^2 + eps) */ 343 PetscCall(VecCopy(x, copy1)); 344 PetscCall(VecPow(copy1,2)); 345 PetscCall(VecShift(copy1, eps)); 346 PetscCall(VecSqrtAbs(copy1)); 347 PetscCall(VecReciprocal(copy1)); 348 /* copy2: x_i^2.*/ 349 PetscCall(VecCopy(x,copy2)); 350 PetscCall(VecPow(copy2,2)); 351 /* copy3: abs(x_i^2 + eps) */ 352 PetscCall(VecCopy(x,copy3)); 353 PetscCall(VecPow(copy3,2)); 354 PetscCall(VecShift(copy3, eps)); 355 PetscCall(VecAbs(copy3)); 356 /* copy2: 1 - x_i^2/abs(x_i^2 + eps) */ 357 PetscCall(VecPointwiseDivide(copy2, copy2,copy3)); 358 PetscCall(VecScale(copy2, -1.)); 359 PetscCall(VecShift(copy2, 1.)); 360 PetscCall(VecAXPY(copy1,1.,copy2)); 361 PetscCall(VecScale(copy1, ctx->mu)); 362 PetscCall(MatZeroEntries(H)); 363 PetscCall(MatDiagonalSet(H, copy1,INSERT_VALUES)); 364 if (Hpre != H) { 365 PetscCall(MatZeroEntries(Hpre)); 366 PetscCall(MatDiagonalSet(Hpre, copy1,INSERT_VALUES)); 367 } 368 } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2"); 369 PetscFunctionReturn(0); 370 } 371 372 /* NORM_2 Case: 0.5 || x ||_2 + 0.5 * mu * ||x + u - z||^2 373 * Else : || x ||_2 + 0.5 * mu * ||x + u - z||^2 */ 374 static PetscErrorCode ObjectiveRegularizationADMM(Tao tao, Vec z, PetscReal *J, void *_ctx) 375 { 376 UserCtx ctx = (UserCtx) _ctx; 377 PetscReal mu, workNorm, reg; 378 Vec x, u, temp; 379 380 PetscFunctionBegin; 381 mu = ctx->mu; 382 x = ctx->workRight[4]; 383 u = ctx->workRight[6]; 384 temp = ctx->workRight[10]; 385 PetscCall(ObjectiveRegularization(tao, z, ®, _ctx)); 386 PetscCall(VecCopy(z,temp)); 387 /* temp = x + u -z */ 388 PetscCall(VecAXPBYPCZ(temp,1.,1.,-1.,x,u)); 389 /* workNorm = ||x + u - z ||^2 */ 390 PetscCall(VecDot(temp, temp, &workNorm)); 391 *J = reg + 0.5 * mu * workNorm; 392 PetscFunctionReturn(0); 393 } 394 395 /* NORM_2 Case: x - mu*(x + u - z) 396 * NORM_1 Case: x/(|x| + eps) - mu*(x + u - z) 397 * Else: TODO */ 398 static PetscErrorCode GradientRegularizationADMM(Tao tao, Vec z, Vec V, void *_ctx) 399 { 400 UserCtx ctx = (UserCtx) _ctx; 401 PetscReal mu; 402 Vec x, u, temp; 403 404 PetscFunctionBegin; 405 mu = ctx->mu; 406 x = ctx->workRight[4]; 407 u = ctx->workRight[6]; 408 temp = ctx->workRight[10]; 409 PetscCall(GradientRegularization(tao, z, V, _ctx)); 410 PetscCall(VecCopy(z, temp)); 411 /* temp = x + u -z */ 412 PetscCall(VecAXPBYPCZ(temp,1.,1.,-1.,x,u)); 413 PetscCall(VecAXPY(V, -mu, temp)); 414 PetscFunctionReturn(0); 415 } 416 417 /* NORM_2 Case: returns diag(mu) 418 * NORM_1 Case: FTF + diag(mu) */ 419 static PetscErrorCode HessianRegularizationADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) 420 { 421 UserCtx ctx = (UserCtx) _ctx; 422 423 PetscFunctionBegin; 424 if (ctx->p == NORM_2) { 425 /* Identity matrix scaled by mu */ 426 PetscCall(MatZeroEntries(H)); 427 PetscCall(MatShift(H,ctx->mu)); 428 if (Hpre != H) { 429 PetscCall(MatZeroEntries(Hpre)); 430 PetscCall(MatShift(Hpre,ctx->mu)); 431 } 432 } else if (ctx->p == NORM_1) { 433 PetscCall(HessianMisfit(tao, x, H, Hpre, (void*) ctx)); 434 PetscCall(MatShift(H, ctx->mu)); 435 if (Hpre != H) PetscCall(MatShift(Hpre, ctx->mu)); 436 } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2"); 437 PetscFunctionReturn(0); 438 } 439 440 /* NORM_2 Case : (1/2) * ||F x - d||^2 + 0.5 * || x ||_p 441 * NORM_1 Case : (1/2) * ||F x - d||^2 + || x ||_p */ 442 static PetscErrorCode ObjectiveComplete(Tao tao, Vec x, PetscReal *J, void *ctx) 443 { 444 PetscReal Jm, Jr; 445 446 PetscFunctionBegin; 447 PetscCall(ObjectiveMisfit(tao, x, &Jm, ctx)); 448 PetscCall(ObjectiveRegularization(tao, x, &Jr, ctx)); 449 *J = Jm + Jr; 450 PetscFunctionReturn(0); 451 } 452 453 /* NORM_2 Case: FTFx - FTd + x 454 * NORM_1 Case: FTFx - FTd + x/(|x| + eps) */ 455 static PetscErrorCode GradientComplete(Tao tao, Vec x, Vec V, void *ctx) 456 { 457 UserCtx cntx = (UserCtx) ctx; 458 459 PetscFunctionBegin; 460 PetscCall(GradientMisfit(tao, x, cntx->workRight[2], ctx)); 461 PetscCall(GradientRegularization(tao, x, cntx->workRight[3], ctx)); 462 PetscCall(VecWAXPY(V,1,cntx->workRight[2],cntx->workRight[3])); 463 PetscFunctionReturn(0); 464 } 465 466 /* NORM_2 Case: diag(mu) + FTF 467 * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) + FTF */ 468 static PetscErrorCode HessianComplete(Tao tao, Vec x, Mat H, Mat Hpre, void *ctx) 469 { 470 Mat tempH; 471 472 PetscFunctionBegin; 473 PetscCall(MatDuplicate(H, MAT_SHARE_NONZERO_PATTERN, &tempH)); 474 PetscCall(HessianMisfit(tao, x, H, H, ctx)); 475 PetscCall(HessianRegularization(tao, x, tempH, tempH, ctx)); 476 PetscCall(MatAXPY(H, 1., tempH, DIFFERENT_NONZERO_PATTERN)); 477 if (Hpre != H) { 478 PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN)); 479 } 480 PetscCall(MatDestroy(&tempH)); 481 PetscFunctionReturn(0); 482 } 483 484 static PetscErrorCode TaoSolveADMM(UserCtx ctx, Vec x) 485 { 486 PetscInt i; 487 PetscReal u_norm, r_norm, s_norm, primal, dual, x_norm, z_norm; 488 Tao tao1,tao2; 489 Vec xk,z,u,diff,zold,zdiff,temp; 490 PetscReal mu; 491 492 PetscFunctionBegin; 493 xk = ctx->workRight[4]; 494 z = ctx->workRight[5]; 495 u = ctx->workRight[6]; 496 diff = ctx->workRight[7]; 497 zold = ctx->workRight[8]; 498 zdiff = ctx->workRight[9]; 499 temp = ctx->workRight[11]; 500 mu = ctx->mu; 501 PetscCall(VecSet(u, 0.)); 502 PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao1)); 503 PetscCall(TaoSetType(tao1,TAONLS)); 504 PetscCall(TaoSetObjective(tao1, ObjectiveMisfitADMM, (void*) ctx)); 505 PetscCall(TaoSetGradient(tao1, NULL, GradientMisfitADMM, (void*) ctx)); 506 PetscCall(TaoSetHessian(tao1, ctx->Hm, ctx->Hm, HessianMisfitADMM, (void*) ctx)); 507 PetscCall(VecSet(xk, 0.)); 508 PetscCall(TaoSetSolution(tao1, xk)); 509 PetscCall(TaoSetOptionsPrefix(tao1, "misfit_")); 510 PetscCall(TaoSetFromOptions(tao1)); 511 PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao2)); 512 if (ctx->p == NORM_2) { 513 PetscCall(TaoSetType(tao2,TAONLS)); 514 PetscCall(TaoSetObjective(tao2, ObjectiveRegularizationADMM, (void*) ctx)); 515 PetscCall(TaoSetGradient(tao2, NULL, GradientRegularizationADMM, (void*) ctx)); 516 PetscCall(TaoSetHessian(tao2, ctx->Hr, ctx->Hr, HessianRegularizationADMM, (void*) ctx)); 517 } 518 PetscCall(VecSet(z, 0.)); 519 PetscCall(TaoSetSolution(tao2, z)); 520 PetscCall(TaoSetOptionsPrefix(tao2, "reg_")); 521 PetscCall(TaoSetFromOptions(tao2)); 522 523 for (i=0; i<ctx->iter; i++) { 524 PetscCall(VecCopy(z,zold)); 525 PetscCall(TaoSolve(tao1)); /* Updates xk */ 526 if (ctx->p == NORM_1) { 527 PetscCall(VecWAXPY(temp,1.,xk,u)); 528 PetscCall(TaoSoftThreshold(temp,-ctx->alpha/mu,ctx->alpha/mu,z)); 529 } else { 530 PetscCall(TaoSolve(tao2)); /* Update zk */ 531 } 532 /* u = u + xk -z */ 533 PetscCall(VecAXPBYPCZ(u,1.,-1.,1.,xk,z)); 534 /* r_norm : norm(x-z) */ 535 PetscCall(VecWAXPY(diff,-1.,z,xk)); 536 PetscCall(VecNorm(diff,NORM_2,&r_norm)); 537 /* s_norm : norm(-mu(z-zold)) */ 538 PetscCall(VecWAXPY(zdiff, -1.,zold,z)); 539 PetscCall(VecNorm(zdiff,NORM_2,&s_norm)); 540 s_norm = s_norm * mu; 541 /* primal : sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z))*/ 542 PetscCall(VecNorm(xk,NORM_2,&x_norm)); 543 PetscCall(VecNorm(z,NORM_2,&z_norm)); 544 primal = PetscSqrtReal(ctx->n)*ctx->abstol + ctx->reltol*PetscMax(x_norm,z_norm); 545 /* Duality : sqrt(n)*ABSTOL + RELTOL*norm(mu*u)*/ 546 PetscCall(VecNorm(u,NORM_2,&u_norm)); 547 dual = PetscSqrtReal(ctx->n)*ctx->abstol + ctx->reltol*u_norm*mu; 548 PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao1),"Iter %" PetscInt_FMT " : ||x-z||: %g, mu*||z-zold||: %g\n", i, (double) r_norm, (double) s_norm)); 549 if (r_norm < primal && s_norm < dual) break; 550 } 551 PetscCall(VecCopy(xk, x)); 552 PetscCall(TaoDestroy(&tao1)); 553 PetscCall(TaoDestroy(&tao2)); 554 PetscFunctionReturn(0); 555 } 556 557 /* Second order Taylor remainder convergence test */ 558 static PetscErrorCode TaylorTest(UserCtx ctx, Tao tao, Vec x, PetscReal *C) 559 { 560 PetscReal h,J,temp; 561 PetscInt i,j; 562 PetscInt numValues; 563 PetscReal Jx,Jxhat_comp,Jxhat_pred; 564 PetscReal *Js, *hs; 565 PetscReal gdotdx; 566 PetscReal minrate = PETSC_MAX_REAL; 567 MPI_Comm comm = PetscObjectComm((PetscObject)x); 568 Vec g, dx, xhat; 569 570 PetscFunctionBegin; 571 PetscCall(VecDuplicate(x, &g)); 572 PetscCall(VecDuplicate(x, &xhat)); 573 /* choose a perturbation direction */ 574 PetscCall(VecDuplicate(x, &dx)); 575 PetscCall(VecSetRandom(dx,ctx->rctx)); 576 /* evaluate objective at x: J(x) */ 577 PetscCall(TaoComputeObjective(tao, x, &Jx)); 578 /* evaluate gradient at x, save in vector g */ 579 PetscCall(TaoComputeGradient(tao, x, g)); 580 PetscCall(VecDot(g, dx, &gdotdx)); 581 582 for (numValues=0, h=ctx->hStart; h>=ctx->hMin; h*=ctx->hFactor) numValues++; 583 PetscCall(PetscCalloc2(numValues, &Js, numValues, &hs)); 584 for (i=0, h=ctx->hStart; h>=ctx->hMin; h*=ctx->hFactor, i++) { 585 PetscCall(VecWAXPY(xhat, h, dx, x)); 586 PetscCall(TaoComputeObjective(tao, xhat, &Jxhat_comp)); 587 /* J(\hat(x)) \approx J(x) + g^T (xhat - x) = J(x) + h * g^T dx */ 588 Jxhat_pred = Jx + h * gdotdx; 589 /* Vector to dJdm scalar? Dot?*/ 590 J = PetscAbsReal(Jxhat_comp - Jxhat_pred); 591 PetscCall(PetscPrintf (comm, "J(xhat): %g, predicted: %g, diff %g\n", (double) Jxhat_comp,(double) Jxhat_pred, (double) J)); 592 Js[i] = J; 593 hs[i] = h; 594 } 595 for (j=1; j<numValues; j++) { 596 temp = PetscLogReal(Js[j]/Js[j-1]) / PetscLogReal (hs[j]/hs[j-1]); 597 PetscCall(PetscPrintf (comm, "Convergence rate step %" PetscInt_FMT ": %g\n", j-1, (double) temp)); 598 minrate = PetscMin(minrate, temp); 599 } 600 /* If O is not ~2, then the test is wrong */ 601 PetscCall(PetscFree2(Js, hs)); 602 *C = minrate; 603 PetscCall(VecDestroy(&dx)); 604 PetscCall(VecDestroy(&xhat)); 605 PetscCall(VecDestroy(&g)); 606 PetscFunctionReturn(0); 607 } 608 609 int main(int argc, char ** argv) 610 { 611 UserCtx ctx; 612 Tao tao; 613 Vec x; 614 Mat H; 615 616 PetscFunctionBeginUser; 617 PetscCall(PetscInitialize(&argc, &argv, NULL,help)); 618 PetscCall(PetscNew(&ctx)); 619 PetscCall(ConfigureContext(ctx)); 620 /* Define two functions that could pass as objectives to TaoSetObjective(): one 621 * for the misfit component, and one for the regularization component */ 622 /* ObjectiveMisfit() and ObjectiveRegularization() */ 623 624 /* Define a single function that calls both components adds them together: the complete objective, 625 * in the absence of a Tao implementation that handles separability */ 626 /* ObjectiveComplete() */ 627 PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao)); 628 PetscCall(TaoSetType(tao,TAONM)); 629 PetscCall(TaoSetObjective(tao, ObjectiveComplete, (void*) ctx)); 630 PetscCall(TaoSetGradient(tao, NULL, GradientComplete, (void*) ctx)); 631 PetscCall(MatDuplicate(ctx->W, MAT_SHARE_NONZERO_PATTERN, &H)); 632 PetscCall(TaoSetHessian(tao, H, H, HessianComplete, (void*) ctx)); 633 PetscCall(MatCreateVecs(ctx->F, NULL, &x)); 634 PetscCall(VecSet(x, 0.)); 635 PetscCall(TaoSetSolution(tao, x)); 636 PetscCall(TaoSetFromOptions(tao)); 637 if (ctx->use_admm) PetscCall(TaoSolveADMM(ctx,x)); 638 else PetscCall(TaoSolve(tao)); 639 /* examine solution */ 640 PetscCall(VecViewFromOptions(x, NULL, "-view_sol")); 641 if (ctx->taylor) { 642 PetscReal rate; 643 PetscCall(TaylorTest(ctx, tao, x, &rate)); 644 } 645 PetscCall(MatDestroy(&H)); 646 PetscCall(TaoDestroy(&tao)); 647 PetscCall(VecDestroy(&x)); 648 PetscCall(DestroyContext(&ctx)); 649 PetscCall(PetscFinalize()); 650 return 0; 651 } 652 653 /*TEST 654 655 build: 656 requires: !complex 657 658 test: 659 suffix: 0 660 args: 661 662 test: 663 suffix: l1_1 664 args: -p 1 -tao_type lmvm -alpha 1. -epsilon 1.e-7 -m 64 -n 64 -view_sol -matrix_format 1 665 666 test: 667 suffix: hessian_1 668 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nls 669 670 test: 671 suffix: hessian_2 672 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nls 673 674 test: 675 suffix: nm_1 676 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nm -tao_max_it 50 677 678 test: 679 suffix: nm_2 680 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nm -tao_max_it 50 681 682 test: 683 suffix: lmvm_1 684 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type lmvm -tao_max_it 40 685 686 test: 687 suffix: lmvm_2 688 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type lmvm -tao_max_it 15 689 690 test: 691 suffix: soft_threshold_admm_1 692 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm 693 694 test: 695 suffix: hessian_admm_1 696 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nls -misfit_tao_type nls 697 698 test: 699 suffix: hessian_admm_2 700 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nls -misfit_tao_type nls 701 702 test: 703 suffix: nm_admm_1 704 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nm -misfit_tao_type nm 705 706 test: 707 suffix: nm_admm_2 708 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nm -misfit_tao_type nm -iter 7 709 710 test: 711 suffix: lmvm_admm_1 712 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm 713 714 test: 715 suffix: lmvm_admm_2 716 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm 717 718 TEST*/ 719