1 #include <../src/tao/bound/impls/bqnk/bqnk.h> /*I "petsctao.h" I*/ 2 #include <petscksp.h> 3 4 static PetscErrorCode TaoBQNKComputeHessian(Tao tao) 5 { 6 TAO_BNK *bnk = (TAO_BNK *)tao->data; 7 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 8 PetscReal gnorm2, delta; 9 10 PetscFunctionBegin; 11 /* Alias the LMVM matrix into the TAO hessian */ 12 if (tao->hessian) PetscCall(MatDestroy(&tao->hessian)); 13 if (tao->hessian_pre) PetscCall(MatDestroy(&tao->hessian_pre)); 14 PetscCall(PetscObjectReference((PetscObject)bqnk->B)); 15 tao->hessian = bqnk->B; 16 PetscCall(PetscObjectReference((PetscObject)bqnk->B)); 17 tao->hessian_pre = bqnk->B; 18 /* Update the Hessian with the latest solution */ 19 if (bqnk->is_spd) { 20 gnorm2 = bnk->gnorm * bnk->gnorm; 21 if (gnorm2 == 0.0) gnorm2 = PETSC_MACHINE_EPSILON; 22 if (bnk->f == 0.0) { 23 delta = 2.0 / gnorm2; 24 } else { 25 delta = 2.0 * PetscAbsScalar(bnk->f) / gnorm2; 26 } 27 PetscCall(MatLMVMSymBroydenSetDelta(bqnk->B, delta)); 28 } 29 PetscCall(MatLMVMUpdate(tao->hessian, tao->solution, bnk->unprojected_gradient)); 30 PetscCall(MatLMVMResetShift(tao->hessian)); 31 /* Prepare the reduced sub-matrices for the inactive set */ 32 PetscCall(MatDestroy(&bnk->H_inactive)); 33 if (bnk->active_idx) { 34 PetscCall(MatCreateSubMatrixVirtual(tao->hessian, bnk->inactive_idx, bnk->inactive_idx, &bnk->H_inactive)); 35 PetscCall(PCLMVMSetIS(bqnk->pc, bnk->inactive_idx)); 36 } else { 37 PetscCall(PetscObjectReference((PetscObject)tao->hessian)); 38 bnk->H_inactive = tao->hessian; 39 PetscCall(PCLMVMClearIS(bqnk->pc)); 40 } 41 PetscCall(MatDestroy(&bnk->Hpre_inactive)); 42 PetscCall(PetscObjectReference((PetscObject)bnk->H_inactive)); 43 bnk->Hpre_inactive = bnk->H_inactive; 44 PetscFunctionReturn(PETSC_SUCCESS); 45 } 46 47 static PetscErrorCode TaoBQNKComputeStep(Tao tao, PetscBool shift, KSPConvergedReason *ksp_reason, PetscInt *step_type) 48 { 49 TAO_BNK *bnk = (TAO_BNK *)tao->data; 50 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 51 52 PetscFunctionBegin; 53 PetscCall(TaoBNKComputeStep(tao, shift, ksp_reason, step_type)); 54 if (*ksp_reason < 0) { 55 /* Krylov solver failed to converge so reset the LMVM matrix */ 56 PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE)); 57 PetscCall(MatLMVMUpdate(bqnk->B, tao->solution, bnk->unprojected_gradient)); 58 } 59 PetscFunctionReturn(PETSC_SUCCESS); 60 } 61 62 PetscErrorCode TaoSolve_BQNK(Tao tao) 63 { 64 TAO_BNK *bnk = (TAO_BNK *)tao->data; 65 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 66 Mat_LMVM *lmvm = (Mat_LMVM *)bqnk->B->data; 67 Mat_LMVM *J0; 68 PetscBool flg = PETSC_FALSE; 69 70 PetscFunctionBegin; 71 if (!tao->recycle) { 72 PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE)); 73 lmvm->nresets = 0; 74 if (lmvm->J0) { 75 PetscCall(PetscObjectBaseTypeCompare((PetscObject)lmvm->J0, MATLMVM, &flg)); 76 if (flg) { 77 J0 = (Mat_LMVM *)lmvm->J0->data; 78 J0->nresets = 0; 79 } 80 } 81 } 82 PetscCall((*bqnk->solve)(tao)); 83 PetscFunctionReturn(PETSC_SUCCESS); 84 } 85 86 PetscErrorCode TaoSetUp_BQNK(Tao tao) 87 { 88 TAO_BNK *bnk = (TAO_BNK *)tao->data; 89 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 90 PetscInt n, N; 91 PetscBool is_lmvm, is_set, is_sym; 92 93 PetscFunctionBegin; 94 PetscCall(TaoSetUp_BNK(tao)); 95 PetscCall(VecGetLocalSize(tao->solution, &n)); 96 PetscCall(VecGetSize(tao->solution, &N)); 97 PetscCall(MatSetSizes(bqnk->B, n, n, N, N)); 98 PetscCall(MatLMVMAllocate(bqnk->B, tao->solution, bnk->unprojected_gradient)); 99 PetscCall(PetscObjectBaseTypeCompare((PetscObject)bqnk->B, MATLMVM, &is_lmvm)); 100 PetscCheck(is_lmvm, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Matrix must be an LMVM-type"); 101 PetscCall(MatIsSymmetricKnown(bqnk->B, &is_set, &is_sym)); 102 PetscCheck(is_set && is_sym, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM matrix must be symmetric"); 103 PetscCall(KSPGetPC(tao->ksp, &bqnk->pc)); 104 PetscCall(PCSetType(bqnk->pc, PCLMVM)); 105 PetscCall(PCLMVMSetMatLMVM(bqnk->pc, bqnk->B)); 106 PetscFunctionReturn(PETSC_SUCCESS); 107 } 108 109 static PetscErrorCode TaoSetFromOptions_BQNK(Tao tao, PetscOptionItems PetscOptionsObject) 110 { 111 TAO_BNK *bnk = (TAO_BNK *)tao->data; 112 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 113 PetscBool is_set; 114 115 PetscFunctionBegin; 116 PetscCall(TaoSetFromOptions_BNK(tao, PetscOptionsObject)); 117 if (bnk->init_type == BNK_INIT_INTERPOLATION) bnk->init_type = BNK_INIT_DIRECTION; 118 PetscCall(MatSetOptionsPrefix(bqnk->B, ((PetscObject)tao)->prefix)); 119 PetscCall(MatAppendOptionsPrefix(bqnk->B, "tao_bqnk_")); 120 PetscCall(MatSetFromOptions(bqnk->B)); 121 PetscCall(MatIsSPDKnown(bqnk->B, &is_set, &bqnk->is_spd)); 122 if (!is_set) bqnk->is_spd = PETSC_FALSE; 123 PetscFunctionReturn(PETSC_SUCCESS); 124 } 125 126 static PetscErrorCode TaoView_BQNK(Tao tao, PetscViewer viewer) 127 { 128 TAO_BNK *bnk = (TAO_BNK *)tao->data; 129 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 130 PetscBool isascii; 131 132 PetscFunctionBegin; 133 PetscCall(TaoView_BNK(tao, viewer)); 134 PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii)); 135 if (isascii) { 136 PetscCall(PetscViewerPushFormat(viewer, PETSC_VIEWER_ASCII_INFO)); 137 PetscCall(MatView(bqnk->B, viewer)); 138 PetscCall(PetscViewerPopFormat(viewer)); 139 } 140 PetscFunctionReturn(PETSC_SUCCESS); 141 } 142 143 static PetscErrorCode TaoDestroy_BQNK(Tao tao) 144 { 145 TAO_BNK *bnk = (TAO_BNK *)tao->data; 146 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 147 148 PetscFunctionBegin; 149 PetscCall(MatDestroy(&bnk->Hpre_inactive)); 150 PetscCall(MatDestroy(&bnk->H_inactive)); 151 PetscCall(MatDestroy(&bqnk->B)); 152 PetscCall(PetscFree(bnk->ctx)); 153 PetscCall(TaoDestroy_BNK(tao)); 154 PetscFunctionReturn(PETSC_SUCCESS); 155 } 156 157 PETSC_INTERN PetscErrorCode TaoCreate_BQNK(Tao tao) 158 { 159 TAO_BNK *bnk; 160 TAO_BQNK *bqnk; 161 162 PetscFunctionBegin; 163 PetscCall(TaoCreate_BNK(tao)); 164 tao->ops->solve = TaoSolve_BQNK; 165 tao->ops->setfromoptions = TaoSetFromOptions_BQNK; 166 tao->ops->destroy = TaoDestroy_BQNK; 167 tao->ops->view = TaoView_BQNK; 168 tao->ops->setup = TaoSetUp_BQNK; 169 170 bnk = (TAO_BNK *)tao->data; 171 bnk->computehessian = TaoBQNKComputeHessian; 172 bnk->computestep = TaoBQNKComputeStep; 173 bnk->init_type = BNK_INIT_DIRECTION; 174 175 PetscCall(PetscNew(&bqnk)); 176 bnk->ctx = (void *)bqnk; 177 bqnk->is_spd = PETSC_TRUE; 178 179 PetscCall(MatCreate(PetscObjectComm((PetscObject)tao), &bqnk->B)); 180 PetscCall(PetscObjectIncrementTabLevel((PetscObject)bqnk->B, (PetscObject)tao, 1)); 181 PetscCall(MatSetType(bqnk->B, MATLMVMSR1)); 182 PetscFunctionReturn(PETSC_SUCCESS); 183 } 184 185 /*@ 186 TaoGetLMVMMatrix - Returns a pointer to the internal LMVM matrix. Valid 187 only for quasi-Newton family of methods. 188 189 Input Parameter: 190 . tao - `Tao` solver context 191 192 Output Parameter: 193 . B - LMVM matrix 194 195 Level: advanced 196 197 .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoSetLMVMMatrix()` 198 @*/ 199 PetscErrorCode TaoGetLMVMMatrix(Tao tao, Mat *B) 200 { 201 TAO_BNK *bnk = (TAO_BNK *)tao->data; 202 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 203 PetscBool flg = PETSC_FALSE; 204 205 PetscFunctionBegin; 206 PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, "")); 207 PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms"); 208 *B = bqnk->B; 209 PetscFunctionReturn(PETSC_SUCCESS); 210 } 211 212 /*@ 213 TaoSetLMVMMatrix - Sets an external LMVM matrix into the Tao solver. Valid 214 only for quasi-Newton family of methods. 215 216 QN family of methods create their own LMVM matrices and users who wish to 217 manipulate this matrix should use TaoGetLMVMMatrix() instead. 218 219 Input Parameters: 220 + tao - Tao solver context 221 - B - LMVM matrix 222 223 Level: advanced 224 225 .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoGetLMVMMatrix()` 226 @*/ 227 PetscErrorCode TaoSetLMVMMatrix(Tao tao, Mat B) 228 { 229 TAO_BNK *bnk = (TAO_BNK *)tao->data; 230 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 231 PetscBool flg = PETSC_FALSE; 232 233 PetscFunctionBegin; 234 PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, "")); 235 PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms"); 236 PetscCall(PetscObjectBaseTypeCompare((PetscObject)B, MATLMVM, &flg)); 237 PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Given matrix is not an LMVM matrix"); 238 if (bqnk->B) PetscCall(MatDestroy(&bqnk->B)); 239 PetscCall(PetscObjectReference((PetscObject)B)); 240 bqnk->B = B; 241 PetscFunctionReturn(PETSC_SUCCESS); 242 } 243