Lines Matching refs:ctx
38 static PetscErrorCode CreateRHS(UserCtx ctx) in CreateRHS() argument
42 PetscCall(VecCreate(PETSC_COMM_WORLD, &ctx->d)); in CreateRHS()
43 PetscCall(VecSetSizes(ctx->d, PETSC_DECIDE, ctx->m)); in CreateRHS()
44 PetscCall(VecSetFromOptions(ctx->d)); in CreateRHS()
45 PetscCall(VecSetRandom(ctx->d, ctx->rctx)); in CreateRHS()
49 static PetscErrorCode CreateMatrix(UserCtx ctx) in CreateMatrix() argument
56 PetscCall(MatCreate(PETSC_COMM_WORLD, &ctx->F)); in CreateMatrix()
57 PetscCall(MatSetSizes(ctx->F, PETSC_DECIDE, PETSC_DECIDE, ctx->m, ctx->n)); in CreateMatrix()
58 …PetscCall(MatSetType(ctx->F, MATAIJ)); /* TODO: Decide specific SetType o… in CreateMatrix()
59 …PetscCall(MatMPIAIJSetPreallocation(ctx->F, 5, NULL, 5, NULL)); /*TODO: some number other than 5?*/ in CreateMatrix()
60 PetscCall(MatSeqAIJSetPreallocation(ctx->F, 5, NULL)); in CreateMatrix()
61 PetscCall(MatSetUp(ctx->F)); in CreateMatrix()
62 PetscCall(MatGetOwnershipRange(ctx->F, &Istart, &Iend)); in CreateMatrix()
67 if (!ctx->matops) { in CreateMatrix()
68 …PetscCheck(ctx->m == ctx->n, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Stencil matrix must be square"); in CreateMatrix()
69 gridN = (PetscInt)PetscSqrtReal((PetscReal)ctx->m); in CreateMatrix()
70 …PetscCheck(gridN * gridN == ctx->m, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Number of rows must be s… in CreateMatrix()
82 PetscCall(MatSetValue(ctx->F, Ii, Ii, 4., INSERT_VALUES)); in CreateMatrix()
83 PetscCall(MatSetValue(ctx->F, Ii, I_n, -1., INSERT_VALUES)); in CreateMatrix()
84 PetscCall(MatSetValue(ctx->F, Ii, I_s, -1., INSERT_VALUES)); in CreateMatrix()
85 PetscCall(MatSetValue(ctx->F, Ii, I_e, -1., INSERT_VALUES)); in CreateMatrix()
86 PetscCall(MatSetValue(ctx->F, Ii, I_w, -1., INSERT_VALUES)); in CreateMatrix()
88 } else PetscCall(MatSetRandom(ctx->F, ctx->rctx)); in CreateMatrix()
89 PetscCall(MatAssemblyBegin(ctx->F, MAT_FINAL_ASSEMBLY)); in CreateMatrix()
90 PetscCall(MatAssemblyEnd(ctx->F, MAT_FINAL_ASSEMBLY)); in CreateMatrix()
93 if (!ctx->matops) PetscCall(MatSetOption(ctx->F, MAT_SYMMETRIC, PETSC_TRUE)); in CreateMatrix()
94 PetscCall(MatTransposeMatMult(ctx->F, ctx->F, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &ctx->W)); in CreateMatrix()
96 PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &ctx->Hm)); in CreateMatrix()
97 PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &ctx->Hr)); in CreateMatrix()
101 static PetscErrorCode SetupWorkspace(UserCtx ctx) in SetupWorkspace() argument
106 PetscCall(MatCreateVecs(ctx->F, &ctx->workLeft[0], &ctx->workRight[0])); in SetupWorkspace()
107 for (i = 1; i < NWORKLEFT; i++) PetscCall(VecDuplicate(ctx->workLeft[0], &ctx->workLeft[i])); in SetupWorkspace()
108 for (i = 1; i < NWORKRIGHT; i++) PetscCall(VecDuplicate(ctx->workRight[0], &ctx->workRight[i])); in SetupWorkspace()
112 static PetscErrorCode ConfigureContext(UserCtx ctx) in ConfigureContext() argument
115 ctx->m = 16; in ConfigureContext()
116 ctx->n = 16; in ConfigureContext()
117 ctx->eps = 1.e-3; in ConfigureContext()
118 ctx->abstol = 1.e-4; in ConfigureContext()
119 ctx->reltol = 1.e-2; in ConfigureContext()
120 ctx->hStart = 1.; in ConfigureContext()
121 ctx->hMin = 1.e-3; in ConfigureContext()
122 ctx->hFactor = 0.5; in ConfigureContext()
123 ctx->alpha = 1.; in ConfigureContext()
124 ctx->mu = 1.0; in ConfigureContext()
125 ctx->matops = 0; in ConfigureContext()
126 ctx->iter = 10; in ConfigureContext()
127 ctx->p = NORM_2; in ConfigureContext()
128 ctx->soft = PETSC_FALSE; in ConfigureContext()
129 ctx->taylor = PETSC_TRUE; in ConfigureContext()
130 ctx->use_admm = PETSC_FALSE; in ConfigureContext()
132 PetscCall(PetscOptionsInt("-m", "The row dimension of matrix F", "ex4.c", ctx->m, &ctx->m, NULL)); in ConfigureContext()
133 …PetscCall(PetscOptionsInt("-n", "The column dimension of matrix F", "ex4.c", ctx->n, &ctx->n, NULL… in ConfigureContext()
134 … "Decide format of F matrix. 0 for stencil, 1 for random", "ex4.c", ctx->matops, &ctx->matops, NUL… in ConfigureContext()
135 …PetscCall(PetscOptionsInt("-iter", "Iteration number ADMM", "ex4.c", ctx->iter, &ctx->iter, NULL)); in ConfigureContext()
136 …Real("-alpha", "The regularization multiplier. 1 default", "ex4.c", ctx->alpha, &ctx->alpha, NULL)… in ConfigureContext()
137 …| in the denominator to approximate the gradient of ||x||_1", "ex4.c", ctx->eps, &ctx->eps, NULL)); in ConfigureContext()
138 …tionsReal("-mu", "The augmented lagrangian multiplier in ADMM", "ex4.c", ctx->mu, &ctx->mu, NULL)); in ConfigureContext()
139 …sReal("-hStart", "Taylor test starting point. 1 default.", "ex4.c", ctx->hStart, &ctx->hStart, NUL… in ConfigureContext()
140 …"-hFactor", "Taylor test multiplier factor. 0.5 default", "ex4.c", ctx->hFactor, &ctx->hFactor, NU… in ConfigureContext()
141 …eal("-hMin", "Taylor test ending condition. 1.e-3 default", "ex4.c", ctx->hMin, &ctx->hMin, NULL)); in ConfigureContext()
142 …onsReal("-abstol", "Absolute stopping criterion for ADMM", "ex4.c", ctx->abstol, &ctx->abstol, NUL… in ConfigureContext()
143 …onsReal("-reltol", "Relative stopping criterion for ADMM", "ex4.c", ctx->reltol, &ctx->reltol, NUL… in ConfigureContext()
144 …sBool("-taylor", "Flag for Taylor test. Default is true.", "ex4.c", ctx->taylor, &ctx->taylor, NUL… in ConfigureContext()
145 … for testing soft threshold no-op case. Default is false.", "ex4.c", ctx->soft, &ctx->soft, NULL)); in ConfigureContext()
146 …Bool("-use_admm", "Use the ADMM solver in this example.", "ex4.c", ctx->use_admm, &ctx->use_admm, … in ConfigureContext()
147 …cOptionsEnum("-p", "Norm type.", "ex4.c", NormTypes, (PetscEnum)ctx->p, (PetscEnum *)&ctx->p, NULL… in ConfigureContext()
150 PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &ctx->rctx)); in ConfigureContext()
151 PetscCall(PetscRandomSetFromOptions(ctx->rctx)); in ConfigureContext()
152 PetscCall(CreateMatrix(ctx)); in ConfigureContext()
153 PetscCall(CreateRHS(ctx)); in ConfigureContext()
154 PetscCall(SetupWorkspace(ctx)); in ConfigureContext()
158 static PetscErrorCode DestroyContext(UserCtx *ctx) in DestroyContext() argument
163 PetscCall(MatDestroy(&(*ctx)->F)); in DestroyContext()
164 PetscCall(MatDestroy(&(*ctx)->W)); in DestroyContext()
165 PetscCall(MatDestroy(&(*ctx)->Hm)); in DestroyContext()
166 PetscCall(MatDestroy(&(*ctx)->Hr)); in DestroyContext()
167 PetscCall(VecDestroy(&(*ctx)->d)); in DestroyContext()
168 for (i = 0; i < NWORKLEFT; i++) PetscCall(VecDestroy(&(*ctx)->workLeft[i])); in DestroyContext()
169 for (i = 0; i < NWORKRIGHT; i++) PetscCall(VecDestroy(&(*ctx)->workRight[i])); in DestroyContext()
170 PetscCall(PetscRandomDestroy(&(*ctx)->rctx)); in DestroyContext()
171 PetscCall(PetscFree(*ctx)); in DestroyContext()
178 UserCtx ctx = (UserCtx)_ctx; in ObjectiveMisfit() local
182 y = ctx->workLeft[0]; in ObjectiveMisfit()
183 PetscCall(MatMult(ctx->F, x, y)); in ObjectiveMisfit()
184 PetscCall(VecAXPY(y, -1., ctx->d)); in ObjectiveMisfit()
193 UserCtx ctx = (UserCtx)_ctx; in GradientMisfit() local
198 FTFx = ctx->workRight[0]; in GradientMisfit()
199 FTd = ctx->workRight[1]; in GradientMisfit()
200 PetscCall(MatMult(ctx->W, x, FTFx)); in GradientMisfit()
201 PetscCall(MatMultTranspose(ctx->F, ctx->d, FTd)); in GradientMisfit()
209 UserCtx ctx = (UserCtx)_ctx; in HessianMisfit() local
212 if (H != ctx->W) PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN)); in HessianMisfit()
213 if (Hpre != ctx->W) PetscCall(MatCopy(ctx->W, Hpre, DIFFERENT_NONZERO_PATTERN)); in HessianMisfit()
221 UserCtx ctx = (UserCtx)_ctx; in ObjectiveMisfitADMM() local
226 mu = ctx->mu; in ObjectiveMisfitADMM()
227 z = ctx->workRight[5]; in ObjectiveMisfitADMM()
228 u = ctx->workRight[6]; in ObjectiveMisfitADMM()
229 temp = ctx->workRight[10]; in ObjectiveMisfitADMM()
245 UserCtx ctx = (UserCtx)_ctx; in GradientMisfitADMM() local
250 mu = ctx->mu; in GradientMisfitADMM()
251 z = ctx->workRight[5]; in GradientMisfitADMM()
252 u = ctx->workRight[6]; in GradientMisfitADMM()
253 temp = ctx->workRight[10]; in GradientMisfitADMM()
266 UserCtx ctx = (UserCtx)_ctx; in HessianMisfitADMM() local
269 PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN)); in HessianMisfitADMM()
270 PetscCall(MatShift(H, ctx->mu)); in HessianMisfitADMM()
278 UserCtx ctx = (UserCtx)_ctx; in ObjectiveRegularization() local
283 PetscCall(VecNorm(x, ctx->p, &norm)); in ObjectiveRegularization()
284 if (ctx->p == NORM_2) norm = 0.5 * norm * norm; in ObjectiveRegularization()
285 *J = ctx->alpha * norm; in ObjectiveRegularization()
294 UserCtx ctx = (UserCtx)_ctx; in GradientRegularization() local
295 PetscReal eps = ctx->eps; in GradientRegularization()
298 if (ctx->p == NORM_2) { in GradientRegularization()
300 } else if (ctx->p == NORM_1) { in GradientRegularization()
301 PetscCall(VecCopy(x, ctx->workRight[1])); in GradientRegularization()
302 PetscCall(VecAbs(ctx->workRight[1])); in GradientRegularization()
303 PetscCall(VecShift(ctx->workRight[1], eps)); in GradientRegularization()
304 PetscCall(VecPointwiseDivide(V, x, ctx->workRight[1])); in GradientRegularization()
313 UserCtx ctx = (UserCtx)_ctx; in HessianRegularization() local
314 PetscReal eps = ctx->eps; in HessianRegularization()
318 if (ctx->p == NORM_2) { in HessianRegularization()
321 PetscCall(MatShift(H, ctx->mu)); in HessianRegularization()
324 PetscCall(MatShift(Hpre, ctx->mu)); in HessianRegularization()
326 } else if (ctx->p == NORM_1) { in HessianRegularization()
328 copy1 = ctx->workRight[1]; in HessianRegularization()
329 copy2 = ctx->workRight[2]; in HessianRegularization()
330 copy3 = ctx->workRight[3]; in HessianRegularization()
350 PetscCall(VecScale(copy1, ctx->mu)); in HessianRegularization()
365 UserCtx ctx = (UserCtx)_ctx; in ObjectiveRegularizationADMM() local
370 mu = ctx->mu; in ObjectiveRegularizationADMM()
371 x = ctx->workRight[4]; in ObjectiveRegularizationADMM()
372 u = ctx->workRight[6]; in ObjectiveRegularizationADMM()
373 temp = ctx->workRight[10]; in ObjectiveRegularizationADMM()
389 UserCtx ctx = (UserCtx)_ctx; in GradientRegularizationADMM() local
394 mu = ctx->mu; in GradientRegularizationADMM()
395 x = ctx->workRight[4]; in GradientRegularizationADMM()
396 u = ctx->workRight[6]; in GradientRegularizationADMM()
397 temp = ctx->workRight[10]; in GradientRegularizationADMM()
410 UserCtx ctx = (UserCtx)_ctx; in HessianRegularizationADMM() local
413 if (ctx->p == NORM_2) { in HessianRegularizationADMM()
416 PetscCall(MatShift(H, ctx->mu)); in HessianRegularizationADMM()
419 PetscCall(MatShift(Hpre, ctx->mu)); in HessianRegularizationADMM()
421 } else if (ctx->p == NORM_1) { in HessianRegularizationADMM()
422 PetscCall(HessianMisfit(tao, x, H, Hpre, (void *)ctx)); in HessianRegularizationADMM()
423 PetscCall(MatShift(H, ctx->mu)); in HessianRegularizationADMM()
424 if (Hpre != H) PetscCall(MatShift(Hpre, ctx->mu)); in HessianRegularizationADMM()
431 static PetscErrorCode ObjectiveComplete(Tao tao, Vec x, PetscReal *J, PetscCtx ctx) in ObjectiveComplete() argument
436 PetscCall(ObjectiveMisfit(tao, x, &Jm, ctx)); in ObjectiveComplete()
437 PetscCall(ObjectiveRegularization(tao, x, &Jr, ctx)); in ObjectiveComplete()
444 static PetscErrorCode GradientComplete(Tao tao, Vec x, Vec V, PetscCtx ctx) in GradientComplete() argument
446 UserCtx cntx = (UserCtx)ctx; in GradientComplete()
449 PetscCall(GradientMisfit(tao, x, cntx->workRight[2], ctx)); in GradientComplete()
450 PetscCall(GradientRegularization(tao, x, cntx->workRight[3], ctx)); in GradientComplete()
457 static PetscErrorCode HessianComplete(Tao tao, Vec x, Mat H, Mat Hpre, PetscCtx ctx) in HessianComplete() argument
463 PetscCall(HessianMisfit(tao, x, H, H, ctx)); in HessianComplete()
464 PetscCall(HessianRegularization(tao, x, tempH, tempH, ctx)); in HessianComplete()
471 static PetscErrorCode TaoSolveADMM(UserCtx ctx, Vec x) in TaoSolveADMM() argument
480 xk = ctx->workRight[4]; in TaoSolveADMM()
481 z = ctx->workRight[5]; in TaoSolveADMM()
482 u = ctx->workRight[6]; in TaoSolveADMM()
483 diff = ctx->workRight[7]; in TaoSolveADMM()
484 zold = ctx->workRight[8]; in TaoSolveADMM()
485 zdiff = ctx->workRight[9]; in TaoSolveADMM()
486 temp = ctx->workRight[11]; in TaoSolveADMM()
487 mu = ctx->mu; in TaoSolveADMM()
491 PetscCall(TaoSetObjective(tao1, ObjectiveMisfitADMM, (void *)ctx)); in TaoSolveADMM()
492 PetscCall(TaoSetGradient(tao1, NULL, GradientMisfitADMM, (void *)ctx)); in TaoSolveADMM()
493 PetscCall(TaoSetHessian(tao1, ctx->Hm, ctx->Hm, HessianMisfitADMM, (void *)ctx)); in TaoSolveADMM()
499 if (ctx->p == NORM_2) { in TaoSolveADMM()
501 PetscCall(TaoSetObjective(tao2, ObjectiveRegularizationADMM, (void *)ctx)); in TaoSolveADMM()
502 PetscCall(TaoSetGradient(tao2, NULL, GradientRegularizationADMM, (void *)ctx)); in TaoSolveADMM()
503 PetscCall(TaoSetHessian(tao2, ctx->Hr, ctx->Hr, HessianRegularizationADMM, (void *)ctx)); in TaoSolveADMM()
510 for (i = 0; i < ctx->iter; i++) { in TaoSolveADMM()
513 if (ctx->p == NORM_1) { in TaoSolveADMM()
515 PetscCall(TaoSoftThreshold(temp, -ctx->alpha / mu, ctx->alpha / mu, z)); in TaoSolveADMM()
531 primal = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * PetscMax(x_norm, z_norm); in TaoSolveADMM()
534 dual = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * u_norm * mu; in TaoSolveADMM()
545 static PetscErrorCode TaylorTest(UserCtx ctx, Tao tao, Vec x, PetscReal *C) in TaylorTest() argument
562 PetscCall(VecSetRandom(dx, ctx->rctx)); in TaylorTest()
569 for (numValues = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor) numValues++; in TaylorTest()
571 for (i = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor, i++) { in TaylorTest()
598 UserCtx ctx; in main() local
605 PetscCall(PetscNew(&ctx)); in main()
606 PetscCall(ConfigureContext(ctx)); in main()
616 PetscCall(TaoSetObjective(tao, ObjectiveComplete, (void *)ctx)); in main()
617 PetscCall(TaoSetGradient(tao, NULL, GradientComplete, (void *)ctx)); in main()
618 PetscCall(MatDuplicate(ctx->W, MAT_SHARE_NONZERO_PATTERN, &H)); in main()
619 PetscCall(TaoSetHessian(tao, H, H, HessianComplete, (void *)ctx)); in main()
620 PetscCall(MatCreateVecs(ctx->F, NULL, &x)); in main()
624 if (ctx->use_admm) PetscCall(TaoSolveADMM(ctx, x)); in main()
628 if (ctx->taylor) { in main()
630 PetscCall(TaylorTest(ctx, tao, x, &rate)); in main()
632 if (ctx->soft) PetscCall(TaoSoftThreshold(x, 0., 0., x)); in main()
636 PetscCall(DestroyContext(&ctx)); in main()