1 #include <petsc/private/taoimpl.h> /*I "petsctao.h" I*/ 2 3 typedef struct { 4 SNES snes; 5 PetscBool setfromoptionscalled; 6 } Tao_SNES; 7 8 static PetscErrorCode TaoSolve_SNES(Tao tao) 9 { 10 Tao_SNES *taosnes = (Tao_SNES *)tao->data; 11 PetscInt its; 12 13 PetscFunctionBegin; 14 /* TODO SNES fails if KSP reaches max_it, while TAO accepts whatever we got */ 15 PetscCall(SNESSolve(taosnes->snes, NULL, tao->solution)); 16 /* TODO REASONS */ 17 tao->reason = TAO_CONVERGED_USER; 18 PetscCall(SNESGetIterationNumber(taosnes->snes, &its)); 19 PetscCall(TaoSetIterationNumber(tao, its)); 20 PetscFunctionReturn(PETSC_SUCCESS); 21 } 22 23 static PetscErrorCode TaoDestroy_SNES(Tao tao) 24 { 25 Tao_SNES *taosnes = (Tao_SNES *)tao->data; 26 27 PetscFunctionBegin; 28 PetscCall(SNESDestroy(&taosnes->snes)); 29 PetscCall(PetscFree(tao->data)); 30 PetscFunctionReturn(PETSC_SUCCESS); 31 } 32 33 static PetscErrorCode TAOSNESObj(SNES snes, Vec X, PetscReal *f, void *ctx) 34 { 35 Tao tao = (Tao)ctx; 36 37 PetscFunctionBegin; 38 PetscCall(TaoComputeObjective(tao, X, f)); 39 PetscFunctionReturn(PETSC_SUCCESS); 40 } 41 42 static PetscErrorCode TAOSNESFunc(SNES snes, Vec X, Vec F, void *ctx) 43 { 44 Tao tao = (Tao)ctx; 45 46 PetscFunctionBegin; 47 PetscCall(TaoComputeGradient(tao, X, F)); 48 PetscFunctionReturn(PETSC_SUCCESS); 49 } 50 51 static PetscErrorCode TAOSNESJac(SNES snes, Vec X, Mat A, Mat P, void *ctx) 52 { 53 Tao tao = (Tao)ctx; 54 55 PetscFunctionBegin; 56 PetscCall(TaoComputeHessian(tao, X, A, P)); 57 PetscFunctionReturn(PETSC_SUCCESS); 58 } 59 60 static PetscErrorCode TAOSNESMonitor(SNES snes, PetscInt its, PetscReal fnorm, void *ctx) 61 { 62 Tao tao = (Tao)ctx; 63 PetscReal obj; 64 Vec X; 65 66 PetscFunctionBegin; 67 PetscCall(SNESGetSolution(snes, &X)); 68 PetscCall(TaoComputeObjective(tao, X, &obj)); 69 PetscCall(TaoSetIterationNumber(tao, its)); 70 PetscCall(TaoMonitor(tao, its, obj, fnorm, 0.0, 0.0)); 71 PetscFunctionReturn(PETSC_SUCCESS); 72 } 73 74 static PetscErrorCode TaoSetUp_SNES(Tao tao) 75 { 76 Tao_SNES *taosnes = (Tao_SNES *)tao->data; 77 Mat A, P; 78 const char *prefix; 79 80 PetscFunctionBegin; 81 PetscCall(TaoGetOptionsPrefix(tao, &prefix)); 82 PetscCall(SNESSetOptionsPrefix(taosnes->snes, prefix)); 83 PetscCall(SNESSetSolution(taosnes->snes, tao->solution)); 84 PetscCall(SNESSetObjective(taosnes->snes, TAOSNESObj, tao)); 85 PetscCall(SNESSetFunction(taosnes->snes, NULL, TAOSNESFunc, tao)); 86 PetscCall(SNESMonitorSet(taosnes->snes, TAOSNESMonitor, tao, NULL)); 87 PetscCall(TaoGetHessian(tao, &A, &P, NULL, NULL)); 88 if (A) PetscCall(SNESSetJacobian(taosnes->snes, A, P, TAOSNESJac, tao)); 89 if (taosnes->setfromoptionscalled) PetscCall(SNESSetFromOptions(taosnes->snes)); 90 taosnes->setfromoptionscalled = PETSC_FALSE; 91 PetscCall(SNESSetUp(taosnes->snes)); 92 PetscFunctionReturn(PETSC_SUCCESS); 93 } 94 95 static PetscErrorCode TaoSetFromOptions_SNES(Tao tao, PetscOptionItems PetscOptionsObject) 96 { 97 Tao_SNES *taosnes = (Tao_SNES *)tao->data; 98 99 PetscFunctionBegin; 100 taosnes->setfromoptionscalled = PETSC_TRUE; 101 PetscFunctionReturn(PETSC_SUCCESS); 102 } 103 104 static PetscErrorCode TaoView_SNES(Tao tao, PetscViewer viewer) 105 { 106 Tao_SNES *taosnes = (Tao_SNES *)tao->data; 107 108 PetscFunctionBegin; 109 PetscCall(SNESView(taosnes->snes, viewer)); 110 PetscFunctionReturn(PETSC_SUCCESS); 111 } 112 113 /*MC 114 TAOSNES - nonlinear solver using SNES 115 116 Level: advanced 117 118 .seealso: `TaoCreate()`, `Tao`, `TaoSetType()`, `TaoType` 119 M*/ 120 PETSC_EXTERN PetscErrorCode TaoCreate_SNES(Tao tao) 121 { 122 Tao_SNES *taosnes; 123 124 PetscFunctionBegin; 125 tao->ops->destroy = TaoDestroy_SNES; 126 tao->ops->setup = TaoSetUp_SNES; 127 tao->ops->setfromoptions = TaoSetFromOptions_SNES; 128 tao->ops->view = TaoView_SNES; 129 tao->ops->solve = TaoSolve_SNES; 130 131 PetscCall(TaoParametersInitialize(tao)); 132 133 PetscCall(PetscNew(&taosnes)); 134 tao->data = (void *)taosnes; 135 PetscCall(SNESCreate(PetscObjectComm((PetscObject)tao), &taosnes->snes)); 136 PetscCall(PetscObjectIncrementTabLevel((PetscObject)taosnes->snes, (PetscObject)tao, 1)); 137 PetscFunctionReturn(PETSC_SUCCESS); 138 } 139