1 static char help[] = "Simple example to test separable objective optimizers.\n"; 2 3 #include <petsc.h> 4 #include <petsctao.h> 5 #include <petscvec.h> 6 #include <petscmath.h> 7 8 #define NWORKLEFT 4 9 #define NWORKRIGHT 12 10 11 typedef struct _UserCtx { 12 PetscInt m; /* The row dimension of F */ 13 PetscInt n; /* The column dimension of F */ 14 PetscInt matops; /* Matrix format. 0 for stencil, 1 for random */ 15 PetscInt iter; /* Number of iterations for ADMM */ 16 PetscReal hStart; /* Starting point for Taylor test */ 17 PetscReal hFactor; /* Taylor test step factor */ 18 PetscReal hMin; /* Taylor test end goal */ 19 PetscReal alpha; /* regularization constant applied to || x ||_p */ 20 PetscReal eps; /* small constant for approximating gradient of || x ||_1 */ 21 PetscReal mu; /* the augmented Lagrangian term in ADMM */ 22 PetscReal abstol; 23 PetscReal reltol; 24 Mat F; /* matrix in least squares component $(1/2) * || F x - d ||_2^2$ */ 25 Mat W; /* Workspace matrix. ATA */ 26 Mat Hm; /* Hessian Misfit*/ 27 Mat Hr; /* Hessian Reg*/ 28 Vec d; /* RHS in least squares component $(1/2) * || F x - d ||_2^2$ */ 29 Vec workLeft[NWORKLEFT]; /* Workspace for temporary vec */ 30 Vec workRight[NWORKRIGHT]; /* Workspace for temporary vec */ 31 NormType p; 32 PetscRandom rctx; 33 PetscBool taylor; /* Flag to determine whether to run Taylor test or not */ 34 PetscBool use_admm; /* Flag to determine whether to run Taylor test or not */ 35 } *UserCtx; 36 37 static PetscErrorCode CreateRHS(UserCtx ctx) 38 { 39 PetscFunctionBegin; 40 /* build the rhs d in ctx */ 41 PetscCall(VecCreate(PETSC_COMM_WORLD, &(ctx->d))); 42 PetscCall(VecSetSizes(ctx->d, PETSC_DECIDE, ctx->m)); 43 PetscCall(VecSetFromOptions(ctx->d)); 44 PetscCall(VecSetRandom(ctx->d, ctx->rctx)); 45 PetscFunctionReturn(PETSC_SUCCESS); 46 } 47 48 static PetscErrorCode CreateMatrix(UserCtx ctx) 49 { 50 PetscInt Istart, Iend, i, j, Ii, gridN, I_n, I_s, I_e, I_w; 51 #if defined(PETSC_USE_LOG) 52 PetscLogStage stage; 53 #endif 54 55 PetscFunctionBegin; 56 /* build the matrix F in ctx */ 57 PetscCall(MatCreate(PETSC_COMM_WORLD, &(ctx->F))); 58 PetscCall(MatSetSizes(ctx->F, PETSC_DECIDE, PETSC_DECIDE, ctx->m, ctx->n)); 59 PetscCall(MatSetType(ctx->F, MATAIJ)); /* TODO: Decide specific SetType other than dummy*/ 60 PetscCall(MatMPIAIJSetPreallocation(ctx->F, 5, NULL, 5, NULL)); /*TODO: some number other than 5?*/ 61 PetscCall(MatSeqAIJSetPreallocation(ctx->F, 5, NULL)); 62 PetscCall(MatSetUp(ctx->F)); 63 PetscCall(MatGetOwnershipRange(ctx->F, &Istart, &Iend)); 64 PetscCall(PetscLogStageRegister("Assembly", &stage)); 65 PetscCall(PetscLogStagePush(stage)); 66 67 /* Set matrix elements in 2-D five point stencil format. */ 68 if (!(ctx->matops)) { 69 PetscCheck(ctx->m == ctx->n, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Stencil matrix must be square"); 70 gridN = (PetscInt)PetscSqrtReal((PetscReal)ctx->m); 71 PetscCheck(gridN * gridN == ctx->m, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Number of rows must be square"); 72 for (Ii = Istart; Ii < Iend; Ii++) { 73 i = Ii / gridN; 74 j = Ii % gridN; 75 I_n = i * gridN + j + 1; 76 if (j + 1 >= gridN) I_n = -1; 77 I_s = i * gridN + j - 1; 78 if (j - 1 < 0) I_s = -1; 79 I_e = (i + 1) * gridN + j; 80 if (i + 1 >= gridN) I_e = -1; 81 I_w = (i - 1) * gridN + j; 82 if (i - 1 < 0) I_w = -1; 83 PetscCall(MatSetValue(ctx->F, Ii, Ii, 4., INSERT_VALUES)); 84 PetscCall(MatSetValue(ctx->F, Ii, I_n, -1., INSERT_VALUES)); 85 PetscCall(MatSetValue(ctx->F, Ii, I_s, -1., INSERT_VALUES)); 86 PetscCall(MatSetValue(ctx->F, Ii, I_e, -1., INSERT_VALUES)); 87 PetscCall(MatSetValue(ctx->F, Ii, I_w, -1., INSERT_VALUES)); 88 } 89 } else PetscCall(MatSetRandom(ctx->F, ctx->rctx)); 90 PetscCall(MatAssemblyBegin(ctx->F, MAT_FINAL_ASSEMBLY)); 91 PetscCall(MatAssemblyEnd(ctx->F, MAT_FINAL_ASSEMBLY)); 92 PetscCall(PetscLogStagePop()); 93 /* Stencil matrix is symmetric. Setting symmetric flag for ICC/Cholesky preconditioner */ 94 if (!(ctx->matops)) PetscCall(MatSetOption(ctx->F, MAT_SYMMETRIC, PETSC_TRUE)); 95 PetscCall(MatTransposeMatMult(ctx->F, ctx->F, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(ctx->W))); 96 /* Setup Hessian Workspace in same shape as W */ 97 PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &(ctx->Hm))); 98 PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &(ctx->Hr))); 99 PetscFunctionReturn(PETSC_SUCCESS); 100 } 101 102 static PetscErrorCode SetupWorkspace(UserCtx ctx) 103 { 104 PetscInt i; 105 106 PetscFunctionBegin; 107 PetscCall(MatCreateVecs(ctx->F, &ctx->workLeft[0], &ctx->workRight[0])); 108 for (i = 1; i < NWORKLEFT; i++) PetscCall(VecDuplicate(ctx->workLeft[0], &(ctx->workLeft[i]))); 109 for (i = 1; i < NWORKRIGHT; i++) PetscCall(VecDuplicate(ctx->workRight[0], &(ctx->workRight[i]))); 110 PetscFunctionReturn(PETSC_SUCCESS); 111 } 112 113 static PetscErrorCode ConfigureContext(UserCtx ctx) 114 { 115 PetscFunctionBegin; 116 ctx->m = 16; 117 ctx->n = 16; 118 ctx->eps = 1.e-3; 119 ctx->abstol = 1.e-4; 120 ctx->reltol = 1.e-2; 121 ctx->hStart = 1.; 122 ctx->hMin = 1.e-3; 123 ctx->hFactor = 0.5; 124 ctx->alpha = 1.; 125 ctx->mu = 1.0; 126 ctx->matops = 0; 127 ctx->iter = 10; 128 ctx->p = NORM_2; 129 ctx->taylor = PETSC_TRUE; 130 ctx->use_admm = PETSC_FALSE; 131 PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "ex4.c"); 132 PetscCall(PetscOptionsInt("-m", "The row dimension of matrix F", "ex4.c", ctx->m, &(ctx->m), NULL)); 133 PetscCall(PetscOptionsInt("-n", "The column dimension of matrix F", "ex4.c", ctx->n, &(ctx->n), NULL)); 134 PetscCall(PetscOptionsInt("-matrix_format", "Decide format of F matrix. 0 for stencil, 1 for random", "ex4.c", ctx->matops, &(ctx->matops), NULL)); 135 PetscCall(PetscOptionsInt("-iter", "Iteration number ADMM", "ex4.c", ctx->iter, &(ctx->iter), NULL)); 136 PetscCall(PetscOptionsReal("-alpha", "The regularization multiplier. 1 default", "ex4.c", ctx->alpha, &(ctx->alpha), NULL)); 137 PetscCall(PetscOptionsReal("-epsilon", "The small constant added to |x_i| in the denominator to approximate the gradient of ||x||_1", "ex4.c", ctx->eps, &(ctx->eps), NULL)); 138 PetscCall(PetscOptionsReal("-mu", "The augmented lagrangian multiplier in ADMM", "ex4.c", ctx->mu, &(ctx->mu), NULL)); 139 PetscCall(PetscOptionsReal("-hStart", "Taylor test starting point. 1 default.", "ex4.c", ctx->hStart, &(ctx->hStart), NULL)); 140 PetscCall(PetscOptionsReal("-hFactor", "Taylor test multiplier factor. 0.5 default", "ex4.c", ctx->hFactor, &(ctx->hFactor), NULL)); 141 PetscCall(PetscOptionsReal("-hMin", "Taylor test ending condition. 1.e-3 default", "ex4.c", ctx->hMin, &(ctx->hMin), NULL)); 142 PetscCall(PetscOptionsReal("-abstol", "Absolute stopping criterion for ADMM", "ex4.c", ctx->abstol, &(ctx->abstol), NULL)); 143 PetscCall(PetscOptionsReal("-reltol", "Relative stopping criterion for ADMM", "ex4.c", ctx->reltol, &(ctx->reltol), NULL)); 144 PetscCall(PetscOptionsBool("-taylor", "Flag for Taylor test. Default is true.", "ex4.c", ctx->taylor, &(ctx->taylor), NULL)); 145 PetscCall(PetscOptionsBool("-use_admm", "Use the ADMM solver in this example.", "ex4.c", ctx->use_admm, &(ctx->use_admm), NULL)); 146 PetscCall(PetscOptionsEnum("-p", "Norm type.", "ex4.c", NormTypes, (PetscEnum)ctx->p, (PetscEnum *)&(ctx->p), NULL)); 147 PetscOptionsEnd(); 148 /* Creating random ctx */ 149 PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &(ctx->rctx))); 150 PetscCall(PetscRandomSetFromOptions(ctx->rctx)); 151 PetscCall(CreateMatrix(ctx)); 152 PetscCall(CreateRHS(ctx)); 153 PetscCall(SetupWorkspace(ctx)); 154 PetscFunctionReturn(PETSC_SUCCESS); 155 } 156 157 static PetscErrorCode DestroyContext(UserCtx *ctx) 158 { 159 PetscInt i; 160 161 PetscFunctionBegin; 162 PetscCall(MatDestroy(&((*ctx)->F))); 163 PetscCall(MatDestroy(&((*ctx)->W))); 164 PetscCall(MatDestroy(&((*ctx)->Hm))); 165 PetscCall(MatDestroy(&((*ctx)->Hr))); 166 PetscCall(VecDestroy(&((*ctx)->d))); 167 for (i = 0; i < NWORKLEFT; i++) PetscCall(VecDestroy(&((*ctx)->workLeft[i]))); 168 for (i = 0; i < NWORKRIGHT; i++) PetscCall(VecDestroy(&((*ctx)->workRight[i]))); 169 PetscCall(PetscRandomDestroy(&((*ctx)->rctx))); 170 PetscCall(PetscFree(*ctx)); 171 PetscFunctionReturn(PETSC_SUCCESS); 172 } 173 174 /* compute (1/2) * ||F x - d||^2 */ 175 static PetscErrorCode ObjectiveMisfit(Tao tao, Vec x, PetscReal *J, void *_ctx) 176 { 177 UserCtx ctx = (UserCtx)_ctx; 178 Vec y; 179 180 PetscFunctionBegin; 181 y = ctx->workLeft[0]; 182 PetscCall(MatMult(ctx->F, x, y)); 183 PetscCall(VecAXPY(y, -1., ctx->d)); 184 PetscCall(VecDot(y, y, J)); 185 *J *= 0.5; 186 PetscFunctionReturn(PETSC_SUCCESS); 187 } 188 189 /* compute V = FTFx - FTd */ 190 static PetscErrorCode GradientMisfit(Tao tao, Vec x, Vec V, void *_ctx) 191 { 192 UserCtx ctx = (UserCtx)_ctx; 193 Vec FTFx, FTd; 194 195 PetscFunctionBegin; 196 /* work1 is A^T Ax, work2 is Ab, W is A^T A*/ 197 FTFx = ctx->workRight[0]; 198 FTd = ctx->workRight[1]; 199 PetscCall(MatMult(ctx->W, x, FTFx)); 200 PetscCall(MatMultTranspose(ctx->F, ctx->d, FTd)); 201 PetscCall(VecWAXPY(V, -1., FTd, FTFx)); 202 PetscFunctionReturn(PETSC_SUCCESS); 203 } 204 205 /* returns FTF */ 206 static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) 207 { 208 UserCtx ctx = (UserCtx)_ctx; 209 210 PetscFunctionBegin; 211 if (H != ctx->W) PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN)); 212 if (Hpre != ctx->W) PetscCall(MatCopy(ctx->W, Hpre, DIFFERENT_NONZERO_PATTERN)); 213 PetscFunctionReturn(PETSC_SUCCESS); 214 } 215 216 /* computes augment Lagrangian objective (with scaled dual): 217 * 0.5 * ||F x - d||^2 + 0.5 * mu ||x - z + u||^2 */ 218 static PetscErrorCode ObjectiveMisfitADMM(Tao tao, Vec x, PetscReal *J, void *_ctx) 219 { 220 UserCtx ctx = (UserCtx)_ctx; 221 PetscReal mu, workNorm, misfit; 222 Vec z, u, temp; 223 224 PetscFunctionBegin; 225 mu = ctx->mu; 226 z = ctx->workRight[5]; 227 u = ctx->workRight[6]; 228 temp = ctx->workRight[10]; 229 /* misfit = f(x) */ 230 PetscCall(ObjectiveMisfit(tao, x, &misfit, _ctx)); 231 PetscCall(VecCopy(x, temp)); 232 /* temp = x - z + u */ 233 PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u)); 234 /* workNorm = ||x - z + u||^2 */ 235 PetscCall(VecDot(temp, temp, &workNorm)); 236 /* augment Lagrangian objective (with scaled dual): f(x) + 0.5 * mu ||x -z + u||^2 */ 237 *J = misfit + 0.5 * mu * workNorm; 238 PetscFunctionReturn(PETSC_SUCCESS); 239 } 240 241 /* computes FTFx - FTd mu*(x - z + u) */ 242 static PetscErrorCode GradientMisfitADMM(Tao tao, Vec x, Vec V, void *_ctx) 243 { 244 UserCtx ctx = (UserCtx)_ctx; 245 PetscReal mu; 246 Vec z, u, temp; 247 248 PetscFunctionBegin; 249 mu = ctx->mu; 250 z = ctx->workRight[5]; 251 u = ctx->workRight[6]; 252 temp = ctx->workRight[10]; 253 PetscCall(GradientMisfit(tao, x, V, _ctx)); 254 PetscCall(VecCopy(x, temp)); 255 /* temp = x - z + u */ 256 PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u)); 257 /* V = FTFx - FTd mu*(x - z + u) */ 258 PetscCall(VecAXPY(V, mu, temp)); 259 PetscFunctionReturn(PETSC_SUCCESS); 260 } 261 262 /* returns FTF + diag(mu) */ 263 static PetscErrorCode HessianMisfitADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) 264 { 265 UserCtx ctx = (UserCtx)_ctx; 266 267 PetscFunctionBegin; 268 PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN)); 269 PetscCall(MatShift(H, ctx->mu)); 270 if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN)); 271 PetscFunctionReturn(PETSC_SUCCESS); 272 } 273 274 /* computes || x ||_p (mult by 0.5 in case of NORM_2) */ 275 static PetscErrorCode ObjectiveRegularization(Tao tao, Vec x, PetscReal *J, void *_ctx) 276 { 277 UserCtx ctx = (UserCtx)_ctx; 278 PetscReal norm; 279 280 PetscFunctionBegin; 281 *J = 0; 282 PetscCall(VecNorm(x, ctx->p, &norm)); 283 if (ctx->p == NORM_2) norm = 0.5 * norm * norm; 284 *J = ctx->alpha * norm; 285 PetscFunctionReturn(PETSC_SUCCESS); 286 } 287 288 /* NORM_2 Case: return x 289 * NORM_1 Case: x/(|x| + eps) 290 * Else: TODO */ 291 static PetscErrorCode GradientRegularization(Tao tao, Vec x, Vec V, void *_ctx) 292 { 293 UserCtx ctx = (UserCtx)_ctx; 294 PetscReal eps = ctx->eps; 295 296 PetscFunctionBegin; 297 if (ctx->p == NORM_2) { 298 PetscCall(VecCopy(x, V)); 299 } else if (ctx->p == NORM_1) { 300 PetscCall(VecCopy(x, ctx->workRight[1])); 301 PetscCall(VecAbs(ctx->workRight[1])); 302 PetscCall(VecShift(ctx->workRight[1], eps)); 303 PetscCall(VecPointwiseDivide(V, x, ctx->workRight[1])); 304 } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2"); 305 PetscFunctionReturn(PETSC_SUCCESS); 306 } 307 308 /* NORM_2 Case: returns diag(mu) 309 * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) */ 310 static PetscErrorCode HessianRegularization(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) 311 { 312 UserCtx ctx = (UserCtx)_ctx; 313 PetscReal eps = ctx->eps; 314 Vec copy1, copy2, copy3; 315 316 PetscFunctionBegin; 317 if (ctx->p == NORM_2) { 318 /* Identity matrix scaled by mu */ 319 PetscCall(MatZeroEntries(H)); 320 PetscCall(MatShift(H, ctx->mu)); 321 if (Hpre != H) { 322 PetscCall(MatZeroEntries(Hpre)); 323 PetscCall(MatShift(Hpre, ctx->mu)); 324 } 325 } else if (ctx->p == NORM_1) { 326 /* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)) */ 327 copy1 = ctx->workRight[1]; 328 copy2 = ctx->workRight[2]; 329 copy3 = ctx->workRight[3]; 330 /* copy1 : 1/sqrt(x_i^2 + eps) */ 331 PetscCall(VecCopy(x, copy1)); 332 PetscCall(VecPow(copy1, 2)); 333 PetscCall(VecShift(copy1, eps)); 334 PetscCall(VecSqrtAbs(copy1)); 335 PetscCall(VecReciprocal(copy1)); 336 /* copy2: x_i^2.*/ 337 PetscCall(VecCopy(x, copy2)); 338 PetscCall(VecPow(copy2, 2)); 339 /* copy3: abs(x_i^2 + eps) */ 340 PetscCall(VecCopy(x, copy3)); 341 PetscCall(VecPow(copy3, 2)); 342 PetscCall(VecShift(copy3, eps)); 343 PetscCall(VecAbs(copy3)); 344 /* copy2: 1 - x_i^2/abs(x_i^2 + eps) */ 345 PetscCall(VecPointwiseDivide(copy2, copy2, copy3)); 346 PetscCall(VecScale(copy2, -1.)); 347 PetscCall(VecShift(copy2, 1.)); 348 PetscCall(VecAXPY(copy1, 1., copy2)); 349 PetscCall(VecScale(copy1, ctx->mu)); 350 PetscCall(MatZeroEntries(H)); 351 PetscCall(MatDiagonalSet(H, copy1, INSERT_VALUES)); 352 if (Hpre != H) { 353 PetscCall(MatZeroEntries(Hpre)); 354 PetscCall(MatDiagonalSet(Hpre, copy1, INSERT_VALUES)); 355 } 356 } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2"); 357 PetscFunctionReturn(PETSC_SUCCESS); 358 } 359 360 /* NORM_2 Case: 0.5 || x ||_2 + 0.5 * mu * ||x + u - z||^2 361 * Else : || x ||_2 + 0.5 * mu * ||x + u - z||^2 */ 362 static PetscErrorCode ObjectiveRegularizationADMM(Tao tao, Vec z, PetscReal *J, void *_ctx) 363 { 364 UserCtx ctx = (UserCtx)_ctx; 365 PetscReal mu, workNorm, reg; 366 Vec x, u, temp; 367 368 PetscFunctionBegin; 369 mu = ctx->mu; 370 x = ctx->workRight[4]; 371 u = ctx->workRight[6]; 372 temp = ctx->workRight[10]; 373 PetscCall(ObjectiveRegularization(tao, z, ®, _ctx)); 374 PetscCall(VecCopy(z, temp)); 375 /* temp = x + u -z */ 376 PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u)); 377 /* workNorm = ||x + u - z ||^2 */ 378 PetscCall(VecDot(temp, temp, &workNorm)); 379 *J = reg + 0.5 * mu * workNorm; 380 PetscFunctionReturn(PETSC_SUCCESS); 381 } 382 383 /* NORM_2 Case: x - mu*(x + u - z) 384 * NORM_1 Case: x/(|x| + eps) - mu*(x + u - z) 385 * Else: TODO */ 386 static PetscErrorCode GradientRegularizationADMM(Tao tao, Vec z, Vec V, void *_ctx) 387 { 388 UserCtx ctx = (UserCtx)_ctx; 389 PetscReal mu; 390 Vec x, u, temp; 391 392 PetscFunctionBegin; 393 mu = ctx->mu; 394 x = ctx->workRight[4]; 395 u = ctx->workRight[6]; 396 temp = ctx->workRight[10]; 397 PetscCall(GradientRegularization(tao, z, V, _ctx)); 398 PetscCall(VecCopy(z, temp)); 399 /* temp = x + u -z */ 400 PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u)); 401 PetscCall(VecAXPY(V, -mu, temp)); 402 PetscFunctionReturn(PETSC_SUCCESS); 403 } 404 405 /* NORM_2 Case: returns diag(mu) 406 * NORM_1 Case: FTF + diag(mu) */ 407 static PetscErrorCode HessianRegularizationADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) 408 { 409 UserCtx ctx = (UserCtx)_ctx; 410 411 PetscFunctionBegin; 412 if (ctx->p == NORM_2) { 413 /* Identity matrix scaled by mu */ 414 PetscCall(MatZeroEntries(H)); 415 PetscCall(MatShift(H, ctx->mu)); 416 if (Hpre != H) { 417 PetscCall(MatZeroEntries(Hpre)); 418 PetscCall(MatShift(Hpre, ctx->mu)); 419 } 420 } else if (ctx->p == NORM_1) { 421 PetscCall(HessianMisfit(tao, x, H, Hpre, (void *)ctx)); 422 PetscCall(MatShift(H, ctx->mu)); 423 if (Hpre != H) PetscCall(MatShift(Hpre, ctx->mu)); 424 } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2"); 425 PetscFunctionReturn(PETSC_SUCCESS); 426 } 427 428 /* NORM_2 Case : (1/2) * ||F x - d||^2 + 0.5 * || x ||_p 429 * NORM_1 Case : (1/2) * ||F x - d||^2 + || x ||_p */ 430 static PetscErrorCode ObjectiveComplete(Tao tao, Vec x, PetscReal *J, void *ctx) 431 { 432 PetscReal Jm, Jr; 433 434 PetscFunctionBegin; 435 PetscCall(ObjectiveMisfit(tao, x, &Jm, ctx)); 436 PetscCall(ObjectiveRegularization(tao, x, &Jr, ctx)); 437 *J = Jm + Jr; 438 PetscFunctionReturn(PETSC_SUCCESS); 439 } 440 441 /* NORM_2 Case: FTFx - FTd + x 442 * NORM_1 Case: FTFx - FTd + x/(|x| + eps) */ 443 static PetscErrorCode GradientComplete(Tao tao, Vec x, Vec V, void *ctx) 444 { 445 UserCtx cntx = (UserCtx)ctx; 446 447 PetscFunctionBegin; 448 PetscCall(GradientMisfit(tao, x, cntx->workRight[2], ctx)); 449 PetscCall(GradientRegularization(tao, x, cntx->workRight[3], ctx)); 450 PetscCall(VecWAXPY(V, 1, cntx->workRight[2], cntx->workRight[3])); 451 PetscFunctionReturn(PETSC_SUCCESS); 452 } 453 454 /* NORM_2 Case: diag(mu) + FTF 455 * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) + FTF */ 456 static PetscErrorCode HessianComplete(Tao tao, Vec x, Mat H, Mat Hpre, void *ctx) 457 { 458 Mat tempH; 459 460 PetscFunctionBegin; 461 PetscCall(MatDuplicate(H, MAT_SHARE_NONZERO_PATTERN, &tempH)); 462 PetscCall(HessianMisfit(tao, x, H, H, ctx)); 463 PetscCall(HessianRegularization(tao, x, tempH, tempH, ctx)); 464 PetscCall(MatAXPY(H, 1., tempH, DIFFERENT_NONZERO_PATTERN)); 465 if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN)); 466 PetscCall(MatDestroy(&tempH)); 467 PetscFunctionReturn(PETSC_SUCCESS); 468 } 469 470 static PetscErrorCode TaoSolveADMM(UserCtx ctx, Vec x) 471 { 472 PetscInt i; 473 PetscReal u_norm, r_norm, s_norm, primal, dual, x_norm, z_norm; 474 Tao tao1, tao2; 475 Vec xk, z, u, diff, zold, zdiff, temp; 476 PetscReal mu; 477 478 PetscFunctionBegin; 479 xk = ctx->workRight[4]; 480 z = ctx->workRight[5]; 481 u = ctx->workRight[6]; 482 diff = ctx->workRight[7]; 483 zold = ctx->workRight[8]; 484 zdiff = ctx->workRight[9]; 485 temp = ctx->workRight[11]; 486 mu = ctx->mu; 487 PetscCall(VecSet(u, 0.)); 488 PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao1)); 489 PetscCall(TaoSetType(tao1, TAONLS)); 490 PetscCall(TaoSetObjective(tao1, ObjectiveMisfitADMM, (void *)ctx)); 491 PetscCall(TaoSetGradient(tao1, NULL, GradientMisfitADMM, (void *)ctx)); 492 PetscCall(TaoSetHessian(tao1, ctx->Hm, ctx->Hm, HessianMisfitADMM, (void *)ctx)); 493 PetscCall(VecSet(xk, 0.)); 494 PetscCall(TaoSetSolution(tao1, xk)); 495 PetscCall(TaoSetOptionsPrefix(tao1, "misfit_")); 496 PetscCall(TaoSetFromOptions(tao1)); 497 PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao2)); 498 if (ctx->p == NORM_2) { 499 PetscCall(TaoSetType(tao2, TAONLS)); 500 PetscCall(TaoSetObjective(tao2, ObjectiveRegularizationADMM, (void *)ctx)); 501 PetscCall(TaoSetGradient(tao2, NULL, GradientRegularizationADMM, (void *)ctx)); 502 PetscCall(TaoSetHessian(tao2, ctx->Hr, ctx->Hr, HessianRegularizationADMM, (void *)ctx)); 503 } 504 PetscCall(VecSet(z, 0.)); 505 PetscCall(TaoSetSolution(tao2, z)); 506 PetscCall(TaoSetOptionsPrefix(tao2, "reg_")); 507 PetscCall(TaoSetFromOptions(tao2)); 508 509 for (i = 0; i < ctx->iter; i++) { 510 PetscCall(VecCopy(z, zold)); 511 PetscCall(TaoSolve(tao1)); /* Updates xk */ 512 if (ctx->p == NORM_1) { 513 PetscCall(VecWAXPY(temp, 1., xk, u)); 514 PetscCall(TaoSoftThreshold(temp, -ctx->alpha / mu, ctx->alpha / mu, z)); 515 } else { 516 PetscCall(TaoSolve(tao2)); /* Update zk */ 517 } 518 /* u = u + xk -z */ 519 PetscCall(VecAXPBYPCZ(u, 1., -1., 1., xk, z)); 520 /* r_norm : norm(x-z) */ 521 PetscCall(VecWAXPY(diff, -1., z, xk)); 522 PetscCall(VecNorm(diff, NORM_2, &r_norm)); 523 /* s_norm : norm(-mu(z-zold)) */ 524 PetscCall(VecWAXPY(zdiff, -1., zold, z)); 525 PetscCall(VecNorm(zdiff, NORM_2, &s_norm)); 526 s_norm = s_norm * mu; 527 /* primal : sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z))*/ 528 PetscCall(VecNorm(xk, NORM_2, &x_norm)); 529 PetscCall(VecNorm(z, NORM_2, &z_norm)); 530 primal = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * PetscMax(x_norm, z_norm); 531 /* Duality : sqrt(n)*ABSTOL + RELTOL*norm(mu*u)*/ 532 PetscCall(VecNorm(u, NORM_2, &u_norm)); 533 dual = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * u_norm * mu; 534 PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao1), "Iter %" PetscInt_FMT " : ||x-z||: %g, mu*||z-zold||: %g\n", i, (double)r_norm, (double)s_norm)); 535 if (r_norm < primal && s_norm < dual) break; 536 } 537 PetscCall(VecCopy(xk, x)); 538 PetscCall(TaoDestroy(&tao1)); 539 PetscCall(TaoDestroy(&tao2)); 540 PetscFunctionReturn(PETSC_SUCCESS); 541 } 542 543 /* Second order Taylor remainder convergence test */ 544 static PetscErrorCode TaylorTest(UserCtx ctx, Tao tao, Vec x, PetscReal *C) 545 { 546 PetscReal h, J, temp; 547 PetscInt i, j; 548 PetscInt numValues; 549 PetscReal Jx, Jxhat_comp, Jxhat_pred; 550 PetscReal *Js, *hs; 551 PetscReal gdotdx; 552 PetscReal minrate = PETSC_MAX_REAL; 553 MPI_Comm comm = PetscObjectComm((PetscObject)x); 554 Vec g, dx, xhat; 555 556 PetscFunctionBegin; 557 PetscCall(VecDuplicate(x, &g)); 558 PetscCall(VecDuplicate(x, &xhat)); 559 /* choose a perturbation direction */ 560 PetscCall(VecDuplicate(x, &dx)); 561 PetscCall(VecSetRandom(dx, ctx->rctx)); 562 /* evaluate objective at x: J(x) */ 563 PetscCall(TaoComputeObjective(tao, x, &Jx)); 564 /* evaluate gradient at x, save in vector g */ 565 PetscCall(TaoComputeGradient(tao, x, g)); 566 PetscCall(VecDot(g, dx, &gdotdx)); 567 568 for (numValues = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor) numValues++; 569 PetscCall(PetscCalloc2(numValues, &Js, numValues, &hs)); 570 for (i = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor, i++) { 571 PetscCall(VecWAXPY(xhat, h, dx, x)); 572 PetscCall(TaoComputeObjective(tao, xhat, &Jxhat_comp)); 573 /* J(\hat(x)) \approx J(x) + g^T (xhat - x) = J(x) + h * g^T dx */ 574 Jxhat_pred = Jx + h * gdotdx; 575 /* Vector to dJdm scalar? Dot?*/ 576 J = PetscAbsReal(Jxhat_comp - Jxhat_pred); 577 PetscCall(PetscPrintf(comm, "J(xhat): %g, predicted: %g, diff %g\n", (double)Jxhat_comp, (double)Jxhat_pred, (double)J)); 578 Js[i] = J; 579 hs[i] = h; 580 } 581 for (j = 1; j < numValues; j++) { 582 temp = PetscLogReal(Js[j] / Js[j - 1]) / PetscLogReal(hs[j] / hs[j - 1]); 583 PetscCall(PetscPrintf(comm, "Convergence rate step %" PetscInt_FMT ": %g\n", j - 1, (double)temp)); 584 minrate = PetscMin(minrate, temp); 585 } 586 /* If O is not ~2, then the test is wrong */ 587 PetscCall(PetscFree2(Js, hs)); 588 *C = minrate; 589 PetscCall(VecDestroy(&dx)); 590 PetscCall(VecDestroy(&xhat)); 591 PetscCall(VecDestroy(&g)); 592 PetscFunctionReturn(PETSC_SUCCESS); 593 } 594 595 int main(int argc, char **argv) 596 { 597 UserCtx ctx; 598 Tao tao; 599 Vec x; 600 Mat H; 601 602 PetscFunctionBeginUser; 603 PetscCall(PetscInitialize(&argc, &argv, NULL, help)); 604 PetscCall(PetscNew(&ctx)); 605 PetscCall(ConfigureContext(ctx)); 606 /* Define two functions that could pass as objectives to TaoSetObjective(): one 607 * for the misfit component, and one for the regularization component */ 608 /* ObjectiveMisfit() and ObjectiveRegularization() */ 609 610 /* Define a single function that calls both components adds them together: the complete objective, 611 * in the absence of a Tao implementation that handles separability */ 612 /* ObjectiveComplete() */ 613 PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao)); 614 PetscCall(TaoSetType(tao, TAONM)); 615 PetscCall(TaoSetObjective(tao, ObjectiveComplete, (void *)ctx)); 616 PetscCall(TaoSetGradient(tao, NULL, GradientComplete, (void *)ctx)); 617 PetscCall(MatDuplicate(ctx->W, MAT_SHARE_NONZERO_PATTERN, &H)); 618 PetscCall(TaoSetHessian(tao, H, H, HessianComplete, (void *)ctx)); 619 PetscCall(MatCreateVecs(ctx->F, NULL, &x)); 620 PetscCall(VecSet(x, 0.)); 621 PetscCall(TaoSetSolution(tao, x)); 622 PetscCall(TaoSetFromOptions(tao)); 623 if (ctx->use_admm) PetscCall(TaoSolveADMM(ctx, x)); 624 else PetscCall(TaoSolve(tao)); 625 /* examine solution */ 626 PetscCall(VecViewFromOptions(x, NULL, "-view_sol")); 627 if (ctx->taylor) { 628 PetscReal rate; 629 PetscCall(TaylorTest(ctx, tao, x, &rate)); 630 } 631 PetscCall(MatDestroy(&H)); 632 PetscCall(TaoDestroy(&tao)); 633 PetscCall(VecDestroy(&x)); 634 PetscCall(DestroyContext(&ctx)); 635 PetscCall(PetscFinalize()); 636 return 0; 637 } 638 639 /*TEST 640 641 build: 642 requires: !complex 643 644 test: 645 suffix: 0 646 args: 647 648 test: 649 suffix: l1_1 650 args: -p 1 -tao_type lmvm -alpha 1. -epsilon 1.e-7 -m 64 -n 64 -view_sol -matrix_format 1 651 652 test: 653 suffix: hessian_1 654 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nls 655 656 test: 657 suffix: hessian_2 658 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nls 659 660 test: 661 suffix: nm_1 662 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nm -tao_max_it 50 663 664 test: 665 suffix: nm_2 666 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nm -tao_max_it 50 667 668 test: 669 suffix: lmvm_1 670 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type lmvm -tao_max_it 40 671 672 test: 673 suffix: lmvm_2 674 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type lmvm -tao_max_it 15 675 676 test: 677 suffix: soft_threshold_admm_1 678 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm 679 680 test: 681 suffix: hessian_admm_1 682 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nls -misfit_tao_type nls 683 684 test: 685 suffix: hessian_admm_2 686 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nls -misfit_tao_type nls 687 688 test: 689 suffix: nm_admm_1 690 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nm -misfit_tao_type nm 691 692 test: 693 suffix: nm_admm_2 694 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nm -misfit_tao_type nm -iter 7 695 696 test: 697 suffix: lmvm_admm_1 698 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm 699 700 test: 701 suffix: lmvm_admm_2 702 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm 703 704 TEST*/ 705