xref: /petsc/src/binding/petsc4py/src/lib-petsc/custom.h (revision 9d47de495d3c23378050c1b4a410c12a375cb6c6)
1 #if !defined(PETSC4PY_CUSTOM_H)
2 #define PETSC4PY_CUSTOM_H
3 
4 #include <petsc/private/deviceimpl.h>
5 #include <petsc/private/sfimpl.h>
6 #include <petsc/private/vecimpl.h>
7 #include <petsc/private/matimpl.h>
8 #include <petsc/private/pcimpl.h>
9 #include <petsc/private/kspimpl.h>
10 #include <petsc/private/snesimpl.h>
11 #include <petsc/private/tsimpl.h>
12 #include <petsc/private/taoimpl.h>
13 #include <petsc/private/deviceimpl.h>
14 #include <petsc/private/viewerimpl.h>
15 
16 /* ---------------------------------------------------------------- */
17 
18 #define PetscERROR(comm,FUNCT,n,t,msg,arg) \
19         PetscError(comm,__LINE__,FUNCT,__FILE__,n,t,msg,arg)
20 
21 /* ---------------------------------------------------------------- */
22 
23 typedef PetscErrorCode (*PetscErrorHandlerFunction)
24 (MPI_Comm,int,const char*,const char*,
25  PetscErrorCode,PetscErrorType,const char*,void*);
26 #define PetscTBEH PetscTraceBackErrorHandler
27 
28 /* ---------------------------------------------------------------- */
29 
30 PETSC_EXTERN PetscErrorCode (*PetscPythonMonitorSet_C)(PetscObject,const char*);
31 
32 /* ---------------------------------------------------------------- */
33 
34 static
PetscLogStageFindId(const char name[],PetscLogStage * stageid)35 PetscErrorCode PetscLogStageFindId(const char name[], PetscLogStage *stageid)
36 {
37   PetscLogState log_state = NULL;
38 
39   PetscFunctionBegin;
40   PetscAssertPointer(name,1);
41   PetscAssertPointer(stageid,2);
42   *stageid = -1;
43   if (!(log_state = petsc_log_state)) PetscFunctionReturn(PETSC_SUCCESS); /* logging is off ? */
44   PetscCall(PetscLogStageGetId(name, stageid));
45   PetscFunctionReturn(PETSC_SUCCESS);
46 }
47 
48 static
PetscLogClassFindId(const char name[],PetscClassId * classid)49 PetscErrorCode PetscLogClassFindId(const char name[], PetscClassId *classid)
50 {
51   PetscLogState log_state = NULL;
52 
53   PetscFunctionBegin;
54   PetscAssertPointer(name,1);
55   PetscAssertPointer(classid,2);
56   *classid = -1;
57   if (!(log_state = petsc_log_state)) PetscFunctionReturn(PETSC_SUCCESS); /* logging is off ? */
58   PetscCall(PetscLogClassGetClassId(name, classid));
59   PetscFunctionReturn(PETSC_SUCCESS);
60 }
61 
62 static
PetscLogEventFindId(const char name[],PetscLogEvent * eventid)63 PetscErrorCode PetscLogEventFindId(const char name[], PetscLogEvent *eventid)
64 {
65   PetscLogState log_state = NULL;
66 
67   PetscFunctionBegin;
68   PetscAssertPointer(name,1);
69   PetscAssertPointer(eventid,2);
70   *eventid = -1;
71   if (!(log_state = petsc_log_state)) PetscFunctionReturn(PETSC_SUCCESS); /* logging is off ? */
72   PetscCall(PetscLogEventGetId(name, eventid));
73   PetscFunctionReturn(PETSC_SUCCESS);
74 }
75 
76 static
PetscLogStageFindName(PetscLogStage stageid,const char * name[])77 PetscErrorCode PetscLogStageFindName(PetscLogStage stageid, const char *name[])
78 {
79   PetscLogState log_state = NULL;
80 
81   PetscFunctionBegin;
82   PetscAssertPointer(name,3);
83   *name = NULL;
84   if (!(log_state = petsc_log_state)) PetscFunctionReturn(PETSC_SUCCESS); /* logging is off ? */
85   PetscCall(PetscLogStageGetName(stageid, name));
86   PetscFunctionReturn(PETSC_SUCCESS);
87 }
88 
89 static
PetscLogClassFindName(PetscClassId classid,const char * name[])90 PetscErrorCode PetscLogClassFindName(PetscClassId classid, const char *name[])
91 {
92   PetscLogState log_state = NULL;
93 
94   PetscFunctionBegin;
95   PetscAssertPointer(name,3);
96   *name = 0;
97   if (!(log_state = petsc_log_state)) PetscFunctionReturn(PETSC_SUCCESS); /* logging is off ? */
98   PetscCall(PetscLogClassIdGetName(classid, name));
99   PetscFunctionReturn(PETSC_SUCCESS);
100 }
101 
102 static
PetscLogEventFindName(PetscLogEvent eventid,const char * name[])103 PetscErrorCode PetscLogEventFindName(PetscLogEvent eventid, const char *name[])
104 {
105   PetscLogState log_state = NULL;
106 
107   PetscFunctionBegin;
108   PetscAssertPointer(name,3);
109   *name = 0;
110   if (!(log_state = petsc_log_state)) PetscFunctionReturn(PETSC_SUCCESS); /* logging is off ? */
111   PetscCall(PetscLogEventGetName(eventid, name));
112   PetscFunctionReturn(PETSC_SUCCESS);
113 }
114 
115 /* ---------------------------------------------------------------- */
116 
117 static
PetscObjectComposedDataGetIntPy(PetscObject o,PetscInt id,PetscInt * v,PetscBool * exist)118 PetscErrorCode PetscObjectComposedDataGetIntPy(PetscObject o, PetscInt id, PetscInt *v, PetscBool *exist)
119 {
120   PetscFunctionBegin;
121   PetscCall(PetscObjectComposedDataGetInt(o,id,*v,*exist));
122   PetscFunctionReturn(PETSC_SUCCESS);
123 }
124 
125 static
PetscObjectComposedDataSetIntPy(PetscObject o,PetscInt id,PetscInt v)126 PetscErrorCode PetscObjectComposedDataSetIntPy(PetscObject o, PetscInt id, PetscInt v)
127 {
128   PetscFunctionBegin;
129   PetscCall(PetscObjectComposedDataSetInt(o,id,v));
130   PetscFunctionReturn(PETSC_SUCCESS);
131 }
132 
133 static
PetscObjectComposedDataRegisterPy(PetscInt * id)134 PetscErrorCode PetscObjectComposedDataRegisterPy(PetscInt *id)
135 {
136   PetscFunctionBegin;
137   PetscCall(PetscObjectComposedDataRegister(id));
138   PetscFunctionReturn(PETSC_SUCCESS);
139 }
140 
141 /* ---------------------------------------------------------------- */
142 
143 /* The object is not used so far. I expect PETSc will sooner or later support
144    a different device context for each object */
145 static
PetscObjectGetDeviceId(PetscObject o,PetscInt * id)146 PetscErrorCode PetscObjectGetDeviceId(PetscObject o, PetscInt *id)
147 {
148 #if defined(PETSC_HAVE_DEVICE)
149   PetscDeviceContext dctx;
150   PetscDevice device;
151 #endif
152   PetscFunctionBegin;
153   PetscValidHeader(o,1);
154 #if defined(PETSC_HAVE_DEVICE)
155   PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
156   PetscCall(PetscDeviceContextGetDevice(dctx,&device));
157   PetscCall(PetscDeviceGetDeviceId(device,id));
158 #else
159   *id = 0;
160 #endif
161   PetscFunctionReturn(PETSC_SUCCESS);
162 }
163 
164 /* ---------------------------------------------------------------- */
165 
166 static
VecGetCurrentMemType(Vec v,PetscMemType * m)167 PetscErrorCode VecGetCurrentMemType(Vec v, PetscMemType *m)
168 {
169   PetscBool bound;
170 
171   PetscFunctionBegin;
172   PetscValidHeaderSpecific(v,VEC_CLASSID,1);
173   PetscAssertPointer(m,2);
174   *m = PETSC_MEMTYPE_HOST;
175   PetscCall(VecBoundToCPU(v,&bound));
176   if (!bound) {
177     VecType rtype;
178     char *iscuda = NULL, *iship = NULL, *iskok = NULL;
179 
180     PetscCall(VecGetRootType_Private(v,&rtype));
181     PetscCall(PetscStrstr(rtype,"cuda",&iscuda));
182     PetscCall(PetscStrstr(rtype,"hip",&iship));
183     PetscCall(PetscStrstr(rtype,"kokkos",&iskok));
184     if (iscuda)     *m = PETSC_MEMTYPE_CUDA;
185     else if (iship) *m = PETSC_MEMTYPE_HIP;
186     else if (iskok) *m = PETSC_MEMTYPE_KOKKOS;
187   }
188   PetscFunctionReturn(PETSC_SUCCESS);
189 }
190 
191 /* ---------------------------------------------------------------- */
192 
193 static
MatIsPreallocated(Mat A,PetscBool * flag)194 PetscErrorCode MatIsPreallocated(Mat A,PetscBool *flag)
195 {
196   PetscFunctionBegin;
197   PetscValidHeaderSpecific(A,MAT_CLASSID,1);
198   PetscAssertPointer(flag,2);
199   *flag = A->preallocated;
200   PetscFunctionReturn(PETSC_SUCCESS);
201 }
202 
203 static
MatHasPreallocationAIJ(Mat A,PetscBool * aij,PetscBool * baij,PetscBool * sbaij,PetscBool * is)204 PetscErrorCode MatHasPreallocationAIJ(Mat A,PetscBool *aij,PetscBool *baij,PetscBool *sbaij,PetscBool *is)
205 {
206   PetscErrorCodeFn *f = 0;
207 
208   PetscFunctionBegin;
209   PetscValidHeaderSpecific(A,MAT_CLASSID,1);
210   PetscValidType(A,1);
211   PetscAssertPointer(aij,2);
212   PetscAssertPointer(baij,3);
213   PetscAssertPointer(sbaij,4);
214   PetscAssertPointer(is,5);
215   *aij = *baij = *sbaij = *is = PETSC_FALSE;
216   if (!f) PetscCall(PetscObjectQueryFunction((PetscObject)A,"MatMPIAIJSetPreallocation_C",&f));
217   if (!f) PetscCall(PetscObjectQueryFunction((PetscObject)A,"MatSeqAIJSetPreallocation_C",&f));
218   if (f)  {*aij = PETSC_TRUE; goto done;}
219   if (!f) PetscCall(PetscObjectQueryFunction((PetscObject)A,"MatMPIBAIJSetPreallocation_C",&f));
220   if (!f) PetscCall(PetscObjectQueryFunction((PetscObject)A,"MatSeqBAIJSetPreallocation_C",&f));
221   if (f)  {*baij = PETSC_TRUE; goto done;}
222   if (!f) PetscCall(PetscObjectQueryFunction((PetscObject)A,"MatMPISBAIJSetPreallocation_C",&f));
223   if (!f) PetscCall(PetscObjectQueryFunction((PetscObject)A,"MatSeqSBAIJSetPreallocation_C",&f));
224   if (f)  {*sbaij = PETSC_TRUE; goto done;}
225   if (!f) PetscCall(PetscObjectQueryFunction((PetscObject)A,"MatISSetPreallocation_C",&f));
226   if (f)  {*is = PETSC_TRUE; goto done;}
227  done:
228   PetscFunctionReturn(PETSC_SUCCESS);
229 }
230 
231 #if !defined(MatNullSpaceFunction)
232 typedef PetscErrorCode MatNullSpaceFunction(MatNullSpace,Vec,void*);
233 #endif
234 
235 /* ---------------------------------------------------------------- */
236 
237 static
MatFactorInfoDefaults(PetscBool incomplete,PetscBool cholesky,MatFactorInfo * info)238 PetscErrorCode MatFactorInfoDefaults(PetscBool incomplete,PetscBool cholesky,MatFactorInfo *info)
239 {
240   PetscFunctionBegin;
241   PetscAssertPointer(info,2);
242   PetscCall(MatFactorInfoInitialize(info));
243   if (incomplete) {
244     info->levels         = (PetscReal)0;
245     info->diagonal_fill  = (PetscReal)0;
246     info->fill           = (PetscReal)1.0;
247     info->usedt          = (PetscReal)0;
248     info->dt             = (PetscReal)PETSC_DEFAULT;
249     info->dtcount        = (PetscReal)PETSC_DEFAULT;
250     info->dtcol          = (PetscReal)PETSC_DEFAULT;
251     info->zeropivot      = (PetscReal)100.0*PETSC_MACHINE_EPSILON;
252     info->pivotinblocks  = (PetscReal)1;
253   } else {
254     info->fill           = (PetscReal)5.0;
255     info->dtcol          = (PetscReal)1.e-6;
256     info->zeropivot      = (PetscReal)100.0*PETSC_MACHINE_EPSILON;
257     info->pivotinblocks  = (PetscReal)1;
258   }
259   if (incomplete) {
260     if (cholesky)
261       info->shifttype    = (PetscReal)MAT_SHIFT_POSITIVE_DEFINITE;
262     else
263       info->shifttype    = (PetscReal)MAT_SHIFT_NONZERO;
264     info->shiftamount    = (PetscReal)100.0*PETSC_MACHINE_EPSILON;
265   } else {
266     info->shifttype      = (PetscReal)MAT_SHIFT_NONE;
267     info->shiftamount    = (PetscReal)0.0;
268   }
269   PetscFunctionReturn(PETSC_SUCCESS);
270 }
271 
272 /* ---------------------------------------------------------------- */
273 
274 static
KSPSetIterationNumber(KSP ksp,PetscInt its)275 PetscErrorCode KSPSetIterationNumber(KSP ksp, PetscInt its)
276 {
277   PetscFunctionBegin;
278   PetscValidHeaderSpecific(ksp,KSP_CLASSID,1);
279   PetscCheck(its >= 0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"iteration number must be nonnegative");
280   ksp->its = its;
281   PetscFunctionReturn(PETSC_SUCCESS);
282 }
283 
284 static
KSPSetResidualNorm(KSP ksp,PetscReal rnorm)285 PetscErrorCode KSPSetResidualNorm(KSP ksp, PetscReal rnorm)
286 {
287   PetscFunctionBegin;
288   PetscValidHeaderSpecific(ksp,KSP_CLASSID,1);
289   PetscCheck(rnorm >= 0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"residual norm must be nonnegative");
290   ksp->rnorm = rnorm;
291   PetscFunctionReturn(PETSC_SUCCESS);
292 }
293 
294 static
KSPConvergenceTestCall(KSP ksp,PetscInt its,PetscReal rnorm,KSPConvergedReason * reason)295 PetscErrorCode KSPConvergenceTestCall(KSP ksp, PetscInt its, PetscReal rnorm, KSPConvergedReason *reason)
296 {
297   PetscFunctionBegin;
298   PetscValidHeaderSpecific(ksp,KSP_CLASSID,1);
299   PetscAssertPointer(reason,4);
300   PetscCheck(its >= 0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"iteration number must be nonnegative");
301   PetscCheck(rnorm >= 0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"residual norm must be nonnegative");
302   PetscCall((*ksp->converged)(ksp,its,rnorm,reason,ksp->cnvP));
303   PetscFunctionReturn(PETSC_SUCCESS);
304 }
305 
306 static
KSPSetConvergedReason(KSP ksp,KSPConvergedReason reason)307 PetscErrorCode KSPSetConvergedReason(KSP ksp, KSPConvergedReason reason)
308 {
309   PetscFunctionBegin;
310   PetscValidHeaderSpecific(ksp,KSP_CLASSID,1);
311   ksp->reason = reason;
312   PetscFunctionReturn(PETSC_SUCCESS);
313 }
314 
315 static
KSPConverged(KSP ksp,PetscInt iter,PetscReal rnorm,KSPConvergedReason * reason)316 PetscErrorCode KSPConverged(KSP ksp,PetscInt iter,PetscReal rnorm,KSPConvergedReason *reason)
317 {
318   PetscFunctionBegin;
319   PetscValidHeaderSpecific(ksp,KSP_CLASSID,1);
320   if (reason) PetscAssertPointer(reason,2);
321   if (!iter) ksp->rnorm0 = rnorm;
322   if (!iter) {
323     ksp->reason = KSP_CONVERGED_ITERATING;
324     ksp->ttol = PetscMax(rnorm*ksp->rtol,ksp->abstol);
325   }
326   if (ksp->converged) {
327     PetscCall(ksp->converged(ksp,iter,rnorm,&ksp->reason,ksp->cnvP));
328   } else {
329     PetscCall(KSPConvergedSkip(ksp,iter,rnorm,&ksp->reason,NULL));
330     /*PetscCall(KSPConvergedDefault(ksp,iter,rnorm,&ksp->reason,NULL));*/
331   }
332   ksp->rnorm = rnorm;
333   if (reason) *reason = ksp->reason;
334   PetscFunctionReturn(PETSC_SUCCESS);
335 }
336 
337 typedef struct {
338   PetscBool prepend_custom;
339   KSPConvergenceTestFn *convtest;
340   PetscCtxDestroyFn    *convdestroy;
341   KSPConvergenceTestFn *convtestcustom;
342   void *convctx;
343 } KSPConvergedNativeCtx;
344 
345 static
KSPConvergedNative_Private(KSP ksp,PetscInt n,PetscReal rnorm,KSPConvergedReason * reason,void * cctx)346 PetscErrorCode KSPConvergedNative_Private(KSP ksp, PetscInt n, PetscReal rnorm, KSPConvergedReason *reason, void *cctx)
347 {
348   KSPConvergedNativeCtx *ctx = (KSPConvergedNativeCtx *)cctx;
349 
350   PetscFunctionBegin;
351   *reason = KSP_CONVERGED_ITERATING;
352   if (ctx->prepend_custom) {
353     PetscCall((*ctx->convtestcustom)(ksp, n, rnorm, reason, NULL));
354     if (*reason) {
355       PetscCall(PetscInfo(ksp, "User provided prepended Python convergence test reason %s KSP iterations=%" PetscInt_FMT ", rnorm=%g\n", KSPConvergedReasons[*reason], n, (double)rnorm));
356       PetscFunctionReturn(PETSC_SUCCESS);
357     }
358   }
359   PetscCall((*ctx->convtest)(ksp, n, rnorm, reason, ctx->convctx));
360   if (*reason) {
361     PetscCall(PetscInfo(ksp, "Default convergence test reason %s KSP iterations=%" PetscInt_FMT ", rnorm=%g\n", KSPConvergedReasons[*reason], n, (double)rnorm));
362     PetscFunctionReturn(PETSC_SUCCESS);
363   }
364   if (!ctx->prepend_custom) {
365     PetscCall((*ctx->convtestcustom)(ksp, n, rnorm, reason, NULL));
366     if (*reason) PetscCall(PetscInfo(ksp, "User provide appended Python convergence test reason %s KSP iterations=%" PetscInt_FMT ", rnorm=%g\n", KSPConvergedReasons[*reason], n, (double)rnorm));
367   }
368   PetscFunctionReturn(PETSC_SUCCESS);
369 }
370 
KSPConvergedNative_Destroy(PetscCtxRt cctx)371 static PetscErrorCode KSPConvergedNative_Destroy(PetscCtxRt cctx)
372 {
373   KSPConvergedNativeCtx *ctx = *(KSPConvergedNativeCtx **)cctx;
374 
375   PetscFunctionBegin;
376   PetscCheck(ctx, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing context");
377   if (ctx->convdestroy) PetscCall((*ctx->convdestroy)(&ctx->convctx));
378   PetscCall(PetscFree(ctx));
379   PetscFunctionReturn(PETSC_SUCCESS);
380 }
381 
382 static
KSPAddConvergenceTest(KSP ksp,PetscErrorCode (* custom)(KSP,PetscInt,PetscReal,KSPConvergedReason *,void *),PetscBool prepend)383 PetscErrorCode KSPAddConvergenceTest(KSP ksp, PetscErrorCode (*custom)(KSP, PetscInt, PetscReal, KSPConvergedReason *, void *), PetscBool prepend)
384 {
385   KSPConvergedNativeCtx *ctx;
386 
387   PetscFunctionBegin;
388   PetscValidHeaderSpecific(ksp,KSP_CLASSID,1);
389   PetscValidLogicalCollectiveBool(ksp,prepend,3);
390   PetscCall(PetscNew(&ctx));
391   ctx->convtestcustom = custom;
392   ctx->prepend_custom = prepend;
393   PetscCall(KSPGetAndClearConvergenceTest(ksp, &ctx->convtest, &ctx->convctx, &ctx->convdestroy));
394   PetscCall(KSPSetConvergenceTest(ksp, KSPConvergedNative_Private, ctx, KSPConvergedNative_Destroy));
395   PetscFunctionReturn(PETSC_SUCCESS);
396 }
397 
398 static
KSPLogHistory(KSP ksp,PetscReal rnorm)399 PetscErrorCode KSPLogHistory(KSP ksp,PetscReal rnorm)
400 {
401   PetscFunctionBegin;
402   PetscValidHeaderSpecific(ksp,KSP_CLASSID,1);
403   PetscCall(KSPLogResidualHistory(ksp,rnorm));
404   PetscFunctionReturn(PETSC_SUCCESS);
405 }
406 
407 /* ---------------------------------------------------------------- */
408 
409 static
SNESConvergenceTestCall(SNES snes,PetscInt its,PetscReal xnorm,PetscReal ynorm,PetscReal fnorm,SNESConvergedReason * reason)410 PetscErrorCode SNESConvergenceTestCall(SNES snes, PetscInt its, PetscReal xnorm, PetscReal ynorm, PetscReal fnorm, SNESConvergedReason *reason)
411 {
412   PetscFunctionBegin;
413   PetscValidHeaderSpecific(snes,SNES_CLASSID,1);
414   PetscAssertPointer(reason,4);
415   PetscCheck(its >= 0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"iteration number must be nonnegative");
416   PetscCheck(xnorm >= 0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"solution norm must be nonnegative");
417   PetscCheck(ynorm >= 0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"step norm must be nonnegative");
418   PetscCheck(fnorm >= 0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"function norm must be nonnegative");
419   PetscUseTypeMethod(snes,converged,its,xnorm,ynorm,fnorm,reason,snes->cnvP);
420   PetscFunctionReturn(PETSC_SUCCESS);
421 }
422 
423 static
SNESLogHistory(SNES snes,PetscReal rnorm,PetscInt lits)424 PetscErrorCode SNESLogHistory(SNES snes,PetscReal rnorm,PetscInt lits)
425 {
426   PetscFunctionBegin;
427   PetscValidHeaderSpecific(snes,SNES_CLASSID,1);
428   PetscCall(SNESLogConvergenceHistory(snes,rnorm,lits));
429   PetscFunctionReturn(PETSC_SUCCESS);
430 }
431 
432 static
SNESGetUseMFFD(SNES snes,PetscBool * flag)433 PetscErrorCode SNESGetUseMFFD(SNES snes,PetscBool *flag)
434 {
435   PetscErrorCode (*jac)(SNES,Vec,Mat,Mat,void*) = NULL;
436   Mat            J = NULL;
437 
438   PetscFunctionBegin;
439   PetscValidHeaderSpecific(snes,SNES_CLASSID,1);
440   PetscAssertPointer(flag,2);
441   *flag = PETSC_FALSE;
442   PetscCall(SNESGetJacobian(snes,&J,0,&jac,0));
443   if (J) PetscCall(PetscObjectTypeCompare((PetscObject)J,MATMFFD,flag));
444   else if (jac == MatMFFDComputeJacobian) *flag = PETSC_TRUE;
445   PetscFunctionReturn(PETSC_SUCCESS);
446 }
447 
448 static
SNESSetUseMFFD(SNES snes,PetscBool flag)449 PetscErrorCode SNESSetUseMFFD(SNES snes,PetscBool flag)
450 {
451   const char* prefix = NULL;
452   PetscBool   flg    = PETSC_FALSE;
453   Vec         r      = NULL;
454   Mat         A      = NULL,B = NULL,J = NULL;
455   void*       funP   = NULL;
456   void*       jacP   = NULL;
457 
458   PetscFunctionBegin;
459   PetscValidHeaderSpecific(snes,SNES_CLASSID,1);
460 
461   PetscCall(SNESGetUseMFFD(snes,&flg));
462   if (flg  &&  flag) PetscFunctionReturn(PETSC_SUCCESS);
463   if (!flg && !flag) PetscFunctionReturn(PETSC_SUCCESS);
464   if (flg  && !flag) {
465     SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONGSTATE,
466             "cannot change matrix-free once it is set");
467     PetscFunctionReturn(PETSC_ERR_ARG_WRONGSTATE);
468   }
469 
470   PetscCall(SNESGetOptionsPrefix(snes,&prefix));
471   PetscCall(SNESGetFunction(snes,&r,0,&funP));
472   PetscCall(SNESGetJacobian(snes,&A,&B,0,&jacP));
473   if (!r) {
474     SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONGSTATE,"SNESSetFunction() must be called first");
475     PetscFunctionReturn(PETSC_ERR_ARG_WRONGSTATE);
476   }
477   PetscCall(MatCreateSNESMF(snes,&J));
478   PetscCall(MatSetOptionsPrefix(J,prefix));
479   PetscCall(MatSetFromOptions(J));
480   if (!B) {
481     KSP       ksp;
482     PC        pc;
483     PetscBool shell,python;
484     PetscCall(SNESSetJacobian(snes,J,J,MatMFFDComputeJacobian,jacP));
485     PetscCall(SNESGetKSP(snes,&ksp));
486     PetscCall(KSPGetPC(ksp,&pc));
487     PetscCall(PetscObjectTypeCompare((PetscObject)pc,PCSHELL,&shell));
488     PetscCall(PetscObjectTypeCompare((PetscObject)pc,PCPYTHON,&python));
489     if (!shell && !python) PetscCall(PCSetType(pc,PCNONE));
490   } else PetscCall(SNESSetJacobian(snes,J,0,0,0));
491   PetscCall(MatDestroy(&J));
492   PetscFunctionReturn(PETSC_SUCCESS);
493 }
494 
495 static
SNESGetUseFDColoring(SNES snes,PetscBool * flag)496 PetscErrorCode SNESGetUseFDColoring(SNES snes,PetscBool *flag)
497 {
498   PetscErrorCode (*jac)(SNES,Vec,Mat,Mat,void*) = NULL;
499 
500   PetscFunctionBegin;
501   PetscValidHeaderSpecific(snes,SNES_CLASSID,1);
502   PetscAssertPointer(flag,2);
503   *flag = PETSC_FALSE;
504   PetscCall(SNESGetJacobian(snes,0,0,&jac,0));
505   if (jac == SNESComputeJacobianDefaultColor) *flag = PETSC_TRUE;
506   PetscFunctionReturn(PETSC_SUCCESS);
507 }
508 
509 static
SNESSetUseFDColoring(SNES snes,PetscBool flag)510 PetscErrorCode SNESSetUseFDColoring(SNES snes,PetscBool flag)
511 {
512   PetscBool      flg = PETSC_FALSE;
513   PetscErrorCode (*fun)(SNES,Vec,Vec,void*) = NULL;
514   void*          funP = NULL;
515   Mat            A = NULL,B = NULL;
516   PetscErrorCode (*jac)(SNES,Vec,Mat,Mat,void*) = NULL;
517   void*          jacP = NULL;
518 
519   PetscFunctionBegin;
520   PetscValidHeaderSpecific(snes,SNES_CLASSID,1);
521 
522   PetscCall(SNESGetUseFDColoring(snes,&flg));
523   if (flg  &&  flag) PetscFunctionReturn(PETSC_SUCCESS);
524   if (!flg && !flag) PetscFunctionReturn(PETSC_SUCCESS);
525   if (flg  && !flag) {
526     SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONGSTATE,
527             "cannot change colored finite differences once it is set");
528     PetscFunctionReturn(PETSC_ERR_ARG_WRONGSTATE);
529   }
530 
531   PetscCall(SNESGetFunction(snes,NULL,&fun,&funP));
532   PetscCall(SNESGetJacobian(snes,&A,&B,&jac,&jacP));
533   PetscCall(SNESSetJacobian(snes,A,B,SNESComputeJacobianDefaultColor,0));
534   {
535     DM     dm;
536     DMSNES sdm;
537     PetscCall(SNESGetDM(snes,&dm));
538     PetscCall(DMGetDMSNES(dm,&sdm));
539     PetscCall(DMSNESUnsetJacobianContext_Internal(dm));
540   }
541   PetscFunctionReturn(PETSC_SUCCESS);
542 }
543 
544 static
SNESComputeUpdate(SNES snes)545 PetscErrorCode SNESComputeUpdate(SNES snes)
546 {
547   PetscFunctionBegin;
548   PetscValidHeaderSpecific(snes,SNES_CLASSID,1);
549   PetscTryTypeMethod(snes, update, snes->iter);
550   PetscFunctionReturn(PETSC_SUCCESS);
551 }
552 
553 static
SNESGetUseKSP(SNES snes,PetscBool * flag)554 PetscErrorCode SNESGetUseKSP(SNES snes,PetscBool *flag)
555 {
556   PetscFunctionBegin;
557   PetscValidHeaderSpecific(snes,SNES_CLASSID,1);
558   PetscAssertPointer(flag,2);
559   *flag = snes->usesksp;
560   PetscFunctionReturn(PETSC_SUCCESS);
561 }
562 
563 static
SNESSetUseKSP(SNES snes,PetscBool flag)564 PetscErrorCode SNESSetUseKSP(SNES snes,PetscBool flag)
565 {
566   PetscFunctionBegin;
567   PetscValidHeaderSpecific(snes,SNES_CLASSID,1);
568   PetscValidLogicalCollectiveBool(snes,flag,2);
569   snes->usesksp = flag;
570   PetscFunctionReturn(PETSC_SUCCESS);
571 }
572 
573 /* ---------------------------------------------------------------- */
574 
575 static
TaoConverged(Tao tao,TaoConvergedReason * reason)576 PetscErrorCode TaoConverged(Tao tao, TaoConvergedReason *reason)
577 {
578   PetscFunctionBegin;
579   PetscValidHeaderSpecific(tao,TAO_CLASSID,1);
580   PetscAssertPointer(reason,2);
581   if (tao->ops->convergencetest) {
582     PetscUseTypeMethod(tao,convergencetest,tao->cnvP);
583   } else {
584     PetscCall(TaoDefaultConvergenceTest(tao,tao->cnvP));
585   }
586   *reason = tao->reason;
587   PetscFunctionReturn(PETSC_SUCCESS);
588 }
589 
590 static
TaoCheckReals(Tao tao,PetscReal f,PetscReal g)591 PetscErrorCode TaoCheckReals(Tao tao, PetscReal f, PetscReal g)
592 {
593   PetscFunctionBegin;
594   PetscValidHeaderSpecific(tao,TAO_CLASSID,1);
595   PetscCheck(!PetscIsInfOrNanReal(f) && !PetscIsInfOrNanReal(g),PetscObjectComm((PetscObject)tao),PETSC_ERR_USER,"User provided compute function generated infinity or NaN");
596   PetscFunctionReturn(PETSC_SUCCESS);
597 }
598 
599 static
TaoCreateDefaultKSP(Tao tao)600 PetscErrorCode TaoCreateDefaultKSP(Tao tao)
601 {
602   PetscFunctionBegin;
603   PetscValidHeaderSpecific(tao,TAO_CLASSID,1);
604   PetscCall(KSPDestroy(&tao->ksp));
605   PetscCall(KSPCreate(((PetscObject)tao)->comm,&tao->ksp));
606   PetscCall(PetscObjectIncrementTabLevel((PetscObject)tao->ksp,(PetscObject)tao,1));
607   PetscFunctionReturn(PETSC_SUCCESS);
608 }
609 
610 static
TaoCreateDefaultLineSearch(Tao tao)611 PetscErrorCode TaoCreateDefaultLineSearch(Tao tao)
612 {
613   PetscFunctionBegin;
614   PetscValidHeaderSpecific(tao,TAO_CLASSID,1);
615   PetscCall(TaoLineSearchDestroy(&tao->linesearch));
616   PetscCall(TaoLineSearchCreate(((PetscObject)tao)->comm,&tao->linesearch));
617   PetscCall(PetscObjectIncrementTabLevel((PetscObject)tao->linesearch,(PetscObject)tao,1));
618   PetscCall(TaoLineSearchSetType(tao->linesearch,TAOLINESEARCHMT));
619   PetscCall(TaoLineSearchUseTaoRoutines(tao->linesearch,tao));
620   PetscCall(TaoLineSearchSetInitialStepLength(tao->linesearch,1.0));
621   PetscFunctionReturn(PETSC_SUCCESS);
622 }
623 
624 static
TaoHasGradientRoutine(Tao tao,PetscBool * flg)625 PetscErrorCode TaoHasGradientRoutine(Tao tao, PetscBool* flg)
626 {
627   PetscFunctionBegin;
628   PetscValidHeaderSpecific(tao,TAO_CLASSID,1);
629   PetscAssertPointer(flg,2);
630   *flg = (PetscBool)(tao->ops->computegradient || tao->ops->computeobjectiveandgradient);
631   PetscFunctionReturn(PETSC_SUCCESS);
632 }
633 
634 #if 0
635 static
636 PetscErrorCode TaoHasHessianRoutine(Tao tao, PetscBool* flg)
637 {
638   PetscFunctionBegin;
639   PetscValidHeaderSpecific(tao,TAO_CLASSID,1);
640   PetscAssertPointer(flg,2);
641   *flg = tao->ops->computehessian;
642   PetscFunctionReturn(PETSC_SUCCESS);
643 }
644 #endif
645 
646 static
TaoComputeUpdate(Tao tao,PetscReal * f)647 PetscErrorCode TaoComputeUpdate(Tao tao, PetscReal *f)
648 {
649   PetscFunctionBegin;
650   PetscValidHeaderSpecific(tao,TAO_CLASSID,1);
651   if (tao->ops->update) {
652     PetscUseTypeMethod(tao,update,tao->niter,tao->user_update);
653     PetscCall(TaoComputeObjective(tao,tao->solution,f));
654   }
655   PetscFunctionReturn(PETSC_SUCCESS);
656 }
657 
658 static
TaoGetVecs(Tao tao,Vec * X,Vec * G,Vec * S)659 PetscErrorCode TaoGetVecs(Tao tao, Vec *X, Vec *G, Vec *S)
660 {
661   PetscBool has_g;
662 
663   PetscFunctionBegin;
664   PetscValidHeaderSpecific(tao,TAO_CLASSID,1);
665   PetscCall(TaoHasGradientRoutine(tao,&has_g));
666   if (X) *X = tao->solution;
667   if (G) {
668     if (has_g && !tao->gradient) PetscCall(VecDuplicate(tao->solution,&tao->gradient));
669     *G = has_g ? tao->gradient : NULL;
670   }
671   if (S) {
672     if (has_g && !tao->stepdirection) PetscCall(VecDuplicate(tao->solution,&tao->stepdirection));
673     *S = has_g ? tao->stepdirection : NULL;
674   }
675   PetscFunctionReturn(PETSC_SUCCESS);
676 }
677 
678 static
TaoApplyLineSearch(Tao tao,PetscReal * f,PetscReal * s,TaoLineSearchConvergedReason * lsr)679 PetscErrorCode TaoApplyLineSearch(Tao tao, PetscReal* f, PetscReal *s, TaoLineSearchConvergedReason *lsr)
680 {
681   PetscFunctionBegin;
682   PetscValidHeaderSpecific(tao,TAO_CLASSID,1);
683   PetscAssertPointer(f,2);
684   PetscAssertPointer(s,3);
685   PetscCall(TaoLineSearchApply(tao->linesearch,tao->solution,f,tao->gradient,tao->stepdirection,s,lsr));
686   PetscCall(TaoAddLineSearchCounts(tao));
687   PetscFunctionReturn(PETSC_SUCCESS);
688 }
689 
690 /* ---------------------------------------------------------------- */
691 
692 static
DMDACreateND(MPI_Comm comm,PetscInt dim,PetscInt dof,PetscInt M,PetscInt N,PetscInt P,PetscInt m,PetscInt n,PetscInt p,const PetscInt lx[],const PetscInt ly[],const PetscInt lz[],DMBoundaryType bx,DMBoundaryType by,DMBoundaryType bz,DMDAStencilType stencil_type,PetscInt stencil_width,DM * dm)693 PetscErrorCode DMDACreateND(MPI_Comm comm,
694                             PetscInt dim,PetscInt dof,
695                             PetscInt M,PetscInt N,PetscInt P,
696                             PetscInt m,PetscInt n,PetscInt p,
697                             const PetscInt lx[],const PetscInt ly[],const PetscInt lz[],
698                             DMBoundaryType bx,DMBoundaryType by,DMBoundaryType bz,
699                             DMDAStencilType stencil_type,PetscInt stencil_width,
700                             DM *dm)
701 {
702   DM da;
703 
704   PetscFunctionBegin;
705   PetscAssertPointer(dm,18);
706   PetscCall(DMDACreate(comm,&da));
707   PetscCall(DMSetDimension(da,dim));
708   PetscCall(DMDASetDof(da,dof));
709   PetscCall(DMDASetSizes(da,M,N,P));
710   PetscCall(DMDASetNumProcs(da,m,n,p));
711   PetscCall(DMDASetOwnershipRanges(da,lx,ly,lz));
712   PetscCall(DMDASetBoundaryType(da,bx,by,bz));
713   PetscCall(DMDASetStencilType(da,stencil_type));
714   PetscCall(DMDASetStencilWidth(da,stencil_width));
715   *dm = (DM)da;
716   PetscFunctionReturn(PETSC_SUCCESS);
717 }
718 
719 static
PetscDeviceReference(PetscDevice device)720 PetscErrorCode PetscDeviceReference(PetscDevice device)
721 {
722   PetscFunctionBegin;
723   PetscCall(PetscDeviceReference_Internal(device));
724   PetscFunctionReturn(PETSC_SUCCESS);
725 }
726 
727 /* ---------------------------------------------------------------- */
728 
729 #endif/* PETSC4PY_CUSTOM_H*/
730 
731 /*
732   Local variables:
733   c-basic-offset: 2
734   indent-tabs-mode: nil
735   End:
736 */
737