1c4762a1bSJed Brown static char help[] = "Simple example to test separable objective optimizers.\n"; 2c4762a1bSJed Brown 3c4762a1bSJed Brown #include <petsc.h> 4c4762a1bSJed Brown #include <petsctao.h> 5c4762a1bSJed Brown #include <petscvec.h> 6c4762a1bSJed Brown #include <petscmath.h> 7c4762a1bSJed Brown 8c4762a1bSJed Brown #define NWORKLEFT 4 9c4762a1bSJed Brown #define NWORKRIGHT 12 10c4762a1bSJed Brown 119371c9d4SSatish Balay typedef struct _UserCtx { 12c4762a1bSJed Brown PetscInt m; /* The row dimension of F */ 13c4762a1bSJed Brown PetscInt n; /* The column dimension of F */ 14c4762a1bSJed Brown PetscInt matops; /* Matrix format. 0 for stencil, 1 for random */ 15be87f6c0SPierre Jolivet PetscInt iter; /* Number of iterations for ADMM */ 16c4762a1bSJed Brown PetscReal hStart; /* Starting point for Taylor test */ 17c4762a1bSJed Brown PetscReal hFactor; /* Taylor test step factor */ 18c4762a1bSJed Brown PetscReal hMin; /* Taylor test end goal */ 19c4762a1bSJed Brown PetscReal alpha; /* regularization constant applied to || x ||_p */ 20c4762a1bSJed Brown PetscReal eps; /* small constant for approximating gradient of || x ||_1 */ 21c4762a1bSJed Brown PetscReal mu; /* the augmented Lagrangian term in ADMM */ 22c4762a1bSJed Brown PetscReal abstol; 23c4762a1bSJed Brown PetscReal reltol; 24c4762a1bSJed Brown Mat F; /* matrix in least squares component $(1/2) * || F x - d ||_2^2$ */ 25c4762a1bSJed Brown Mat W; /* Workspace matrix. ATA */ 26c4762a1bSJed Brown Mat Hm; /* Hessian Misfit*/ 27c4762a1bSJed Brown Mat Hr; /* Hessian Reg*/ 28c4762a1bSJed Brown Vec d; /* RHS in least squares component $(1/2) * || F x - d ||_2^2$ */ 29c4762a1bSJed Brown Vec workLeft[NWORKLEFT]; /* Workspace for temporary vec */ 30c4762a1bSJed Brown Vec workRight[NWORKRIGHT]; /* Workspace for temporary vec */ 31c4762a1bSJed Brown NormType p; 32c4762a1bSJed Brown PetscRandom rctx; 3384430a0dSHansol Suh PetscBool soft; 34c4762a1bSJed Brown PetscBool taylor; /* Flag to determine whether to run Taylor test or not */ 35c4762a1bSJed Brown PetscBool use_admm; /* Flag to determine whether to run Taylor test or not */ 36c4762a1bSJed Brown } *UserCtx; 37c4762a1bSJed Brown 38d71ae5a4SJacob Faibussowitsch static PetscErrorCode CreateRHS(UserCtx ctx) 39d71ae5a4SJacob Faibussowitsch { 40c4762a1bSJed Brown PetscFunctionBegin; 41c4762a1bSJed Brown /* build the rhs d in ctx */ 42*f4f49eeaSPierre Jolivet PetscCall(VecCreate(PETSC_COMM_WORLD, &ctx->d)); 439566063dSJacob Faibussowitsch PetscCall(VecSetSizes(ctx->d, PETSC_DECIDE, ctx->m)); 449566063dSJacob Faibussowitsch PetscCall(VecSetFromOptions(ctx->d)); 459566063dSJacob Faibussowitsch PetscCall(VecSetRandom(ctx->d, ctx->rctx)); 463ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 47c4762a1bSJed Brown } 48c4762a1bSJed Brown 49d71ae5a4SJacob Faibussowitsch static PetscErrorCode CreateMatrix(UserCtx ctx) 50d71ae5a4SJacob Faibussowitsch { 51c4762a1bSJed Brown PetscInt Istart, Iend, i, j, Ii, gridN, I_n, I_s, I_e, I_w; 52c4762a1bSJed Brown PetscLogStage stage; 53c4762a1bSJed Brown 54c4762a1bSJed Brown PetscFunctionBegin; 55c4762a1bSJed Brown /* build the matrix F in ctx */ 56*f4f49eeaSPierre Jolivet PetscCall(MatCreate(PETSC_COMM_WORLD, &ctx->F)); 579566063dSJacob Faibussowitsch PetscCall(MatSetSizes(ctx->F, PETSC_DECIDE, PETSC_DECIDE, ctx->m, ctx->n)); 589566063dSJacob Faibussowitsch PetscCall(MatSetType(ctx->F, MATAIJ)); /* TODO: Decide specific SetType other than dummy*/ 599566063dSJacob Faibussowitsch PetscCall(MatMPIAIJSetPreallocation(ctx->F, 5, NULL, 5, NULL)); /*TODO: some number other than 5?*/ 609566063dSJacob Faibussowitsch PetscCall(MatSeqAIJSetPreallocation(ctx->F, 5, NULL)); 619566063dSJacob Faibussowitsch PetscCall(MatSetUp(ctx->F)); 629566063dSJacob Faibussowitsch PetscCall(MatGetOwnershipRange(ctx->F, &Istart, &Iend)); 639566063dSJacob Faibussowitsch PetscCall(PetscLogStageRegister("Assembly", &stage)); 649566063dSJacob Faibussowitsch PetscCall(PetscLogStagePush(stage)); 65c4762a1bSJed Brown 663c859ba3SBarry Smith /* Set matrix elements in 2-D five point stencil format. */ 67*f4f49eeaSPierre Jolivet if (!ctx->matops) { 683c859ba3SBarry Smith PetscCheck(ctx->m == ctx->n, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Stencil matrix must be square"); 69c4762a1bSJed Brown gridN = (PetscInt)PetscSqrtReal((PetscReal)ctx->m); 703c859ba3SBarry Smith PetscCheck(gridN * gridN == ctx->m, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Number of rows must be square"); 71c4762a1bSJed Brown for (Ii = Istart; Ii < Iend; Ii++) { 729371c9d4SSatish Balay i = Ii / gridN; 739371c9d4SSatish Balay j = Ii % gridN; 74c4762a1bSJed Brown I_n = i * gridN + j + 1; 75c4762a1bSJed Brown if (j + 1 >= gridN) I_n = -1; 76c4762a1bSJed Brown I_s = i * gridN + j - 1; 77c4762a1bSJed Brown if (j - 1 < 0) I_s = -1; 78c4762a1bSJed Brown I_e = (i + 1) * gridN + j; 79c4762a1bSJed Brown if (i + 1 >= gridN) I_e = -1; 80c4762a1bSJed Brown I_w = (i - 1) * gridN + j; 81c4762a1bSJed Brown if (i - 1 < 0) I_w = -1; 829566063dSJacob Faibussowitsch PetscCall(MatSetValue(ctx->F, Ii, Ii, 4., INSERT_VALUES)); 839566063dSJacob Faibussowitsch PetscCall(MatSetValue(ctx->F, Ii, I_n, -1., INSERT_VALUES)); 849566063dSJacob Faibussowitsch PetscCall(MatSetValue(ctx->F, Ii, I_s, -1., INSERT_VALUES)); 859566063dSJacob Faibussowitsch PetscCall(MatSetValue(ctx->F, Ii, I_e, -1., INSERT_VALUES)); 869566063dSJacob Faibussowitsch PetscCall(MatSetValue(ctx->F, Ii, I_w, -1., INSERT_VALUES)); 87c4762a1bSJed Brown } 889566063dSJacob Faibussowitsch } else PetscCall(MatSetRandom(ctx->F, ctx->rctx)); 899566063dSJacob Faibussowitsch PetscCall(MatAssemblyBegin(ctx->F, MAT_FINAL_ASSEMBLY)); 909566063dSJacob Faibussowitsch PetscCall(MatAssemblyEnd(ctx->F, MAT_FINAL_ASSEMBLY)); 919566063dSJacob Faibussowitsch PetscCall(PetscLogStagePop()); 92c4762a1bSJed Brown /* Stencil matrix is symmetric. Setting symmetric flag for ICC/Cholesky preconditioner */ 93*f4f49eeaSPierre Jolivet if (!ctx->matops) PetscCall(MatSetOption(ctx->F, MAT_SYMMETRIC, PETSC_TRUE)); 94*f4f49eeaSPierre Jolivet PetscCall(MatTransposeMatMult(ctx->F, ctx->F, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &ctx->W)); 95c4762a1bSJed Brown /* Setup Hessian Workspace in same shape as W */ 96*f4f49eeaSPierre Jolivet PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &ctx->Hm)); 97*f4f49eeaSPierre Jolivet PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &ctx->Hr)); 983ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 99c4762a1bSJed Brown } 100c4762a1bSJed Brown 101d71ae5a4SJacob Faibussowitsch static PetscErrorCode SetupWorkspace(UserCtx ctx) 102d71ae5a4SJacob Faibussowitsch { 103c4762a1bSJed Brown PetscInt i; 104c4762a1bSJed Brown 105c4762a1bSJed Brown PetscFunctionBegin; 1069566063dSJacob Faibussowitsch PetscCall(MatCreateVecs(ctx->F, &ctx->workLeft[0], &ctx->workRight[0])); 107*f4f49eeaSPierre Jolivet for (i = 1; i < NWORKLEFT; i++) PetscCall(VecDuplicate(ctx->workLeft[0], &ctx->workLeft[i])); 108*f4f49eeaSPierre Jolivet for (i = 1; i < NWORKRIGHT; i++) PetscCall(VecDuplicate(ctx->workRight[0], &ctx->workRight[i])); 1093ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 110c4762a1bSJed Brown } 111c4762a1bSJed Brown 112d71ae5a4SJacob Faibussowitsch static PetscErrorCode ConfigureContext(UserCtx ctx) 113d71ae5a4SJacob Faibussowitsch { 114c4762a1bSJed Brown PetscFunctionBegin; 115c4762a1bSJed Brown ctx->m = 16; 116c4762a1bSJed Brown ctx->n = 16; 117c4762a1bSJed Brown ctx->eps = 1.e-3; 118c4762a1bSJed Brown ctx->abstol = 1.e-4; 119c4762a1bSJed Brown ctx->reltol = 1.e-2; 120c4762a1bSJed Brown ctx->hStart = 1.; 121c4762a1bSJed Brown ctx->hMin = 1.e-3; 122c4762a1bSJed Brown ctx->hFactor = 0.5; 123c4762a1bSJed Brown ctx->alpha = 1.; 124c4762a1bSJed Brown ctx->mu = 1.0; 125c4762a1bSJed Brown ctx->matops = 0; 126c4762a1bSJed Brown ctx->iter = 10; 127c4762a1bSJed Brown ctx->p = NORM_2; 12884430a0dSHansol Suh ctx->soft = PETSC_FALSE; 129c4762a1bSJed Brown ctx->taylor = PETSC_TRUE; 130c4762a1bSJed Brown ctx->use_admm = PETSC_FALSE; 131d0609cedSBarry Smith PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "ex4.c"); 132*f4f49eeaSPierre Jolivet PetscCall(PetscOptionsInt("-m", "The row dimension of matrix F", "ex4.c", ctx->m, &ctx->m, NULL)); 133*f4f49eeaSPierre Jolivet PetscCall(PetscOptionsInt("-n", "The column dimension of matrix F", "ex4.c", ctx->n, &ctx->n, NULL)); 134*f4f49eeaSPierre Jolivet PetscCall(PetscOptionsInt("-matrix_format", "Decide format of F matrix. 0 for stencil, 1 for random", "ex4.c", ctx->matops, &ctx->matops, NULL)); 135*f4f49eeaSPierre Jolivet PetscCall(PetscOptionsInt("-iter", "Iteration number ADMM", "ex4.c", ctx->iter, &ctx->iter, NULL)); 136*f4f49eeaSPierre Jolivet PetscCall(PetscOptionsReal("-alpha", "The regularization multiplier. 1 default", "ex4.c", ctx->alpha, &ctx->alpha, NULL)); 137*f4f49eeaSPierre Jolivet PetscCall(PetscOptionsReal("-epsilon", "The small constant added to |x_i| in the denominator to approximate the gradient of ||x||_1", "ex4.c", ctx->eps, &ctx->eps, NULL)); 138*f4f49eeaSPierre Jolivet PetscCall(PetscOptionsReal("-mu", "The augmented lagrangian multiplier in ADMM", "ex4.c", ctx->mu, &ctx->mu, NULL)); 139*f4f49eeaSPierre Jolivet PetscCall(PetscOptionsReal("-hStart", "Taylor test starting point. 1 default.", "ex4.c", ctx->hStart, &ctx->hStart, NULL)); 140*f4f49eeaSPierre Jolivet PetscCall(PetscOptionsReal("-hFactor", "Taylor test multiplier factor. 0.5 default", "ex4.c", ctx->hFactor, &ctx->hFactor, NULL)); 141*f4f49eeaSPierre Jolivet PetscCall(PetscOptionsReal("-hMin", "Taylor test ending condition. 1.e-3 default", "ex4.c", ctx->hMin, &ctx->hMin, NULL)); 142*f4f49eeaSPierre Jolivet PetscCall(PetscOptionsReal("-abstol", "Absolute stopping criterion for ADMM", "ex4.c", ctx->abstol, &ctx->abstol, NULL)); 143*f4f49eeaSPierre Jolivet PetscCall(PetscOptionsReal("-reltol", "Relative stopping criterion for ADMM", "ex4.c", ctx->reltol, &ctx->reltol, NULL)); 144*f4f49eeaSPierre Jolivet PetscCall(PetscOptionsBool("-taylor", "Flag for Taylor test. Default is true.", "ex4.c", ctx->taylor, &ctx->taylor, NULL)); 145*f4f49eeaSPierre Jolivet PetscCall(PetscOptionsBool("-soft", "Flag for testing soft threshold no-op case. Default is false.", "ex4.c", ctx->soft, &ctx->soft, NULL)); 146*f4f49eeaSPierre Jolivet PetscCall(PetscOptionsBool("-use_admm", "Use the ADMM solver in this example.", "ex4.c", ctx->use_admm, &ctx->use_admm, NULL)); 147*f4f49eeaSPierre Jolivet PetscCall(PetscOptionsEnum("-p", "Norm type.", "ex4.c", NormTypes, (PetscEnum)ctx->p, (PetscEnum *)&ctx->p, NULL)); 148d0609cedSBarry Smith PetscOptionsEnd(); 149c4762a1bSJed Brown /* Creating random ctx */ 150*f4f49eeaSPierre Jolivet PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &ctx->rctx)); 1519566063dSJacob Faibussowitsch PetscCall(PetscRandomSetFromOptions(ctx->rctx)); 1529566063dSJacob Faibussowitsch PetscCall(CreateMatrix(ctx)); 1539566063dSJacob Faibussowitsch PetscCall(CreateRHS(ctx)); 1549566063dSJacob Faibussowitsch PetscCall(SetupWorkspace(ctx)); 1553ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 156c4762a1bSJed Brown } 157c4762a1bSJed Brown 158d71ae5a4SJacob Faibussowitsch static PetscErrorCode DestroyContext(UserCtx *ctx) 159d71ae5a4SJacob Faibussowitsch { 160c4762a1bSJed Brown PetscInt i; 161c4762a1bSJed Brown 162c4762a1bSJed Brown PetscFunctionBegin; 1639566063dSJacob Faibussowitsch PetscCall(MatDestroy(&((*ctx)->F))); 1649566063dSJacob Faibussowitsch PetscCall(MatDestroy(&((*ctx)->W))); 1659566063dSJacob Faibussowitsch PetscCall(MatDestroy(&((*ctx)->Hm))); 1669566063dSJacob Faibussowitsch PetscCall(MatDestroy(&((*ctx)->Hr))); 1679566063dSJacob Faibussowitsch PetscCall(VecDestroy(&((*ctx)->d))); 16848a46eb9SPierre Jolivet for (i = 0; i < NWORKLEFT; i++) PetscCall(VecDestroy(&((*ctx)->workLeft[i]))); 16948a46eb9SPierre Jolivet for (i = 0; i < NWORKRIGHT; i++) PetscCall(VecDestroy(&((*ctx)->workRight[i]))); 1709566063dSJacob Faibussowitsch PetscCall(PetscRandomDestroy(&((*ctx)->rctx))); 1719566063dSJacob Faibussowitsch PetscCall(PetscFree(*ctx)); 1723ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 173c4762a1bSJed Brown } 174c4762a1bSJed Brown 175c4762a1bSJed Brown /* compute (1/2) * ||F x - d||^2 */ 176d71ae5a4SJacob Faibussowitsch static PetscErrorCode ObjectiveMisfit(Tao tao, Vec x, PetscReal *J, void *_ctx) 177d71ae5a4SJacob Faibussowitsch { 178c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx; 179c4762a1bSJed Brown Vec y; 180c4762a1bSJed Brown 181c4762a1bSJed Brown PetscFunctionBegin; 182c4762a1bSJed Brown y = ctx->workLeft[0]; 1839566063dSJacob Faibussowitsch PetscCall(MatMult(ctx->F, x, y)); 1849566063dSJacob Faibussowitsch PetscCall(VecAXPY(y, -1., ctx->d)); 1859566063dSJacob Faibussowitsch PetscCall(VecDot(y, y, J)); 186c4762a1bSJed Brown *J *= 0.5; 1873ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 188c4762a1bSJed Brown } 189c4762a1bSJed Brown 190c4762a1bSJed Brown /* compute V = FTFx - FTd */ 191d71ae5a4SJacob Faibussowitsch static PetscErrorCode GradientMisfit(Tao tao, Vec x, Vec V, void *_ctx) 192d71ae5a4SJacob Faibussowitsch { 193c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx; 194c4762a1bSJed Brown Vec FTFx, FTd; 195c4762a1bSJed Brown 196c4762a1bSJed Brown PetscFunctionBegin; 197c4762a1bSJed Brown /* work1 is A^T Ax, work2 is Ab, W is A^T A*/ 198c4762a1bSJed Brown FTFx = ctx->workRight[0]; 199c4762a1bSJed Brown FTd = ctx->workRight[1]; 2009566063dSJacob Faibussowitsch PetscCall(MatMult(ctx->W, x, FTFx)); 2019566063dSJacob Faibussowitsch PetscCall(MatMultTranspose(ctx->F, ctx->d, FTd)); 2029566063dSJacob Faibussowitsch PetscCall(VecWAXPY(V, -1., FTd, FTFx)); 2033ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 204c4762a1bSJed Brown } 205c4762a1bSJed Brown 206c4762a1bSJed Brown /* returns FTF */ 207d71ae5a4SJacob Faibussowitsch static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) 208d71ae5a4SJacob Faibussowitsch { 209c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx; 210c4762a1bSJed Brown 211c4762a1bSJed Brown PetscFunctionBegin; 2129566063dSJacob Faibussowitsch if (H != ctx->W) PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN)); 2139566063dSJacob Faibussowitsch if (Hpre != ctx->W) PetscCall(MatCopy(ctx->W, Hpre, DIFFERENT_NONZERO_PATTERN)); 2143ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 215c4762a1bSJed Brown } 216c4762a1bSJed Brown 217c4762a1bSJed Brown /* computes augment Lagrangian objective (with scaled dual): 218c4762a1bSJed Brown * 0.5 * ||F x - d||^2 + 0.5 * mu ||x - z + u||^2 */ 219d71ae5a4SJacob Faibussowitsch static PetscErrorCode ObjectiveMisfitADMM(Tao tao, Vec x, PetscReal *J, void *_ctx) 220d71ae5a4SJacob Faibussowitsch { 221c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx; 222c4762a1bSJed Brown PetscReal mu, workNorm, misfit; 223c4762a1bSJed Brown Vec z, u, temp; 224c4762a1bSJed Brown 225c4762a1bSJed Brown PetscFunctionBegin; 226c4762a1bSJed Brown mu = ctx->mu; 227c4762a1bSJed Brown z = ctx->workRight[5]; 228c4762a1bSJed Brown u = ctx->workRight[6]; 229c4762a1bSJed Brown temp = ctx->workRight[10]; 230c4762a1bSJed Brown /* misfit = f(x) */ 2319566063dSJacob Faibussowitsch PetscCall(ObjectiveMisfit(tao, x, &misfit, _ctx)); 2329566063dSJacob Faibussowitsch PetscCall(VecCopy(x, temp)); 233c4762a1bSJed Brown /* temp = x - z + u */ 2349566063dSJacob Faibussowitsch PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u)); 235c4762a1bSJed Brown /* workNorm = ||x - z + u||^2 */ 2369566063dSJacob Faibussowitsch PetscCall(VecDot(temp, temp, &workNorm)); 237c4762a1bSJed Brown /* augment Lagrangian objective (with scaled dual): f(x) + 0.5 * mu ||x -z + u||^2 */ 238c4762a1bSJed Brown *J = misfit + 0.5 * mu * workNorm; 2393ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 240c4762a1bSJed Brown } 241c4762a1bSJed Brown 242c4762a1bSJed Brown /* computes FTFx - FTd mu*(x - z + u) */ 243d71ae5a4SJacob Faibussowitsch static PetscErrorCode GradientMisfitADMM(Tao tao, Vec x, Vec V, void *_ctx) 244d71ae5a4SJacob Faibussowitsch { 245c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx; 246c4762a1bSJed Brown PetscReal mu; 247c4762a1bSJed Brown Vec z, u, temp; 248c4762a1bSJed Brown 249c4762a1bSJed Brown PetscFunctionBegin; 250c4762a1bSJed Brown mu = ctx->mu; 251c4762a1bSJed Brown z = ctx->workRight[5]; 252c4762a1bSJed Brown u = ctx->workRight[6]; 253c4762a1bSJed Brown temp = ctx->workRight[10]; 2549566063dSJacob Faibussowitsch PetscCall(GradientMisfit(tao, x, V, _ctx)); 2559566063dSJacob Faibussowitsch PetscCall(VecCopy(x, temp)); 256c4762a1bSJed Brown /* temp = x - z + u */ 2579566063dSJacob Faibussowitsch PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u)); 258c4762a1bSJed Brown /* V = FTFx - FTd mu*(x - z + u) */ 2599566063dSJacob Faibussowitsch PetscCall(VecAXPY(V, mu, temp)); 2603ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 261c4762a1bSJed Brown } 262c4762a1bSJed Brown 263c4762a1bSJed Brown /* returns FTF + diag(mu) */ 264d71ae5a4SJacob Faibussowitsch static PetscErrorCode HessianMisfitADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) 265d71ae5a4SJacob Faibussowitsch { 266c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx; 267c4762a1bSJed Brown 268c4762a1bSJed Brown PetscFunctionBegin; 2699566063dSJacob Faibussowitsch PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN)); 2709566063dSJacob Faibussowitsch PetscCall(MatShift(H, ctx->mu)); 27148a46eb9SPierre Jolivet if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN)); 2723ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 273c4762a1bSJed Brown } 274c4762a1bSJed Brown 275c4762a1bSJed Brown /* computes || x ||_p (mult by 0.5 in case of NORM_2) */ 276d71ae5a4SJacob Faibussowitsch static PetscErrorCode ObjectiveRegularization(Tao tao, Vec x, PetscReal *J, void *_ctx) 277d71ae5a4SJacob Faibussowitsch { 278c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx; 279c4762a1bSJed Brown PetscReal norm; 280c4762a1bSJed Brown 281c4762a1bSJed Brown PetscFunctionBegin; 282c4762a1bSJed Brown *J = 0; 2839566063dSJacob Faibussowitsch PetscCall(VecNorm(x, ctx->p, &norm)); 284c4762a1bSJed Brown if (ctx->p == NORM_2) norm = 0.5 * norm * norm; 285c4762a1bSJed Brown *J = ctx->alpha * norm; 2863ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 287c4762a1bSJed Brown } 288c4762a1bSJed Brown 289c4762a1bSJed Brown /* NORM_2 Case: return x 290c4762a1bSJed Brown * NORM_1 Case: x/(|x| + eps) 291c4762a1bSJed Brown * Else: TODO */ 292d71ae5a4SJacob Faibussowitsch static PetscErrorCode GradientRegularization(Tao tao, Vec x, Vec V, void *_ctx) 293d71ae5a4SJacob Faibussowitsch { 294c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx; 295c4762a1bSJed Brown PetscReal eps = ctx->eps; 296c4762a1bSJed Brown 297c4762a1bSJed Brown PetscFunctionBegin; 298c4762a1bSJed Brown if (ctx->p == NORM_2) { 2999566063dSJacob Faibussowitsch PetscCall(VecCopy(x, V)); 300c4762a1bSJed Brown } else if (ctx->p == NORM_1) { 3019566063dSJacob Faibussowitsch PetscCall(VecCopy(x, ctx->workRight[1])); 3029566063dSJacob Faibussowitsch PetscCall(VecAbs(ctx->workRight[1])); 3039566063dSJacob Faibussowitsch PetscCall(VecShift(ctx->workRight[1], eps)); 3049566063dSJacob Faibussowitsch PetscCall(VecPointwiseDivide(V, x, ctx->workRight[1])); 305c4762a1bSJed Brown } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2"); 3063ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 307c4762a1bSJed Brown } 308c4762a1bSJed Brown 309c4762a1bSJed Brown /* NORM_2 Case: returns diag(mu) 310c4762a1bSJed Brown * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) */ 311d71ae5a4SJacob Faibussowitsch static PetscErrorCode HessianRegularization(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) 312d71ae5a4SJacob Faibussowitsch { 313c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx; 314c4762a1bSJed Brown PetscReal eps = ctx->eps; 315c4762a1bSJed Brown Vec copy1, copy2, copy3; 316c4762a1bSJed Brown 317c4762a1bSJed Brown PetscFunctionBegin; 318c4762a1bSJed Brown if (ctx->p == NORM_2) { 319c4762a1bSJed Brown /* Identity matrix scaled by mu */ 3209566063dSJacob Faibussowitsch PetscCall(MatZeroEntries(H)); 3219566063dSJacob Faibussowitsch PetscCall(MatShift(H, ctx->mu)); 322c4762a1bSJed Brown if (Hpre != H) { 3239566063dSJacob Faibussowitsch PetscCall(MatZeroEntries(Hpre)); 3249566063dSJacob Faibussowitsch PetscCall(MatShift(Hpre, ctx->mu)); 325c4762a1bSJed Brown } 326c4762a1bSJed Brown } else if (ctx->p == NORM_1) { 327c4762a1bSJed Brown /* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)) */ 328c4762a1bSJed Brown copy1 = ctx->workRight[1]; 329c4762a1bSJed Brown copy2 = ctx->workRight[2]; 330c4762a1bSJed Brown copy3 = ctx->workRight[3]; 331c4762a1bSJed Brown /* copy1 : 1/sqrt(x_i^2 + eps) */ 3329566063dSJacob Faibussowitsch PetscCall(VecCopy(x, copy1)); 3339566063dSJacob Faibussowitsch PetscCall(VecPow(copy1, 2)); 3349566063dSJacob Faibussowitsch PetscCall(VecShift(copy1, eps)); 3359566063dSJacob Faibussowitsch PetscCall(VecSqrtAbs(copy1)); 3369566063dSJacob Faibussowitsch PetscCall(VecReciprocal(copy1)); 337c4762a1bSJed Brown /* copy2: x_i^2.*/ 3389566063dSJacob Faibussowitsch PetscCall(VecCopy(x, copy2)); 3399566063dSJacob Faibussowitsch PetscCall(VecPow(copy2, 2)); 340c4762a1bSJed Brown /* copy3: abs(x_i^2 + eps) */ 3419566063dSJacob Faibussowitsch PetscCall(VecCopy(x, copy3)); 3429566063dSJacob Faibussowitsch PetscCall(VecPow(copy3, 2)); 3439566063dSJacob Faibussowitsch PetscCall(VecShift(copy3, eps)); 3449566063dSJacob Faibussowitsch PetscCall(VecAbs(copy3)); 345c4762a1bSJed Brown /* copy2: 1 - x_i^2/abs(x_i^2 + eps) */ 3469566063dSJacob Faibussowitsch PetscCall(VecPointwiseDivide(copy2, copy2, copy3)); 3479566063dSJacob Faibussowitsch PetscCall(VecScale(copy2, -1.)); 3489566063dSJacob Faibussowitsch PetscCall(VecShift(copy2, 1.)); 3499566063dSJacob Faibussowitsch PetscCall(VecAXPY(copy1, 1., copy2)); 3509566063dSJacob Faibussowitsch PetscCall(VecScale(copy1, ctx->mu)); 3519566063dSJacob Faibussowitsch PetscCall(MatZeroEntries(H)); 3529566063dSJacob Faibussowitsch PetscCall(MatDiagonalSet(H, copy1, INSERT_VALUES)); 353c4762a1bSJed Brown if (Hpre != H) { 3549566063dSJacob Faibussowitsch PetscCall(MatZeroEntries(Hpre)); 3559566063dSJacob Faibussowitsch PetscCall(MatDiagonalSet(Hpre, copy1, INSERT_VALUES)); 356c4762a1bSJed Brown } 357c4762a1bSJed Brown } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2"); 3583ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 359c4762a1bSJed Brown } 360c4762a1bSJed Brown 361c4762a1bSJed Brown /* NORM_2 Case: 0.5 || x ||_2 + 0.5 * mu * ||x + u - z||^2 362c4762a1bSJed Brown * Else : || x ||_2 + 0.5 * mu * ||x + u - z||^2 */ 363d71ae5a4SJacob Faibussowitsch static PetscErrorCode ObjectiveRegularizationADMM(Tao tao, Vec z, PetscReal *J, void *_ctx) 364d71ae5a4SJacob Faibussowitsch { 365c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx; 366c4762a1bSJed Brown PetscReal mu, workNorm, reg; 367c4762a1bSJed Brown Vec x, u, temp; 368c4762a1bSJed Brown 369c4762a1bSJed Brown PetscFunctionBegin; 370c4762a1bSJed Brown mu = ctx->mu; 371c4762a1bSJed Brown x = ctx->workRight[4]; 372c4762a1bSJed Brown u = ctx->workRight[6]; 373c4762a1bSJed Brown temp = ctx->workRight[10]; 3749566063dSJacob Faibussowitsch PetscCall(ObjectiveRegularization(tao, z, ®, _ctx)); 3759566063dSJacob Faibussowitsch PetscCall(VecCopy(z, temp)); 376c4762a1bSJed Brown /* temp = x + u -z */ 3779566063dSJacob Faibussowitsch PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u)); 378c4762a1bSJed Brown /* workNorm = ||x + u - z ||^2 */ 3799566063dSJacob Faibussowitsch PetscCall(VecDot(temp, temp, &workNorm)); 380c4762a1bSJed Brown *J = reg + 0.5 * mu * workNorm; 3813ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 382c4762a1bSJed Brown } 383c4762a1bSJed Brown 384c4762a1bSJed Brown /* NORM_2 Case: x - mu*(x + u - z) 385c4762a1bSJed Brown * NORM_1 Case: x/(|x| + eps) - mu*(x + u - z) 386c4762a1bSJed Brown * Else: TODO */ 387d71ae5a4SJacob Faibussowitsch static PetscErrorCode GradientRegularizationADMM(Tao tao, Vec z, Vec V, void *_ctx) 388d71ae5a4SJacob Faibussowitsch { 389c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx; 390c4762a1bSJed Brown PetscReal mu; 391c4762a1bSJed Brown Vec x, u, temp; 392c4762a1bSJed Brown 393c4762a1bSJed Brown PetscFunctionBegin; 394c4762a1bSJed Brown mu = ctx->mu; 395c4762a1bSJed Brown x = ctx->workRight[4]; 396c4762a1bSJed Brown u = ctx->workRight[6]; 397c4762a1bSJed Brown temp = ctx->workRight[10]; 3989566063dSJacob Faibussowitsch PetscCall(GradientRegularization(tao, z, V, _ctx)); 3999566063dSJacob Faibussowitsch PetscCall(VecCopy(z, temp)); 400c4762a1bSJed Brown /* temp = x + u -z */ 4019566063dSJacob Faibussowitsch PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u)); 4029566063dSJacob Faibussowitsch PetscCall(VecAXPY(V, -mu, temp)); 4033ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 404c4762a1bSJed Brown } 405c4762a1bSJed Brown 406c4762a1bSJed Brown /* NORM_2 Case: returns diag(mu) 407c4762a1bSJed Brown * NORM_1 Case: FTF + diag(mu) */ 408d71ae5a4SJacob Faibussowitsch static PetscErrorCode HessianRegularizationADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) 409d71ae5a4SJacob Faibussowitsch { 410c4762a1bSJed Brown UserCtx ctx = (UserCtx)_ctx; 411c4762a1bSJed Brown 412c4762a1bSJed Brown PetscFunctionBegin; 413c4762a1bSJed Brown if (ctx->p == NORM_2) { 414c4762a1bSJed Brown /* Identity matrix scaled by mu */ 4159566063dSJacob Faibussowitsch PetscCall(MatZeroEntries(H)); 4169566063dSJacob Faibussowitsch PetscCall(MatShift(H, ctx->mu)); 417c4762a1bSJed Brown if (Hpre != H) { 4189566063dSJacob Faibussowitsch PetscCall(MatZeroEntries(Hpre)); 4199566063dSJacob Faibussowitsch PetscCall(MatShift(Hpre, ctx->mu)); 420c4762a1bSJed Brown } 421c4762a1bSJed Brown } else if (ctx->p == NORM_1) { 4229566063dSJacob Faibussowitsch PetscCall(HessianMisfit(tao, x, H, Hpre, (void *)ctx)); 4239566063dSJacob Faibussowitsch PetscCall(MatShift(H, ctx->mu)); 4249566063dSJacob Faibussowitsch if (Hpre != H) PetscCall(MatShift(Hpre, ctx->mu)); 425c4762a1bSJed Brown } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2"); 4263ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 427c4762a1bSJed Brown } 428c4762a1bSJed Brown 429c4762a1bSJed Brown /* NORM_2 Case : (1/2) * ||F x - d||^2 + 0.5 * || x ||_p 430c4762a1bSJed Brown * NORM_1 Case : (1/2) * ||F x - d||^2 + || x ||_p */ 431d71ae5a4SJacob Faibussowitsch static PetscErrorCode ObjectiveComplete(Tao tao, Vec x, PetscReal *J, void *ctx) 432d71ae5a4SJacob Faibussowitsch { 433c4762a1bSJed Brown PetscReal Jm, Jr; 434c4762a1bSJed Brown 435c4762a1bSJed Brown PetscFunctionBegin; 4369566063dSJacob Faibussowitsch PetscCall(ObjectiveMisfit(tao, x, &Jm, ctx)); 4379566063dSJacob Faibussowitsch PetscCall(ObjectiveRegularization(tao, x, &Jr, ctx)); 438c4762a1bSJed Brown *J = Jm + Jr; 4393ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 440c4762a1bSJed Brown } 441c4762a1bSJed Brown 442c4762a1bSJed Brown /* NORM_2 Case: FTFx - FTd + x 443c4762a1bSJed Brown * NORM_1 Case: FTFx - FTd + x/(|x| + eps) */ 444d71ae5a4SJacob Faibussowitsch static PetscErrorCode GradientComplete(Tao tao, Vec x, Vec V, void *ctx) 445d71ae5a4SJacob Faibussowitsch { 446c4762a1bSJed Brown UserCtx cntx = (UserCtx)ctx; 447c4762a1bSJed Brown 448c4762a1bSJed Brown PetscFunctionBegin; 4499566063dSJacob Faibussowitsch PetscCall(GradientMisfit(tao, x, cntx->workRight[2], ctx)); 4509566063dSJacob Faibussowitsch PetscCall(GradientRegularization(tao, x, cntx->workRight[3], ctx)); 4519566063dSJacob Faibussowitsch PetscCall(VecWAXPY(V, 1, cntx->workRight[2], cntx->workRight[3])); 4523ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 453c4762a1bSJed Brown } 454c4762a1bSJed Brown 455c4762a1bSJed Brown /* NORM_2 Case: diag(mu) + FTF 456c4762a1bSJed Brown * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) + FTF */ 457d71ae5a4SJacob Faibussowitsch static PetscErrorCode HessianComplete(Tao tao, Vec x, Mat H, Mat Hpre, void *ctx) 458d71ae5a4SJacob Faibussowitsch { 459c4762a1bSJed Brown Mat tempH; 460c4762a1bSJed Brown 461c4762a1bSJed Brown PetscFunctionBegin; 4629566063dSJacob Faibussowitsch PetscCall(MatDuplicate(H, MAT_SHARE_NONZERO_PATTERN, &tempH)); 4639566063dSJacob Faibussowitsch PetscCall(HessianMisfit(tao, x, H, H, ctx)); 4649566063dSJacob Faibussowitsch PetscCall(HessianRegularization(tao, x, tempH, tempH, ctx)); 4659566063dSJacob Faibussowitsch PetscCall(MatAXPY(H, 1., tempH, DIFFERENT_NONZERO_PATTERN)); 46648a46eb9SPierre Jolivet if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN)); 4679566063dSJacob Faibussowitsch PetscCall(MatDestroy(&tempH)); 4683ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 469c4762a1bSJed Brown } 470c4762a1bSJed Brown 471d71ae5a4SJacob Faibussowitsch static PetscErrorCode TaoSolveADMM(UserCtx ctx, Vec x) 472d71ae5a4SJacob Faibussowitsch { 473c4762a1bSJed Brown PetscInt i; 474c4762a1bSJed Brown PetscReal u_norm, r_norm, s_norm, primal, dual, x_norm, z_norm; 475c4762a1bSJed Brown Tao tao1, tao2; 476c4762a1bSJed Brown Vec xk, z, u, diff, zold, zdiff, temp; 477c4762a1bSJed Brown PetscReal mu; 478c4762a1bSJed Brown 479c4762a1bSJed Brown PetscFunctionBegin; 480c4762a1bSJed Brown xk = ctx->workRight[4]; 481c4762a1bSJed Brown z = ctx->workRight[5]; 482c4762a1bSJed Brown u = ctx->workRight[6]; 483c4762a1bSJed Brown diff = ctx->workRight[7]; 484c4762a1bSJed Brown zold = ctx->workRight[8]; 485c4762a1bSJed Brown zdiff = ctx->workRight[9]; 486c4762a1bSJed Brown temp = ctx->workRight[11]; 487c4762a1bSJed Brown mu = ctx->mu; 4889566063dSJacob Faibussowitsch PetscCall(VecSet(u, 0.)); 4899566063dSJacob Faibussowitsch PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao1)); 4909566063dSJacob Faibussowitsch PetscCall(TaoSetType(tao1, TAONLS)); 4919566063dSJacob Faibussowitsch PetscCall(TaoSetObjective(tao1, ObjectiveMisfitADMM, (void *)ctx)); 4929566063dSJacob Faibussowitsch PetscCall(TaoSetGradient(tao1, NULL, GradientMisfitADMM, (void *)ctx)); 4939566063dSJacob Faibussowitsch PetscCall(TaoSetHessian(tao1, ctx->Hm, ctx->Hm, HessianMisfitADMM, (void *)ctx)); 4949566063dSJacob Faibussowitsch PetscCall(VecSet(xk, 0.)); 4959566063dSJacob Faibussowitsch PetscCall(TaoSetSolution(tao1, xk)); 4969566063dSJacob Faibussowitsch PetscCall(TaoSetOptionsPrefix(tao1, "misfit_")); 4979566063dSJacob Faibussowitsch PetscCall(TaoSetFromOptions(tao1)); 4989566063dSJacob Faibussowitsch PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao2)); 499c4762a1bSJed Brown if (ctx->p == NORM_2) { 5009566063dSJacob Faibussowitsch PetscCall(TaoSetType(tao2, TAONLS)); 5019566063dSJacob Faibussowitsch PetscCall(TaoSetObjective(tao2, ObjectiveRegularizationADMM, (void *)ctx)); 5029566063dSJacob Faibussowitsch PetscCall(TaoSetGradient(tao2, NULL, GradientRegularizationADMM, (void *)ctx)); 5039566063dSJacob Faibussowitsch PetscCall(TaoSetHessian(tao2, ctx->Hr, ctx->Hr, HessianRegularizationADMM, (void *)ctx)); 504c4762a1bSJed Brown } 5059566063dSJacob Faibussowitsch PetscCall(VecSet(z, 0.)); 5069566063dSJacob Faibussowitsch PetscCall(TaoSetSolution(tao2, z)); 5079566063dSJacob Faibussowitsch PetscCall(TaoSetOptionsPrefix(tao2, "reg_")); 5089566063dSJacob Faibussowitsch PetscCall(TaoSetFromOptions(tao2)); 509c4762a1bSJed Brown 510c4762a1bSJed Brown for (i = 0; i < ctx->iter; i++) { 5119566063dSJacob Faibussowitsch PetscCall(VecCopy(z, zold)); 5129566063dSJacob Faibussowitsch PetscCall(TaoSolve(tao1)); /* Updates xk */ 513c4762a1bSJed Brown if (ctx->p == NORM_1) { 5149566063dSJacob Faibussowitsch PetscCall(VecWAXPY(temp, 1., xk, u)); 5159566063dSJacob Faibussowitsch PetscCall(TaoSoftThreshold(temp, -ctx->alpha / mu, ctx->alpha / mu, z)); 516c4762a1bSJed Brown } else { 5179566063dSJacob Faibussowitsch PetscCall(TaoSolve(tao2)); /* Update zk */ 518c4762a1bSJed Brown } 519c4762a1bSJed Brown /* u = u + xk -z */ 5209566063dSJacob Faibussowitsch PetscCall(VecAXPBYPCZ(u, 1., -1., 1., xk, z)); 521c4762a1bSJed Brown /* r_norm : norm(x-z) */ 5229566063dSJacob Faibussowitsch PetscCall(VecWAXPY(diff, -1., z, xk)); 5239566063dSJacob Faibussowitsch PetscCall(VecNorm(diff, NORM_2, &r_norm)); 524c4762a1bSJed Brown /* s_norm : norm(-mu(z-zold)) */ 5259566063dSJacob Faibussowitsch PetscCall(VecWAXPY(zdiff, -1., zold, z)); 5269566063dSJacob Faibussowitsch PetscCall(VecNorm(zdiff, NORM_2, &s_norm)); 527c4762a1bSJed Brown s_norm = s_norm * mu; 528c4762a1bSJed Brown /* primal : sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z))*/ 5299566063dSJacob Faibussowitsch PetscCall(VecNorm(xk, NORM_2, &x_norm)); 5309566063dSJacob Faibussowitsch PetscCall(VecNorm(z, NORM_2, &z_norm)); 531c4762a1bSJed Brown primal = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * PetscMax(x_norm, z_norm); 532c4762a1bSJed Brown /* Duality : sqrt(n)*ABSTOL + RELTOL*norm(mu*u)*/ 5339566063dSJacob Faibussowitsch PetscCall(VecNorm(u, NORM_2, &u_norm)); 534c4762a1bSJed Brown dual = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * u_norm * mu; 53563a3b9bcSJacob Faibussowitsch PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao1), "Iter %" PetscInt_FMT " : ||x-z||: %g, mu*||z-zold||: %g\n", i, (double)r_norm, (double)s_norm)); 536c4762a1bSJed Brown if (r_norm < primal && s_norm < dual) break; 537c4762a1bSJed Brown } 5389566063dSJacob Faibussowitsch PetscCall(VecCopy(xk, x)); 5399566063dSJacob Faibussowitsch PetscCall(TaoDestroy(&tao1)); 5409566063dSJacob Faibussowitsch PetscCall(TaoDestroy(&tao2)); 5413ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 542c4762a1bSJed Brown } 543c4762a1bSJed Brown 544c4762a1bSJed Brown /* Second order Taylor remainder convergence test */ 545d71ae5a4SJacob Faibussowitsch static PetscErrorCode TaylorTest(UserCtx ctx, Tao tao, Vec x, PetscReal *C) 546d71ae5a4SJacob Faibussowitsch { 547c4762a1bSJed Brown PetscReal h, J, temp; 548c4762a1bSJed Brown PetscInt i, j; 549c4762a1bSJed Brown PetscInt numValues; 550c4762a1bSJed Brown PetscReal Jx, Jxhat_comp, Jxhat_pred; 551c4762a1bSJed Brown PetscReal *Js, *hs; 552c4762a1bSJed Brown PetscReal gdotdx; 553c4762a1bSJed Brown PetscReal minrate = PETSC_MAX_REAL; 554c4762a1bSJed Brown MPI_Comm comm = PetscObjectComm((PetscObject)x); 555c4762a1bSJed Brown Vec g, dx, xhat; 556c4762a1bSJed Brown 557c4762a1bSJed Brown PetscFunctionBegin; 5589566063dSJacob Faibussowitsch PetscCall(VecDuplicate(x, &g)); 5599566063dSJacob Faibussowitsch PetscCall(VecDuplicate(x, &xhat)); 560c4762a1bSJed Brown /* choose a perturbation direction */ 5619566063dSJacob Faibussowitsch PetscCall(VecDuplicate(x, &dx)); 5629566063dSJacob Faibussowitsch PetscCall(VecSetRandom(dx, ctx->rctx)); 563c4762a1bSJed Brown /* evaluate objective at x: J(x) */ 5649566063dSJacob Faibussowitsch PetscCall(TaoComputeObjective(tao, x, &Jx)); 565c4762a1bSJed Brown /* evaluate gradient at x, save in vector g */ 5669566063dSJacob Faibussowitsch PetscCall(TaoComputeGradient(tao, x, g)); 5679566063dSJacob Faibussowitsch PetscCall(VecDot(g, dx, &gdotdx)); 568c4762a1bSJed Brown 569c4762a1bSJed Brown for (numValues = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor) numValues++; 5709566063dSJacob Faibussowitsch PetscCall(PetscCalloc2(numValues, &Js, numValues, &hs)); 571c4762a1bSJed Brown for (i = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor, i++) { 5729566063dSJacob Faibussowitsch PetscCall(VecWAXPY(xhat, h, dx, x)); 5739566063dSJacob Faibussowitsch PetscCall(TaoComputeObjective(tao, xhat, &Jxhat_comp)); 574c4762a1bSJed Brown /* J(\hat(x)) \approx J(x) + g^T (xhat - x) = J(x) + h * g^T dx */ 575c4762a1bSJed Brown Jxhat_pred = Jx + h * gdotdx; 576c4762a1bSJed Brown /* Vector to dJdm scalar? Dot?*/ 577c4762a1bSJed Brown J = PetscAbsReal(Jxhat_comp - Jxhat_pred); 5789566063dSJacob Faibussowitsch PetscCall(PetscPrintf(comm, "J(xhat): %g, predicted: %g, diff %g\n", (double)Jxhat_comp, (double)Jxhat_pred, (double)J)); 579c4762a1bSJed Brown Js[i] = J; 580c4762a1bSJed Brown hs[i] = h; 581c4762a1bSJed Brown } 582c4762a1bSJed Brown for (j = 1; j < numValues; j++) { 583c4762a1bSJed Brown temp = PetscLogReal(Js[j] / Js[j - 1]) / PetscLogReal(hs[j] / hs[j - 1]); 58463a3b9bcSJacob Faibussowitsch PetscCall(PetscPrintf(comm, "Convergence rate step %" PetscInt_FMT ": %g\n", j - 1, (double)temp)); 585c4762a1bSJed Brown minrate = PetscMin(minrate, temp); 586c4762a1bSJed Brown } 587c4762a1bSJed Brown /* If O is not ~2, then the test is wrong */ 5889566063dSJacob Faibussowitsch PetscCall(PetscFree2(Js, hs)); 589c4762a1bSJed Brown *C = minrate; 5909566063dSJacob Faibussowitsch PetscCall(VecDestroy(&dx)); 5919566063dSJacob Faibussowitsch PetscCall(VecDestroy(&xhat)); 5929566063dSJacob Faibussowitsch PetscCall(VecDestroy(&g)); 5933ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 594c4762a1bSJed Brown } 595c4762a1bSJed Brown 596d71ae5a4SJacob Faibussowitsch int main(int argc, char **argv) 597d71ae5a4SJacob Faibussowitsch { 598c4762a1bSJed Brown UserCtx ctx; 599c4762a1bSJed Brown Tao tao; 600c4762a1bSJed Brown Vec x; 601c4762a1bSJed Brown Mat H; 602c4762a1bSJed Brown 603327415f7SBarry Smith PetscFunctionBeginUser; 6049566063dSJacob Faibussowitsch PetscCall(PetscInitialize(&argc, &argv, NULL, help)); 6059566063dSJacob Faibussowitsch PetscCall(PetscNew(&ctx)); 6069566063dSJacob Faibussowitsch PetscCall(ConfigureContext(ctx)); 607a82e8c82SStefano Zampini /* Define two functions that could pass as objectives to TaoSetObjective(): one 608c4762a1bSJed Brown * for the misfit component, and one for the regularization component */ 609c4762a1bSJed Brown /* ObjectiveMisfit() and ObjectiveRegularization() */ 610c4762a1bSJed Brown 611c4762a1bSJed Brown /* Define a single function that calls both components adds them together: the complete objective, 612c4762a1bSJed Brown * in the absence of a Tao implementation that handles separability */ 613c4762a1bSJed Brown /* ObjectiveComplete() */ 6149566063dSJacob Faibussowitsch PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao)); 6159566063dSJacob Faibussowitsch PetscCall(TaoSetType(tao, TAONM)); 6169566063dSJacob Faibussowitsch PetscCall(TaoSetObjective(tao, ObjectiveComplete, (void *)ctx)); 6179566063dSJacob Faibussowitsch PetscCall(TaoSetGradient(tao, NULL, GradientComplete, (void *)ctx)); 6189566063dSJacob Faibussowitsch PetscCall(MatDuplicate(ctx->W, MAT_SHARE_NONZERO_PATTERN, &H)); 6199566063dSJacob Faibussowitsch PetscCall(TaoSetHessian(tao, H, H, HessianComplete, (void *)ctx)); 6209566063dSJacob Faibussowitsch PetscCall(MatCreateVecs(ctx->F, NULL, &x)); 6219566063dSJacob Faibussowitsch PetscCall(VecSet(x, 0.)); 6229566063dSJacob Faibussowitsch PetscCall(TaoSetSolution(tao, x)); 6239566063dSJacob Faibussowitsch PetscCall(TaoSetFromOptions(tao)); 6241baa6e33SBarry Smith if (ctx->use_admm) PetscCall(TaoSolveADMM(ctx, x)); 6251baa6e33SBarry Smith else PetscCall(TaoSolve(tao)); 626c4762a1bSJed Brown /* examine solution */ 6279566063dSJacob Faibussowitsch PetscCall(VecViewFromOptions(x, NULL, "-view_sol")); 628c4762a1bSJed Brown if (ctx->taylor) { 629c4762a1bSJed Brown PetscReal rate; 6309566063dSJacob Faibussowitsch PetscCall(TaylorTest(ctx, tao, x, &rate)); 631c4762a1bSJed Brown } 63284430a0dSHansol Suh if (ctx->soft) { PetscCall(TaoSoftThreshold(x, 0., 0., x)); } 6339566063dSJacob Faibussowitsch PetscCall(MatDestroy(&H)); 6349566063dSJacob Faibussowitsch PetscCall(TaoDestroy(&tao)); 6359566063dSJacob Faibussowitsch PetscCall(VecDestroy(&x)); 6369566063dSJacob Faibussowitsch PetscCall(DestroyContext(&ctx)); 6379566063dSJacob Faibussowitsch PetscCall(PetscFinalize()); 638b122ec5aSJacob Faibussowitsch return 0; 639c4762a1bSJed Brown } 640c4762a1bSJed Brown 641c4762a1bSJed Brown /*TEST 642c4762a1bSJed Brown 643c4762a1bSJed Brown build: 644c4762a1bSJed Brown requires: !complex 645c4762a1bSJed Brown 646c4762a1bSJed Brown test: 647c4762a1bSJed Brown suffix: 0 648c4762a1bSJed Brown args: 649c4762a1bSJed Brown 650c4762a1bSJed Brown test: 651c4762a1bSJed Brown suffix: l1_1 652c4762a1bSJed Brown args: -p 1 -tao_type lmvm -alpha 1. -epsilon 1.e-7 -m 64 -n 64 -view_sol -matrix_format 1 653c4762a1bSJed Brown 654c4762a1bSJed Brown test: 655c4762a1bSJed Brown suffix: hessian_1 656c5f5e425SStefano Zampini args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nls 657c4762a1bSJed Brown 658c4762a1bSJed Brown test: 659c4762a1bSJed Brown suffix: hessian_2 660c5f5e425SStefano Zampini args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nls 661c4762a1bSJed Brown 662c4762a1bSJed Brown test: 663c4762a1bSJed Brown suffix: nm_1 664c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nm -tao_max_it 50 665c4762a1bSJed Brown 666c4762a1bSJed Brown test: 667c4762a1bSJed Brown suffix: nm_2 668c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nm -tao_max_it 50 669c4762a1bSJed Brown 670c4762a1bSJed Brown test: 671c4762a1bSJed Brown suffix: lmvm_1 672c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type lmvm -tao_max_it 40 673c4762a1bSJed Brown 674c4762a1bSJed Brown test: 675c4762a1bSJed Brown suffix: lmvm_2 676c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type lmvm -tao_max_it 15 677c4762a1bSJed Brown 678c4762a1bSJed Brown test: 679c4762a1bSJed Brown suffix: soft_threshold_admm_1 680c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm 681c4762a1bSJed Brown 682c4762a1bSJed Brown test: 683c4762a1bSJed Brown suffix: hessian_admm_1 684c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nls -misfit_tao_type nls 685c4762a1bSJed Brown 686c4762a1bSJed Brown test: 687c4762a1bSJed Brown suffix: hessian_admm_2 688c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nls -misfit_tao_type nls 689c4762a1bSJed Brown 690c4762a1bSJed Brown test: 691c4762a1bSJed Brown suffix: nm_admm_1 692c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nm -misfit_tao_type nm 693c4762a1bSJed Brown 694c4762a1bSJed Brown test: 695c4762a1bSJed Brown suffix: nm_admm_2 696c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nm -misfit_tao_type nm -iter 7 697c4762a1bSJed Brown 698c4762a1bSJed Brown test: 699c4762a1bSJed Brown suffix: lmvm_admm_1 700c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm 701c4762a1bSJed Brown 702c4762a1bSJed Brown test: 703c4762a1bSJed Brown suffix: lmvm_admm_2 704c4762a1bSJed Brown args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm 705c4762a1bSJed Brown 70684430a0dSHansol Suh test: 70784430a0dSHansol Suh suffix: soft 70884430a0dSHansol Suh args: -taylor 0 -soft 1 70984430a0dSHansol Suh 710c4762a1bSJed Brown TEST*/ 711