1 #include <../src/tao/bound/impls/bqnk/bqnk.h> /*I "petsctao.h" I*/
2 #include <petscksp.h>
3
TaoBQNKComputeHessian(Tao tao)4 static PetscErrorCode TaoBQNKComputeHessian(Tao tao)
5 {
6 TAO_BNK *bnk = (TAO_BNK *)tao->data;
7 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
8 PetscReal gnorm2, delta;
9
10 PetscFunctionBegin;
11 /* Alias the LMVM matrix into the TAO hessian */
12 if (tao->hessian) PetscCall(MatDestroy(&tao->hessian));
13 if (tao->hessian_pre) PetscCall(MatDestroy(&tao->hessian_pre));
14 PetscCall(PetscObjectReference((PetscObject)bqnk->B));
15 tao->hessian = bqnk->B;
16 PetscCall(PetscObjectReference((PetscObject)bqnk->B));
17 tao->hessian_pre = bqnk->B;
18 /* Update the Hessian with the latest solution */
19 if (bqnk->is_spd) {
20 gnorm2 = bnk->gnorm * bnk->gnorm;
21 if (gnorm2 == 0.0) gnorm2 = PETSC_MACHINE_EPSILON;
22 if (bnk->f == 0.0) {
23 delta = 2.0 / gnorm2;
24 } else {
25 delta = 2.0 * PetscAbsScalar(bnk->f) / gnorm2;
26 }
27 PetscCall(MatLMVMSymBroydenSetDelta(bqnk->B, delta));
28 }
29 PetscCall(MatLMVMUpdate(tao->hessian, tao->solution, bnk->unprojected_gradient));
30 PetscCall(MatLMVMResetShift(tao->hessian));
31 /* Prepare the reduced sub-matrices for the inactive set */
32 PetscCall(MatDestroy(&bnk->H_inactive));
33 if (bnk->active_idx) {
34 PetscCall(MatCreateSubMatrixVirtual(tao->hessian, bnk->inactive_idx, bnk->inactive_idx, &bnk->H_inactive));
35 PetscCall(PCLMVMSetIS(bqnk->pc, bnk->inactive_idx));
36 } else {
37 PetscCall(PetscObjectReference((PetscObject)tao->hessian));
38 bnk->H_inactive = tao->hessian;
39 PetscCall(PCLMVMClearIS(bqnk->pc));
40 }
41 PetscCall(MatDestroy(&bnk->Hpre_inactive));
42 PetscCall(PetscObjectReference((PetscObject)bnk->H_inactive));
43 bnk->Hpre_inactive = bnk->H_inactive;
44 PetscFunctionReturn(PETSC_SUCCESS);
45 }
46
TaoBQNKComputeStep(Tao tao,PetscBool shift,KSPConvergedReason * ksp_reason,PetscInt * step_type)47 static PetscErrorCode TaoBQNKComputeStep(Tao tao, PetscBool shift, KSPConvergedReason *ksp_reason, PetscInt *step_type)
48 {
49 TAO_BNK *bnk = (TAO_BNK *)tao->data;
50 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
51
52 PetscFunctionBegin;
53 PetscCall(TaoBNKComputeStep(tao, shift, ksp_reason, step_type));
54 if (*ksp_reason < 0) {
55 /* Krylov solver failed to converge so reset the LMVM matrix */
56 PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE));
57 PetscCall(MatLMVMUpdate(bqnk->B, tao->solution, bnk->unprojected_gradient));
58 }
59 PetscFunctionReturn(PETSC_SUCCESS);
60 }
61
TaoSolve_BQNK(Tao tao)62 PetscErrorCode TaoSolve_BQNK(Tao tao)
63 {
64 TAO_BNK *bnk = (TAO_BNK *)tao->data;
65 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
66 Mat_LMVM *lmvm = (Mat_LMVM *)bqnk->B->data;
67 Mat_LMVM *J0;
68 PetscBool flg = PETSC_FALSE;
69
70 PetscFunctionBegin;
71 if (!tao->recycle) {
72 PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE));
73 lmvm->nresets = 0;
74 if (lmvm->J0) {
75 PetscCall(PetscObjectBaseTypeCompare((PetscObject)lmvm->J0, MATLMVM, &flg));
76 if (flg) {
77 J0 = (Mat_LMVM *)lmvm->J0->data;
78 J0->nresets = 0;
79 }
80 }
81 }
82 PetscCall((*bqnk->solve)(tao));
83 PetscFunctionReturn(PETSC_SUCCESS);
84 }
85
TaoSetUp_BQNK(Tao tao)86 PetscErrorCode TaoSetUp_BQNK(Tao tao)
87 {
88 TAO_BNK *bnk = (TAO_BNK *)tao->data;
89 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
90 PetscInt n, N;
91 PetscBool is_lmvm, is_set, is_sym;
92
93 PetscFunctionBegin;
94 PetscCall(TaoSetUp_BNK(tao));
95 PetscCall(VecGetLocalSize(tao->solution, &n));
96 PetscCall(VecGetSize(tao->solution, &N));
97 PetscCall(MatSetSizes(bqnk->B, n, n, N, N));
98 PetscCall(MatLMVMAllocate(bqnk->B, tao->solution, bnk->unprojected_gradient));
99 PetscCall(PetscObjectBaseTypeCompare((PetscObject)bqnk->B, MATLMVM, &is_lmvm));
100 PetscCheck(is_lmvm, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Matrix must be an LMVM-type");
101 PetscCall(MatIsSymmetricKnown(bqnk->B, &is_set, &is_sym));
102 PetscCheck(is_set && is_sym, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM matrix must be symmetric");
103 PetscCall(KSPGetPC(tao->ksp, &bqnk->pc));
104 PetscCall(PCSetType(bqnk->pc, PCLMVM));
105 PetscCall(PCLMVMSetMatLMVM(bqnk->pc, bqnk->B));
106 PetscFunctionReturn(PETSC_SUCCESS);
107 }
108
TaoSetFromOptions_BQNK(Tao tao,PetscOptionItems PetscOptionsObject)109 static PetscErrorCode TaoSetFromOptions_BQNK(Tao tao, PetscOptionItems PetscOptionsObject)
110 {
111 TAO_BNK *bnk = (TAO_BNK *)tao->data;
112 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
113 PetscBool is_set;
114
115 PetscFunctionBegin;
116 PetscCall(TaoSetFromOptions_BNK(tao, PetscOptionsObject));
117 if (bnk->init_type == BNK_INIT_INTERPOLATION) bnk->init_type = BNK_INIT_DIRECTION;
118 PetscCall(MatSetOptionsPrefix(bqnk->B, ((PetscObject)tao)->prefix));
119 PetscCall(MatAppendOptionsPrefix(bqnk->B, "tao_bqnk_"));
120 PetscCall(MatSetFromOptions(bqnk->B));
121 PetscCall(MatIsSPDKnown(bqnk->B, &is_set, &bqnk->is_spd));
122 if (!is_set) bqnk->is_spd = PETSC_FALSE;
123 PetscFunctionReturn(PETSC_SUCCESS);
124 }
125
TaoView_BQNK(Tao tao,PetscViewer viewer)126 static PetscErrorCode TaoView_BQNK(Tao tao, PetscViewer viewer)
127 {
128 TAO_BNK *bnk = (TAO_BNK *)tao->data;
129 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
130 PetscBool isascii;
131
132 PetscFunctionBegin;
133 PetscCall(TaoView_BNK(tao, viewer));
134 PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
135 if (isascii) {
136 PetscCall(PetscViewerPushFormat(viewer, PETSC_VIEWER_ASCII_INFO));
137 PetscCall(MatView(bqnk->B, viewer));
138 PetscCall(PetscViewerPopFormat(viewer));
139 }
140 PetscFunctionReturn(PETSC_SUCCESS);
141 }
142
TaoDestroy_BQNK(Tao tao)143 static PetscErrorCode TaoDestroy_BQNK(Tao tao)
144 {
145 TAO_BNK *bnk = (TAO_BNK *)tao->data;
146 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
147
148 PetscFunctionBegin;
149 PetscCall(MatDestroy(&bnk->Hpre_inactive));
150 PetscCall(MatDestroy(&bnk->H_inactive));
151 PetscCall(MatDestroy(&bqnk->B));
152 PetscCall(PetscFree(bnk->ctx));
153 PetscCall(TaoDestroy_BNK(tao));
154 PetscFunctionReturn(PETSC_SUCCESS);
155 }
156
TaoCreate_BQNK(Tao tao)157 PETSC_INTERN PetscErrorCode TaoCreate_BQNK(Tao tao)
158 {
159 TAO_BNK *bnk;
160 TAO_BQNK *bqnk;
161
162 PetscFunctionBegin;
163 PetscCall(TaoCreate_BNK(tao));
164 tao->ops->solve = TaoSolve_BQNK;
165 tao->ops->setfromoptions = TaoSetFromOptions_BQNK;
166 tao->ops->destroy = TaoDestroy_BQNK;
167 tao->ops->view = TaoView_BQNK;
168 tao->ops->setup = TaoSetUp_BQNK;
169
170 bnk = (TAO_BNK *)tao->data;
171 bnk->computehessian = TaoBQNKComputeHessian;
172 bnk->computestep = TaoBQNKComputeStep;
173 bnk->init_type = BNK_INIT_DIRECTION;
174
175 PetscCall(PetscNew(&bqnk));
176 bnk->ctx = (void *)bqnk;
177 bqnk->is_spd = PETSC_TRUE;
178
179 PetscCall(MatCreate(PetscObjectComm((PetscObject)tao), &bqnk->B));
180 PetscCall(PetscObjectIncrementTabLevel((PetscObject)bqnk->B, (PetscObject)tao, 1));
181 PetscCall(MatSetType(bqnk->B, MATLMVMSR1));
182 PetscFunctionReturn(PETSC_SUCCESS);
183 }
184
185 /*@
186 TaoGetLMVMMatrix - Returns a pointer to the internal LMVM matrix. Valid
187 only for quasi-Newton family of methods.
188
189 Input Parameter:
190 . tao - `Tao` solver context
191
192 Output Parameter:
193 . B - LMVM matrix
194
195 Level: advanced
196
197 .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoSetLMVMMatrix()`
198 @*/
TaoGetLMVMMatrix(Tao tao,Mat * B)199 PetscErrorCode TaoGetLMVMMatrix(Tao tao, Mat *B)
200 {
201 TAO_BNK *bnk = (TAO_BNK *)tao->data;
202 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
203 PetscBool flg = PETSC_FALSE;
204
205 PetscFunctionBegin;
206 PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, ""));
207 PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms");
208 *B = bqnk->B;
209 PetscFunctionReturn(PETSC_SUCCESS);
210 }
211
212 /*@
213 TaoSetLMVMMatrix - Sets an external LMVM matrix into the Tao solver. Valid
214 only for quasi-Newton family of methods.
215
216 QN family of methods create their own LMVM matrices and users who wish to
217 manipulate this matrix should use TaoGetLMVMMatrix() instead.
218
219 Input Parameters:
220 + tao - Tao solver context
221 - B - LMVM matrix
222
223 Level: advanced
224
225 .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoGetLMVMMatrix()`
226 @*/
TaoSetLMVMMatrix(Tao tao,Mat B)227 PetscErrorCode TaoSetLMVMMatrix(Tao tao, Mat B)
228 {
229 TAO_BNK *bnk = (TAO_BNK *)tao->data;
230 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
231 PetscBool flg = PETSC_FALSE;
232
233 PetscFunctionBegin;
234 PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, ""));
235 PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms");
236 PetscCall(PetscObjectBaseTypeCompare((PetscObject)B, MATLMVM, &flg));
237 PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Given matrix is not an LMVM matrix");
238 if (bqnk->B) PetscCall(MatDestroy(&bqnk->B));
239 PetscCall(PetscObjectReference((PetscObject)B));
240 bqnk->B = B;
241 PetscFunctionReturn(PETSC_SUCCESS);
242 }
243