xref: /petsc/src/tao/constrained/tutorials/maros.c (revision 58d68138c660dfb4e9f5b03334792cd4f2ffd7cc)
1 /* Program usage: mpiexec -n 1 maros1 [-help] [all TAO options] */
2 
3 /* ----------------------------------------------------------------------
4 TODO Explain maros example
5 ---------------------------------------------------------------------- */
6 
7 #include <petsctao.h>
8 
9 static char help[] = "";
10 
11 /*
12    User-defined application context - contains data needed by the
13    application-provided call-back routines, FormFunction(),
14    FormGradient(), and FormHessian().
15 */
16 
17 /*
18    x,d in R^n
19    f in R
20    bin in R^mi
21    beq in R^me
22    Aeq in R^(me x n)
23    Ain in R^(mi x n)
24    H in R^(n x n)
25    min f=(1/2)*x'*H*x + d'*x
26    s.t.  Aeq*x == beq
27          Ain*x >= bin
28 */
29 typedef struct {
30   char     name[32];
31   PetscInt n;  /* Length x */
32   PetscInt me; /* number of equality constraints */
33   PetscInt mi; /* number of inequality constraints */
34   PetscInt m;  /* me+mi */
35   Mat      Aeq, Ain, H;
36   Vec      beq, bin, d;
37 } AppCtx;
38 
39 /* -------- User-defined Routines --------- */
40 
41 PetscErrorCode InitializeProblem(AppCtx *);
42 PetscErrorCode DestroyProblem(AppCtx *);
43 PetscErrorCode FormFunctionGradient(Tao, Vec, PetscReal *, Vec, void *);
44 PetscErrorCode FormHessian(Tao, Vec, Mat, Mat, void *);
45 PetscErrorCode FormInequalityConstraints(Tao, Vec, Vec, void *);
46 PetscErrorCode FormEqualityConstraints(Tao, Vec, Vec, void *);
47 PetscErrorCode FormInequalityJacobian(Tao, Vec, Mat, Mat, void *);
48 PetscErrorCode FormEqualityJacobian(Tao, Vec, Mat, Mat, void *);
49 
50 PetscErrorCode main(int argc, char **argv) {
51   PetscMPIInt        size;
52   Vec                x; /* solution */
53   KSP                ksp;
54   PC                 pc;
55   Vec                ceq, cin;
56   PetscBool          flg; /* A return value when checking for use options */
57   Tao                tao; /* Tao solver context */
58   TaoConvergedReason reason;
59   AppCtx             user; /* application context */
60 
61   /* Initialize TAO,PETSc */
62   PetscFunctionBeginUser;
63   PetscCall(PetscInitialize(&argc, &argv, (char *)0, help));
64   PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD, &size));
65   /* Specify default parameters for the problem, check for command-line overrides */
66   PetscCall(PetscStrncpy(user.name, "HS21", sizeof(user.name)));
67   PetscCall(PetscOptionsGetString(NULL, NULL, "-cutername", user.name, sizeof(user.name), &flg));
68 
69   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n---- MAROS Problem %s -----\n", user.name));
70   PetscCall(InitializeProblem(&user));
71   PetscCall(VecDuplicate(user.d, &x));
72   PetscCall(VecDuplicate(user.beq, &ceq));
73   PetscCall(VecDuplicate(user.bin, &cin));
74   PetscCall(VecSet(x, 1.0));
75 
76   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao));
77   PetscCall(TaoSetType(tao, TAOIPM));
78   PetscCall(TaoSetSolution(tao, x));
79   PetscCall(TaoSetObjectiveAndGradient(tao, NULL, FormFunctionGradient, (void *)&user));
80   PetscCall(TaoSetEqualityConstraintsRoutine(tao, ceq, FormEqualityConstraints, (void *)&user));
81   PetscCall(TaoSetInequalityConstraintsRoutine(tao, cin, FormInequalityConstraints, (void *)&user));
82   PetscCall(TaoSetInequalityBounds(tao, user.bin, NULL));
83   PetscCall(TaoSetJacobianEqualityRoutine(tao, user.Aeq, user.Aeq, FormEqualityJacobian, (void *)&user));
84   PetscCall(TaoSetJacobianInequalityRoutine(tao, user.Ain, user.Ain, FormInequalityJacobian, (void *)&user));
85   PetscCall(TaoSetHessian(tao, user.H, user.H, FormHessian, (void *)&user));
86   PetscCall(TaoGetKSP(tao, &ksp));
87   PetscCall(KSPGetPC(ksp, &pc));
88   PetscCall(PCSetType(pc, PCLU));
89   /*
90       This algorithm produces matrices with zeros along the diagonal therefore we need to use
91     SuperLU which does partial pivoting
92   */
93   PetscCall(PCFactorSetMatSolverType(pc, MATSOLVERSUPERLU));
94   PetscCall(KSPSetType(ksp, KSPPREONLY));
95   PetscCall(TaoSetTolerances(tao, 0, 0, 0));
96 
97   PetscCall(TaoSetFromOptions(tao));
98   PetscCall(TaoSolve(tao));
99   PetscCall(TaoGetConvergedReason(tao, &reason));
100   if (reason < 0) {
101     PetscCall(PetscPrintf(MPI_COMM_WORLD, "TAO failed to converge due to %s.\n", TaoConvergedReasons[reason]));
102   } else {
103     PetscCall(PetscPrintf(MPI_COMM_WORLD, "Optimization completed with status %s.\n", TaoConvergedReasons[reason]));
104   }
105 
106   PetscCall(DestroyProblem(&user));
107   PetscCall(VecDestroy(&x));
108   PetscCall(VecDestroy(&ceq));
109   PetscCall(VecDestroy(&cin));
110   PetscCall(TaoDestroy(&tao));
111 
112   PetscCall(PetscFinalize());
113   return 0;
114 }
115 
116 PetscErrorCode InitializeProblem(AppCtx *user) {
117   PetscViewer loader;
118   MPI_Comm    comm;
119   PetscInt    nrows, ncols, i;
120   PetscScalar one = 1.0;
121   char        filebase[128];
122   char        filename[128];
123 
124   PetscFunctionBegin;
125   comm = PETSC_COMM_WORLD;
126   PetscCall(PetscStrncpy(filebase, user->name, sizeof(filebase)));
127   PetscCall(PetscStrlcat(filebase, "/", sizeof(filebase)));
128   PetscCall(PetscStrncpy(filename, filebase, sizeof(filename)));
129   PetscCall(PetscStrlcat(filename, "f", sizeof(filename)));
130   PetscCall(PetscViewerBinaryOpen(comm, filename, FILE_MODE_READ, &loader));
131 
132   PetscCall(VecCreate(comm, &user->d));
133   PetscCall(VecLoad(user->d, loader));
134   PetscCall(PetscViewerDestroy(&loader));
135   PetscCall(VecGetSize(user->d, &nrows));
136   PetscCall(VecSetFromOptions(user->d));
137   user->n = nrows;
138 
139   PetscCall(PetscStrncpy(filename, filebase, sizeof(filename)));
140   PetscCall(PetscStrlcat(filename, "H", sizeof(filename)));
141   PetscCall(PetscViewerBinaryOpen(comm, filename, FILE_MODE_READ, &loader));
142 
143   PetscCall(MatCreate(comm, &user->H));
144   PetscCall(MatSetSizes(user->H, PETSC_DECIDE, PETSC_DECIDE, nrows, nrows));
145   PetscCall(MatLoad(user->H, loader));
146   PetscCall(PetscViewerDestroy(&loader));
147   PetscCall(MatGetSize(user->H, &nrows, &ncols));
148   PetscCheck(nrows == user->n, comm, PETSC_ERR_ARG_SIZ, "H: nrows != n");
149   PetscCheck(ncols == user->n, comm, PETSC_ERR_ARG_SIZ, "H: ncols != n");
150   PetscCall(MatSetFromOptions(user->H));
151 
152   PetscCall(PetscStrncpy(filename, filebase, sizeof(filename)));
153   PetscCall(PetscStrlcat(filename, "Aeq", sizeof(filename)));
154   PetscCall(PetscViewerBinaryOpen(comm, filename, FILE_MODE_READ, &loader));
155   PetscCall(MatCreate(comm, &user->Aeq));
156   PetscCall(MatLoad(user->Aeq, loader));
157   PetscCall(PetscViewerDestroy(&loader));
158   PetscCall(MatGetSize(user->Aeq, &nrows, &ncols));
159   PetscCheck(ncols == user->n, comm, PETSC_ERR_ARG_SIZ, "Aeq ncols != H nrows");
160   PetscCall(MatSetFromOptions(user->Aeq));
161   user->me = nrows;
162 
163   PetscCall(PetscStrncpy(filename, filebase, sizeof(filename)));
164   PetscCall(PetscStrlcat(filename, "Beq", sizeof(filename)));
165   PetscCall(PetscViewerBinaryOpen(comm, filename, FILE_MODE_READ, &loader));
166   PetscCall(VecCreate(comm, &user->beq));
167   PetscCall(VecLoad(user->beq, loader));
168   PetscCall(PetscViewerDestroy(&loader));
169   PetscCall(VecGetSize(user->beq, &nrows));
170   PetscCheck(nrows == user->me, comm, PETSC_ERR_ARG_SIZ, "Aeq nrows != Beq n");
171   PetscCall(VecSetFromOptions(user->beq));
172 
173   user->mi = user->n;
174   /* Ain = eye(n,n) */
175   PetscCall(MatCreate(comm, &user->Ain));
176   PetscCall(MatSetType(user->Ain, MATAIJ));
177   PetscCall(MatSetSizes(user->Ain, PETSC_DECIDE, PETSC_DECIDE, user->mi, user->mi));
178 
179   PetscCall(MatMPIAIJSetPreallocation(user->Ain, 1, NULL, 0, NULL));
180   PetscCall(MatSeqAIJSetPreallocation(user->Ain, 1, NULL));
181 
182   for (i = 0; i < user->mi; i++) PetscCall(MatSetValues(user->Ain, 1, &i, 1, &i, &one, INSERT_VALUES));
183   PetscCall(MatAssemblyBegin(user->Ain, MAT_FINAL_ASSEMBLY));
184   PetscCall(MatAssemblyEnd(user->Ain, MAT_FINAL_ASSEMBLY));
185   PetscCall(MatSetFromOptions(user->Ain));
186 
187   /* bin = [0,0 ... 0]' */
188   PetscCall(VecCreate(comm, &user->bin));
189   PetscCall(VecSetType(user->bin, VECMPI));
190   PetscCall(VecSetSizes(user->bin, PETSC_DECIDE, user->mi));
191   PetscCall(VecSet(user->bin, 0.0));
192   PetscCall(VecSetFromOptions(user->bin));
193   user->m = user->me + user->mi;
194   PetscFunctionReturn(0);
195 }
196 
197 PetscErrorCode DestroyProblem(AppCtx *user) {
198   PetscFunctionBegin;
199   PetscCall(MatDestroy(&user->H));
200   PetscCall(MatDestroy(&user->Aeq));
201   PetscCall(MatDestroy(&user->Ain));
202   PetscCall(VecDestroy(&user->beq));
203   PetscCall(VecDestroy(&user->bin));
204   PetscCall(VecDestroy(&user->d));
205   PetscFunctionReturn(0);
206 }
207 
208 PetscErrorCode FormFunctionGradient(Tao tao, Vec x, PetscReal *f, Vec g, void *ctx) {
209   AppCtx     *user = (AppCtx *)ctx;
210   PetscScalar xtHx;
211 
212   PetscFunctionBegin;
213   PetscCall(MatMult(user->H, x, g));
214   PetscCall(VecDot(x, g, &xtHx));
215   PetscCall(VecDot(x, user->d, f));
216   *f += 0.5 * xtHx;
217   PetscCall(VecAXPY(g, 1.0, user->d));
218   PetscFunctionReturn(0);
219 }
220 
221 PetscErrorCode FormHessian(Tao tao, Vec x, Mat H, Mat Hpre, void *ctx) {
222   PetscFunctionBegin;
223   PetscFunctionReturn(0);
224 }
225 
226 PetscErrorCode FormInequalityConstraints(Tao tao, Vec x, Vec ci, void *ctx) {
227   AppCtx *user = (AppCtx *)ctx;
228 
229   PetscFunctionBegin;
230   PetscCall(MatMult(user->Ain, x, ci));
231   PetscFunctionReturn(0);
232 }
233 
234 PetscErrorCode FormEqualityConstraints(Tao tao, Vec x, Vec ce, void *ctx) {
235   AppCtx *user = (AppCtx *)ctx;
236 
237   PetscFunctionBegin;
238   PetscCall(MatMult(user->Aeq, x, ce));
239   PetscCall(VecAXPY(ce, -1.0, user->beq));
240   PetscFunctionReturn(0);
241 }
242 
243 PetscErrorCode FormInequalityJacobian(Tao tao, Vec x, Mat JI, Mat JIpre, void *ctx) {
244   PetscFunctionBegin;
245   PetscFunctionReturn(0);
246 }
247 
248 PetscErrorCode FormEqualityJacobian(Tao tao, Vec x, Mat JE, Mat JEpre, void *ctx) {
249   PetscFunctionBegin;
250   PetscFunctionReturn(0);
251 }
252 
253 /*TEST
254 
255    build:
256       requires: !complex
257 
258    test:
259       requires: superlu
260       localrunfiles: HS21
261 
262 TEST*/
263