1 static char help[] = "Simple example to test separable objective optimizers.\n"; 2 3 #include <petsc.h> 4 #include <petsctao.h> 5 #include <petscvec.h> 6 #include <petscmath.h> 7 8 #define NWORKLEFT 4 9 #define NWORKRIGHT 12 10 11 typedef struct _UserCtx { 12 PetscInt m; /* The row dimension of F */ 13 PetscInt n; /* The column dimension of F */ 14 PetscInt matops; /* Matrix format. 0 for stencil, 1 for random */ 15 PetscInt iter; /* Number of iterations for ADMM */ 16 PetscReal hStart; /* Starting point for Taylor test */ 17 PetscReal hFactor; /* Taylor test step factor */ 18 PetscReal hMin; /* Taylor test end goal */ 19 PetscReal alpha; /* regularization constant applied to || x ||_p */ 20 PetscReal eps; /* small constant for approximating gradient of || x ||_1 */ 21 PetscReal mu; /* the augmented Lagrangian term in ADMM */ 22 PetscReal abstol; 23 PetscReal reltol; 24 Mat F; /* matrix in least squares component $(1/2) * || F x - d ||_2^2$ */ 25 Mat W; /* Workspace matrix. ATA */ 26 Mat Hm; /* Hessian Misfit*/ 27 Mat Hr; /* Hessian Reg*/ 28 Vec d; /* RHS in least squares component $(1/2) * || F x - d ||_2^2$ */ 29 Vec workLeft[NWORKLEFT]; /* Workspace for temporary vec */ 30 Vec workRight[NWORKRIGHT]; /* Workspace for temporary vec */ 31 NormType p; 32 PetscRandom rctx; 33 PetscBool soft; 34 PetscBool taylor; /* Flag to determine whether to run Taylor test or not */ 35 PetscBool use_admm; /* Flag to determine whether to run Taylor test or not */ 36 } *UserCtx; 37 38 static PetscErrorCode CreateRHS(UserCtx ctx) 39 { 40 PetscFunctionBegin; 41 /* build the rhs d in ctx */ 42 PetscCall(VecCreate(PETSC_COMM_WORLD, &ctx->d)); 43 PetscCall(VecSetSizes(ctx->d, PETSC_DECIDE, ctx->m)); 44 PetscCall(VecSetFromOptions(ctx->d)); 45 PetscCall(VecSetRandom(ctx->d, ctx->rctx)); 46 PetscFunctionReturn(PETSC_SUCCESS); 47 } 48 49 static PetscErrorCode CreateMatrix(UserCtx ctx) 50 { 51 PetscInt Istart, Iend, i, j, Ii, gridN, I_n, I_s, I_e, I_w; 52 PetscLogStage stage; 53 54 PetscFunctionBegin; 55 /* build the matrix F in ctx */ 56 PetscCall(MatCreate(PETSC_COMM_WORLD, &ctx->F)); 57 PetscCall(MatSetSizes(ctx->F, PETSC_DECIDE, PETSC_DECIDE, ctx->m, ctx->n)); 58 PetscCall(MatSetType(ctx->F, MATAIJ)); /* TODO: Decide specific SetType other than dummy*/ 59 PetscCall(MatMPIAIJSetPreallocation(ctx->F, 5, NULL, 5, NULL)); /*TODO: some number other than 5?*/ 60 PetscCall(MatSeqAIJSetPreallocation(ctx->F, 5, NULL)); 61 PetscCall(MatSetUp(ctx->F)); 62 PetscCall(MatGetOwnershipRange(ctx->F, &Istart, &Iend)); 63 PetscCall(PetscLogStageRegister("Assembly", &stage)); 64 PetscCall(PetscLogStagePush(stage)); 65 66 /* Set matrix elements in 2-D five point stencil format. */ 67 if (!ctx->matops) { 68 PetscCheck(ctx->m == ctx->n, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Stencil matrix must be square"); 69 gridN = (PetscInt)PetscSqrtReal((PetscReal)ctx->m); 70 PetscCheck(gridN * gridN == ctx->m, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Number of rows must be square"); 71 for (Ii = Istart; Ii < Iend; Ii++) { 72 i = Ii / gridN; 73 j = Ii % gridN; 74 I_n = i * gridN + j + 1; 75 if (j + 1 >= gridN) I_n = -1; 76 I_s = i * gridN + j - 1; 77 if (j - 1 < 0) I_s = -1; 78 I_e = (i + 1) * gridN + j; 79 if (i + 1 >= gridN) I_e = -1; 80 I_w = (i - 1) * gridN + j; 81 if (i - 1 < 0) I_w = -1; 82 PetscCall(MatSetValue(ctx->F, Ii, Ii, 4., INSERT_VALUES)); 83 PetscCall(MatSetValue(ctx->F, Ii, I_n, -1., INSERT_VALUES)); 84 PetscCall(MatSetValue(ctx->F, Ii, I_s, -1., INSERT_VALUES)); 85 PetscCall(MatSetValue(ctx->F, Ii, I_e, -1., INSERT_VALUES)); 86 PetscCall(MatSetValue(ctx->F, Ii, I_w, -1., INSERT_VALUES)); 87 } 88 } else PetscCall(MatSetRandom(ctx->F, ctx->rctx)); 89 PetscCall(MatAssemblyBegin(ctx->F, MAT_FINAL_ASSEMBLY)); 90 PetscCall(MatAssemblyEnd(ctx->F, MAT_FINAL_ASSEMBLY)); 91 PetscCall(PetscLogStagePop()); 92 /* Stencil matrix is symmetric. Setting symmetric flag for ICC/Cholesky preconditioner */ 93 if (!ctx->matops) PetscCall(MatSetOption(ctx->F, MAT_SYMMETRIC, PETSC_TRUE)); 94 PetscCall(MatTransposeMatMult(ctx->F, ctx->F, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &ctx->W)); 95 /* Setup Hessian Workspace in same shape as W */ 96 PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &ctx->Hm)); 97 PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &ctx->Hr)); 98 PetscFunctionReturn(PETSC_SUCCESS); 99 } 100 101 static PetscErrorCode SetupWorkspace(UserCtx ctx) 102 { 103 PetscInt i; 104 105 PetscFunctionBegin; 106 PetscCall(MatCreateVecs(ctx->F, &ctx->workLeft[0], &ctx->workRight[0])); 107 for (i = 1; i < NWORKLEFT; i++) PetscCall(VecDuplicate(ctx->workLeft[0], &ctx->workLeft[i])); 108 for (i = 1; i < NWORKRIGHT; i++) PetscCall(VecDuplicate(ctx->workRight[0], &ctx->workRight[i])); 109 PetscFunctionReturn(PETSC_SUCCESS); 110 } 111 112 static PetscErrorCode ConfigureContext(UserCtx ctx) 113 { 114 PetscFunctionBegin; 115 ctx->m = 16; 116 ctx->n = 16; 117 ctx->eps = 1.e-3; 118 ctx->abstol = 1.e-4; 119 ctx->reltol = 1.e-2; 120 ctx->hStart = 1.; 121 ctx->hMin = 1.e-3; 122 ctx->hFactor = 0.5; 123 ctx->alpha = 1.; 124 ctx->mu = 1.0; 125 ctx->matops = 0; 126 ctx->iter = 10; 127 ctx->p = NORM_2; 128 ctx->soft = PETSC_FALSE; 129 ctx->taylor = PETSC_TRUE; 130 ctx->use_admm = PETSC_FALSE; 131 PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "ex4.c"); 132 PetscCall(PetscOptionsInt("-m", "The row dimension of matrix F", "ex4.c", ctx->m, &ctx->m, NULL)); 133 PetscCall(PetscOptionsInt("-n", "The column dimension of matrix F", "ex4.c", ctx->n, &ctx->n, NULL)); 134 PetscCall(PetscOptionsInt("-matrix_format", "Decide format of F matrix. 0 for stencil, 1 for random", "ex4.c", ctx->matops, &ctx->matops, NULL)); 135 PetscCall(PetscOptionsInt("-iter", "Iteration number ADMM", "ex4.c", ctx->iter, &ctx->iter, NULL)); 136 PetscCall(PetscOptionsReal("-alpha", "The regularization multiplier. 1 default", "ex4.c", ctx->alpha, &ctx->alpha, NULL)); 137 PetscCall(PetscOptionsReal("-epsilon", "The small constant added to |x_i| in the denominator to approximate the gradient of ||x||_1", "ex4.c", ctx->eps, &ctx->eps, NULL)); 138 PetscCall(PetscOptionsReal("-mu", "The augmented lagrangian multiplier in ADMM", "ex4.c", ctx->mu, &ctx->mu, NULL)); 139 PetscCall(PetscOptionsReal("-hStart", "Taylor test starting point. 1 default.", "ex4.c", ctx->hStart, &ctx->hStart, NULL)); 140 PetscCall(PetscOptionsReal("-hFactor", "Taylor test multiplier factor. 0.5 default", "ex4.c", ctx->hFactor, &ctx->hFactor, NULL)); 141 PetscCall(PetscOptionsReal("-hMin", "Taylor test ending condition. 1.e-3 default", "ex4.c", ctx->hMin, &ctx->hMin, NULL)); 142 PetscCall(PetscOptionsReal("-abstol", "Absolute stopping criterion for ADMM", "ex4.c", ctx->abstol, &ctx->abstol, NULL)); 143 PetscCall(PetscOptionsReal("-reltol", "Relative stopping criterion for ADMM", "ex4.c", ctx->reltol, &ctx->reltol, NULL)); 144 PetscCall(PetscOptionsBool("-taylor", "Flag for Taylor test. Default is true.", "ex4.c", ctx->taylor, &ctx->taylor, NULL)); 145 PetscCall(PetscOptionsBool("-soft", "Flag for testing soft threshold no-op case. Default is false.", "ex4.c", ctx->soft, &ctx->soft, NULL)); 146 PetscCall(PetscOptionsBool("-use_admm", "Use the ADMM solver in this example.", "ex4.c", ctx->use_admm, &ctx->use_admm, NULL)); 147 PetscCall(PetscOptionsEnum("-p", "Norm type.", "ex4.c", NormTypes, (PetscEnum)ctx->p, (PetscEnum *)&ctx->p, NULL)); 148 PetscOptionsEnd(); 149 /* Creating random ctx */ 150 PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &ctx->rctx)); 151 PetscCall(PetscRandomSetFromOptions(ctx->rctx)); 152 PetscCall(CreateMatrix(ctx)); 153 PetscCall(CreateRHS(ctx)); 154 PetscCall(SetupWorkspace(ctx)); 155 PetscFunctionReturn(PETSC_SUCCESS); 156 } 157 158 static PetscErrorCode DestroyContext(UserCtx *ctx) 159 { 160 PetscInt i; 161 162 PetscFunctionBegin; 163 PetscCall(MatDestroy(&(*ctx)->F)); 164 PetscCall(MatDestroy(&(*ctx)->W)); 165 PetscCall(MatDestroy(&(*ctx)->Hm)); 166 PetscCall(MatDestroy(&(*ctx)->Hr)); 167 PetscCall(VecDestroy(&(*ctx)->d)); 168 for (i = 0; i < NWORKLEFT; i++) PetscCall(VecDestroy(&(*ctx)->workLeft[i])); 169 for (i = 0; i < NWORKRIGHT; i++) PetscCall(VecDestroy(&(*ctx)->workRight[i])); 170 PetscCall(PetscRandomDestroy(&(*ctx)->rctx)); 171 PetscCall(PetscFree(*ctx)); 172 PetscFunctionReturn(PETSC_SUCCESS); 173 } 174 175 /* compute (1/2) * ||F x - d||^2 */ 176 static PetscErrorCode ObjectiveMisfit(Tao tao, Vec x, PetscReal *J, void *_ctx) 177 { 178 UserCtx ctx = (UserCtx)_ctx; 179 Vec y; 180 181 PetscFunctionBegin; 182 y = ctx->workLeft[0]; 183 PetscCall(MatMult(ctx->F, x, y)); 184 PetscCall(VecAXPY(y, -1., ctx->d)); 185 PetscCall(VecDot(y, y, J)); 186 *J *= 0.5; 187 PetscFunctionReturn(PETSC_SUCCESS); 188 } 189 190 /* compute V = FTFx - FTd */ 191 static PetscErrorCode GradientMisfit(Tao tao, Vec x, Vec V, void *_ctx) 192 { 193 UserCtx ctx = (UserCtx)_ctx; 194 Vec FTFx, FTd; 195 196 PetscFunctionBegin; 197 /* work1 is A^T Ax, work2 is Ab, W is A^T A*/ 198 FTFx = ctx->workRight[0]; 199 FTd = ctx->workRight[1]; 200 PetscCall(MatMult(ctx->W, x, FTFx)); 201 PetscCall(MatMultTranspose(ctx->F, ctx->d, FTd)); 202 PetscCall(VecWAXPY(V, -1., FTd, FTFx)); 203 PetscFunctionReturn(PETSC_SUCCESS); 204 } 205 206 /* returns FTF */ 207 static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) 208 { 209 UserCtx ctx = (UserCtx)_ctx; 210 211 PetscFunctionBegin; 212 if (H != ctx->W) PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN)); 213 if (Hpre != ctx->W) PetscCall(MatCopy(ctx->W, Hpre, DIFFERENT_NONZERO_PATTERN)); 214 PetscFunctionReturn(PETSC_SUCCESS); 215 } 216 217 /* computes augment Lagrangian objective (with scaled dual): 218 * 0.5 * ||F x - d||^2 + 0.5 * mu ||x - z + u||^2 */ 219 static PetscErrorCode ObjectiveMisfitADMM(Tao tao, Vec x, PetscReal *J, void *_ctx) 220 { 221 UserCtx ctx = (UserCtx)_ctx; 222 PetscReal mu, workNorm, misfit; 223 Vec z, u, temp; 224 225 PetscFunctionBegin; 226 mu = ctx->mu; 227 z = ctx->workRight[5]; 228 u = ctx->workRight[6]; 229 temp = ctx->workRight[10]; 230 /* misfit = f(x) */ 231 PetscCall(ObjectiveMisfit(tao, x, &misfit, _ctx)); 232 PetscCall(VecCopy(x, temp)); 233 /* temp = x - z + u */ 234 PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u)); 235 /* workNorm = ||x - z + u||^2 */ 236 PetscCall(VecDot(temp, temp, &workNorm)); 237 /* augment Lagrangian objective (with scaled dual): f(x) + 0.5 * mu ||x -z + u||^2 */ 238 *J = misfit + 0.5 * mu * workNorm; 239 PetscFunctionReturn(PETSC_SUCCESS); 240 } 241 242 /* computes FTFx - FTd mu*(x - z + u) */ 243 static PetscErrorCode GradientMisfitADMM(Tao tao, Vec x, Vec V, void *_ctx) 244 { 245 UserCtx ctx = (UserCtx)_ctx; 246 PetscReal mu; 247 Vec z, u, temp; 248 249 PetscFunctionBegin; 250 mu = ctx->mu; 251 z = ctx->workRight[5]; 252 u = ctx->workRight[6]; 253 temp = ctx->workRight[10]; 254 PetscCall(GradientMisfit(tao, x, V, _ctx)); 255 PetscCall(VecCopy(x, temp)); 256 /* temp = x - z + u */ 257 PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u)); 258 /* V = FTFx - FTd mu*(x - z + u) */ 259 PetscCall(VecAXPY(V, mu, temp)); 260 PetscFunctionReturn(PETSC_SUCCESS); 261 } 262 263 /* returns FTF + diag(mu) */ 264 static PetscErrorCode HessianMisfitADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) 265 { 266 UserCtx ctx = (UserCtx)_ctx; 267 268 PetscFunctionBegin; 269 PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN)); 270 PetscCall(MatShift(H, ctx->mu)); 271 if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN)); 272 PetscFunctionReturn(PETSC_SUCCESS); 273 } 274 275 /* computes || x ||_p (mult by 0.5 in case of NORM_2) */ 276 static PetscErrorCode ObjectiveRegularization(Tao tao, Vec x, PetscReal *J, void *_ctx) 277 { 278 UserCtx ctx = (UserCtx)_ctx; 279 PetscReal norm; 280 281 PetscFunctionBegin; 282 *J = 0; 283 PetscCall(VecNorm(x, ctx->p, &norm)); 284 if (ctx->p == NORM_2) norm = 0.5 * norm * norm; 285 *J = ctx->alpha * norm; 286 PetscFunctionReturn(PETSC_SUCCESS); 287 } 288 289 /* NORM_2 Case: return x 290 * NORM_1 Case: x/(|x| + eps) 291 * Else: TODO */ 292 static PetscErrorCode GradientRegularization(Tao tao, Vec x, Vec V, void *_ctx) 293 { 294 UserCtx ctx = (UserCtx)_ctx; 295 PetscReal eps = ctx->eps; 296 297 PetscFunctionBegin; 298 if (ctx->p == NORM_2) { 299 PetscCall(VecCopy(x, V)); 300 } else if (ctx->p == NORM_1) { 301 PetscCall(VecCopy(x, ctx->workRight[1])); 302 PetscCall(VecAbs(ctx->workRight[1])); 303 PetscCall(VecShift(ctx->workRight[1], eps)); 304 PetscCall(VecPointwiseDivide(V, x, ctx->workRight[1])); 305 } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2"); 306 PetscFunctionReturn(PETSC_SUCCESS); 307 } 308 309 /* NORM_2 Case: returns diag(mu) 310 * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) */ 311 static PetscErrorCode HessianRegularization(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) 312 { 313 UserCtx ctx = (UserCtx)_ctx; 314 PetscReal eps = ctx->eps; 315 Vec copy1, copy2, copy3; 316 317 PetscFunctionBegin; 318 if (ctx->p == NORM_2) { 319 /* Identity matrix scaled by mu */ 320 PetscCall(MatZeroEntries(H)); 321 PetscCall(MatShift(H, ctx->mu)); 322 if (Hpre != H) { 323 PetscCall(MatZeroEntries(Hpre)); 324 PetscCall(MatShift(Hpre, ctx->mu)); 325 } 326 } else if (ctx->p == NORM_1) { 327 /* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)) */ 328 copy1 = ctx->workRight[1]; 329 copy2 = ctx->workRight[2]; 330 copy3 = ctx->workRight[3]; 331 /* copy1 : 1/sqrt(x_i^2 + eps) */ 332 PetscCall(VecCopy(x, copy1)); 333 PetscCall(VecPow(copy1, 2)); 334 PetscCall(VecShift(copy1, eps)); 335 PetscCall(VecSqrtAbs(copy1)); 336 PetscCall(VecReciprocal(copy1)); 337 /* copy2: x_i^2.*/ 338 PetscCall(VecCopy(x, copy2)); 339 PetscCall(VecPow(copy2, 2)); 340 /* copy3: abs(x_i^2 + eps) */ 341 PetscCall(VecCopy(x, copy3)); 342 PetscCall(VecPow(copy3, 2)); 343 PetscCall(VecShift(copy3, eps)); 344 PetscCall(VecAbs(copy3)); 345 /* copy2: 1 - x_i^2/abs(x_i^2 + eps) */ 346 PetscCall(VecPointwiseDivide(copy2, copy2, copy3)); 347 PetscCall(VecScale(copy2, -1.)); 348 PetscCall(VecShift(copy2, 1.)); 349 PetscCall(VecAXPY(copy1, 1., copy2)); 350 PetscCall(VecScale(copy1, ctx->mu)); 351 PetscCall(MatZeroEntries(H)); 352 PetscCall(MatDiagonalSet(H, copy1, INSERT_VALUES)); 353 if (Hpre != H) { 354 PetscCall(MatZeroEntries(Hpre)); 355 PetscCall(MatDiagonalSet(Hpre, copy1, INSERT_VALUES)); 356 } 357 } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2"); 358 PetscFunctionReturn(PETSC_SUCCESS); 359 } 360 361 /* NORM_2 Case: 0.5 || x ||_2 + 0.5 * mu * ||x + u - z||^2 362 * Else : || x ||_2 + 0.5 * mu * ||x + u - z||^2 */ 363 static PetscErrorCode ObjectiveRegularizationADMM(Tao tao, Vec z, PetscReal *J, void *_ctx) 364 { 365 UserCtx ctx = (UserCtx)_ctx; 366 PetscReal mu, workNorm, reg; 367 Vec x, u, temp; 368 369 PetscFunctionBegin; 370 mu = ctx->mu; 371 x = ctx->workRight[4]; 372 u = ctx->workRight[6]; 373 temp = ctx->workRight[10]; 374 PetscCall(ObjectiveRegularization(tao, z, ®, _ctx)); 375 PetscCall(VecCopy(z, temp)); 376 /* temp = x + u -z */ 377 PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u)); 378 /* workNorm = ||x + u - z ||^2 */ 379 PetscCall(VecDot(temp, temp, &workNorm)); 380 *J = reg + 0.5 * mu * workNorm; 381 PetscFunctionReturn(PETSC_SUCCESS); 382 } 383 384 /* NORM_2 Case: x - mu*(x + u - z) 385 * NORM_1 Case: x/(|x| + eps) - mu*(x + u - z) 386 * Else: TODO */ 387 static PetscErrorCode GradientRegularizationADMM(Tao tao, Vec z, Vec V, void *_ctx) 388 { 389 UserCtx ctx = (UserCtx)_ctx; 390 PetscReal mu; 391 Vec x, u, temp; 392 393 PetscFunctionBegin; 394 mu = ctx->mu; 395 x = ctx->workRight[4]; 396 u = ctx->workRight[6]; 397 temp = ctx->workRight[10]; 398 PetscCall(GradientRegularization(tao, z, V, _ctx)); 399 PetscCall(VecCopy(z, temp)); 400 /* temp = x + u -z */ 401 PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u)); 402 PetscCall(VecAXPY(V, -mu, temp)); 403 PetscFunctionReturn(PETSC_SUCCESS); 404 } 405 406 /* NORM_2 Case: returns diag(mu) 407 * NORM_1 Case: FTF + diag(mu) */ 408 static PetscErrorCode HessianRegularizationADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) 409 { 410 UserCtx ctx = (UserCtx)_ctx; 411 412 PetscFunctionBegin; 413 if (ctx->p == NORM_2) { 414 /* Identity matrix scaled by mu */ 415 PetscCall(MatZeroEntries(H)); 416 PetscCall(MatShift(H, ctx->mu)); 417 if (Hpre != H) { 418 PetscCall(MatZeroEntries(Hpre)); 419 PetscCall(MatShift(Hpre, ctx->mu)); 420 } 421 } else if (ctx->p == NORM_1) { 422 PetscCall(HessianMisfit(tao, x, H, Hpre, (void *)ctx)); 423 PetscCall(MatShift(H, ctx->mu)); 424 if (Hpre != H) PetscCall(MatShift(Hpre, ctx->mu)); 425 } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2"); 426 PetscFunctionReturn(PETSC_SUCCESS); 427 } 428 429 /* NORM_2 Case : (1/2) * ||F x - d||^2 + 0.5 * || x ||_p 430 * NORM_1 Case : (1/2) * ||F x - d||^2 + || x ||_p */ 431 static PetscErrorCode ObjectiveComplete(Tao tao, Vec x, PetscReal *J, PetscCtx ctx) 432 { 433 PetscReal Jm, Jr; 434 435 PetscFunctionBegin; 436 PetscCall(ObjectiveMisfit(tao, x, &Jm, ctx)); 437 PetscCall(ObjectiveRegularization(tao, x, &Jr, ctx)); 438 *J = Jm + Jr; 439 PetscFunctionReturn(PETSC_SUCCESS); 440 } 441 442 /* NORM_2 Case: FTFx - FTd + x 443 * NORM_1 Case: FTFx - FTd + x/(|x| + eps) */ 444 static PetscErrorCode GradientComplete(Tao tao, Vec x, Vec V, PetscCtx ctx) 445 { 446 UserCtx cntx = (UserCtx)ctx; 447 448 PetscFunctionBegin; 449 PetscCall(GradientMisfit(tao, x, cntx->workRight[2], ctx)); 450 PetscCall(GradientRegularization(tao, x, cntx->workRight[3], ctx)); 451 PetscCall(VecWAXPY(V, 1, cntx->workRight[2], cntx->workRight[3])); 452 PetscFunctionReturn(PETSC_SUCCESS); 453 } 454 455 /* NORM_2 Case: diag(mu) + FTF 456 * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) + FTF */ 457 static PetscErrorCode HessianComplete(Tao tao, Vec x, Mat H, Mat Hpre, PetscCtx ctx) 458 { 459 Mat tempH; 460 461 PetscFunctionBegin; 462 PetscCall(MatDuplicate(H, MAT_SHARE_NONZERO_PATTERN, &tempH)); 463 PetscCall(HessianMisfit(tao, x, H, H, ctx)); 464 PetscCall(HessianRegularization(tao, x, tempH, tempH, ctx)); 465 PetscCall(MatAXPY(H, 1., tempH, DIFFERENT_NONZERO_PATTERN)); 466 if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN)); 467 PetscCall(MatDestroy(&tempH)); 468 PetscFunctionReturn(PETSC_SUCCESS); 469 } 470 471 static PetscErrorCode TaoSolveADMM(UserCtx ctx, Vec x) 472 { 473 PetscInt i; 474 PetscReal u_norm, r_norm, s_norm, primal, dual, x_norm, z_norm; 475 Tao tao1, tao2; 476 Vec xk, z, u, diff, zold, zdiff, temp; 477 PetscReal mu; 478 479 PetscFunctionBegin; 480 xk = ctx->workRight[4]; 481 z = ctx->workRight[5]; 482 u = ctx->workRight[6]; 483 diff = ctx->workRight[7]; 484 zold = ctx->workRight[8]; 485 zdiff = ctx->workRight[9]; 486 temp = ctx->workRight[11]; 487 mu = ctx->mu; 488 PetscCall(VecSet(u, 0.)); 489 PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao1)); 490 PetscCall(TaoSetType(tao1, TAONLS)); 491 PetscCall(TaoSetObjective(tao1, ObjectiveMisfitADMM, (void *)ctx)); 492 PetscCall(TaoSetGradient(tao1, NULL, GradientMisfitADMM, (void *)ctx)); 493 PetscCall(TaoSetHessian(tao1, ctx->Hm, ctx->Hm, HessianMisfitADMM, (void *)ctx)); 494 PetscCall(VecSet(xk, 0.)); 495 PetscCall(TaoSetSolution(tao1, xk)); 496 PetscCall(TaoSetOptionsPrefix(tao1, "misfit_")); 497 PetscCall(TaoSetFromOptions(tao1)); 498 PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao2)); 499 if (ctx->p == NORM_2) { 500 PetscCall(TaoSetType(tao2, TAONLS)); 501 PetscCall(TaoSetObjective(tao2, ObjectiveRegularizationADMM, (void *)ctx)); 502 PetscCall(TaoSetGradient(tao2, NULL, GradientRegularizationADMM, (void *)ctx)); 503 PetscCall(TaoSetHessian(tao2, ctx->Hr, ctx->Hr, HessianRegularizationADMM, (void *)ctx)); 504 } 505 PetscCall(VecSet(z, 0.)); 506 PetscCall(TaoSetSolution(tao2, z)); 507 PetscCall(TaoSetOptionsPrefix(tao2, "reg_")); 508 PetscCall(TaoSetFromOptions(tao2)); 509 510 for (i = 0; i < ctx->iter; i++) { 511 PetscCall(VecCopy(z, zold)); 512 PetscCall(TaoSolve(tao1)); /* Updates xk */ 513 if (ctx->p == NORM_1) { 514 PetscCall(VecWAXPY(temp, 1., xk, u)); 515 PetscCall(TaoSoftThreshold(temp, -ctx->alpha / mu, ctx->alpha / mu, z)); 516 } else { 517 PetscCall(TaoSolve(tao2)); /* Update zk */ 518 } 519 /* u = u + xk -z */ 520 PetscCall(VecAXPBYPCZ(u, 1., -1., 1., xk, z)); 521 /* r_norm : norm(x-z) */ 522 PetscCall(VecWAXPY(diff, -1., z, xk)); 523 PetscCall(VecNorm(diff, NORM_2, &r_norm)); 524 /* s_norm : norm(-mu(z-zold)) */ 525 PetscCall(VecWAXPY(zdiff, -1., zold, z)); 526 PetscCall(VecNorm(zdiff, NORM_2, &s_norm)); 527 s_norm = s_norm * mu; 528 /* primal : sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z))*/ 529 PetscCall(VecNorm(xk, NORM_2, &x_norm)); 530 PetscCall(VecNorm(z, NORM_2, &z_norm)); 531 primal = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * PetscMax(x_norm, z_norm); 532 /* Duality : sqrt(n)*ABSTOL + RELTOL*norm(mu*u)*/ 533 PetscCall(VecNorm(u, NORM_2, &u_norm)); 534 dual = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * u_norm * mu; 535 PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao1), "Iter %" PetscInt_FMT " : ||x-z||: %g, mu*||z-zold||: %g\n", i, (double)r_norm, (double)s_norm)); 536 if (r_norm < primal && s_norm < dual) break; 537 } 538 PetscCall(VecCopy(xk, x)); 539 PetscCall(TaoDestroy(&tao1)); 540 PetscCall(TaoDestroy(&tao2)); 541 PetscFunctionReturn(PETSC_SUCCESS); 542 } 543 544 /* Second order Taylor remainder convergence test */ 545 static PetscErrorCode TaylorTest(UserCtx ctx, Tao tao, Vec x, PetscReal *C) 546 { 547 PetscReal h, J, temp; 548 PetscInt i, j; 549 PetscInt numValues; 550 PetscReal Jx, Jxhat_comp, Jxhat_pred; 551 PetscReal *Js, *hs; 552 PetscReal gdotdx; 553 PetscReal minrate = PETSC_MAX_REAL; 554 MPI_Comm comm = PetscObjectComm((PetscObject)x); 555 Vec g, dx, xhat; 556 557 PetscFunctionBegin; 558 PetscCall(VecDuplicate(x, &g)); 559 PetscCall(VecDuplicate(x, &xhat)); 560 /* choose a perturbation direction */ 561 PetscCall(VecDuplicate(x, &dx)); 562 PetscCall(VecSetRandom(dx, ctx->rctx)); 563 /* evaluate objective at x: J(x) */ 564 PetscCall(TaoComputeObjective(tao, x, &Jx)); 565 /* evaluate gradient at x, save in vector g */ 566 PetscCall(TaoComputeGradient(tao, x, g)); 567 PetscCall(VecDot(g, dx, &gdotdx)); 568 569 for (numValues = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor) numValues++; 570 PetscCall(PetscCalloc2(numValues, &Js, numValues, &hs)); 571 for (i = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor, i++) { 572 PetscCall(VecWAXPY(xhat, h, dx, x)); 573 PetscCall(TaoComputeObjective(tao, xhat, &Jxhat_comp)); 574 /* J(\hat(x)) \approx J(x) + g^T (xhat - x) = J(x) + h * g^T dx */ 575 Jxhat_pred = Jx + h * gdotdx; 576 /* Vector to dJdm scalar? Dot?*/ 577 J = PetscAbsReal(Jxhat_comp - Jxhat_pred); 578 PetscCall(PetscPrintf(comm, "J(xhat): %g, predicted: %g, diff %g\n", (double)Jxhat_comp, (double)Jxhat_pred, (double)J)); 579 Js[i] = J; 580 hs[i] = h; 581 } 582 for (j = 1; j < numValues; j++) { 583 temp = PetscLogReal(Js[j] / Js[j - 1]) / PetscLogReal(hs[j] / hs[j - 1]); 584 PetscCall(PetscPrintf(comm, "Convergence rate step %" PetscInt_FMT ": %g\n", j - 1, (double)temp)); 585 minrate = PetscMin(minrate, temp); 586 } 587 /* If O is not ~2, then the test is wrong */ 588 PetscCall(PetscFree2(Js, hs)); 589 *C = minrate; 590 PetscCall(VecDestroy(&dx)); 591 PetscCall(VecDestroy(&xhat)); 592 PetscCall(VecDestroy(&g)); 593 PetscFunctionReturn(PETSC_SUCCESS); 594 } 595 596 int main(int argc, char **argv) 597 { 598 UserCtx ctx; 599 Tao tao; 600 Vec x; 601 Mat H; 602 603 PetscFunctionBeginUser; 604 PetscCall(PetscInitialize(&argc, &argv, NULL, help)); 605 PetscCall(PetscNew(&ctx)); 606 PetscCall(ConfigureContext(ctx)); 607 /* Define two functions that could pass as objectives to TaoSetObjective(): one 608 * for the misfit component, and one for the regularization component */ 609 /* ObjectiveMisfit() and ObjectiveRegularization() */ 610 611 /* Define a single function that calls both components adds them together: the complete objective, 612 * in the absence of a Tao implementation that handles separability */ 613 /* ObjectiveComplete() */ 614 PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao)); 615 PetscCall(TaoSetType(tao, TAONM)); 616 PetscCall(TaoSetObjective(tao, ObjectiveComplete, (void *)ctx)); 617 PetscCall(TaoSetGradient(tao, NULL, GradientComplete, (void *)ctx)); 618 PetscCall(MatDuplicate(ctx->W, MAT_SHARE_NONZERO_PATTERN, &H)); 619 PetscCall(TaoSetHessian(tao, H, H, HessianComplete, (void *)ctx)); 620 PetscCall(MatCreateVecs(ctx->F, NULL, &x)); 621 PetscCall(VecSet(x, 0.)); 622 PetscCall(TaoSetSolution(tao, x)); 623 PetscCall(TaoSetFromOptions(tao)); 624 if (ctx->use_admm) PetscCall(TaoSolveADMM(ctx, x)); 625 else PetscCall(TaoSolve(tao)); 626 /* examine solution */ 627 PetscCall(VecViewFromOptions(x, NULL, "-view_sol")); 628 if (ctx->taylor) { 629 PetscReal rate; 630 PetscCall(TaylorTest(ctx, tao, x, &rate)); 631 } 632 if (ctx->soft) PetscCall(TaoSoftThreshold(x, 0., 0., x)); 633 PetscCall(MatDestroy(&H)); 634 PetscCall(TaoDestroy(&tao)); 635 PetscCall(VecDestroy(&x)); 636 PetscCall(DestroyContext(&ctx)); 637 PetscCall(PetscFinalize()); 638 return 0; 639 } 640 641 /*TEST 642 643 build: 644 requires: !complex 645 646 test: 647 suffix: 0 648 args: 649 650 test: 651 suffix: l1_1 652 args: -p 1 -tao_type lmvm -alpha 1. -epsilon 1.e-7 -m 64 -n 64 -view_sol -matrix_format 1 653 654 test: 655 suffix: hessian_1 656 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nls 657 658 test: 659 suffix: hessian_2 660 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nls 661 662 test: 663 suffix: nm_1 664 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nm -tao_max_it 50 665 666 test: 667 suffix: nm_2 668 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nm -tao_max_it 50 669 670 test: 671 suffix: lmvm_1 672 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type lmvm -tao_max_it 40 673 674 test: 675 suffix: lmvm_2 676 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type lmvm -tao_max_it 15 677 678 test: 679 suffix: soft_threshold_admm_1 680 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm 681 682 test: 683 suffix: hessian_admm_1 684 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nls -misfit_tao_type nls 685 686 test: 687 suffix: hessian_admm_2 688 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nls -misfit_tao_type nls 689 690 test: 691 suffix: nm_admm_1 692 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nm -misfit_tao_type nm 693 694 test: 695 suffix: nm_admm_2 696 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nm -misfit_tao_type nm -iter 7 697 698 test: 699 suffix: lmvm_admm_1 700 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm 701 702 test: 703 suffix: lmvm_admm_2 704 args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm 705 706 test: 707 suffix: soft 708 args: -taylor 0 -soft 1 709 output_file: output/empty.out 710 711 TEST*/ 712