xref: /libCEED/backends/ref/ceed-ref-operator.c (revision d264344313546611a0e282df1f09990e3ff407ce)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed/ceed.h>
18 #include <ceed/backend.h>
19 #include <math.h>
20 #include <stdbool.h>
21 #include <stddef.h>
22 #include <stdint.h>
23 #include "ceed-ref.h"
24 
25 //------------------------------------------------------------------------------
26 // Setup Input/Output Fields
27 //------------------------------------------------------------------------------
28 static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op,
29                                        bool is_input, CeedVector *e_vecs_full,
30                                        CeedVector *e_vecs, CeedVector *q_vecs,
31                                        CeedInt start_e, CeedInt num_fields,
32                                        CeedInt Q) {
33   CeedInt ierr, num_comp, size, P;
34   CeedSize e_size, q_size;
35   Ceed ceed;
36   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
37   CeedBasis basis;
38   CeedElemRestriction elem_restr;
39   CeedOperatorField *op_fields;
40   CeedQFunctionField *qf_fields;
41   if (is_input) {
42     ierr = CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL);
43     CeedChkBackend(ierr);
44     ierr = CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL);
45     CeedChkBackend(ierr);
46   } else {
47     ierr = CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields);
48     CeedChkBackend(ierr);
49     ierr = CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields);
50     CeedChkBackend(ierr);
51   }
52 
53   // Loop over fields
54   for (CeedInt i=0; i<num_fields; i++) {
55     CeedEvalMode eval_mode;
56     ierr = CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode);
57     CeedChkBackend(ierr);
58 
59     if (eval_mode != CEED_EVAL_WEIGHT) {
60       ierr = CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_restr);
61       CeedChkBackend(ierr);
62       ierr = CeedElemRestrictionCreateVector(elem_restr, NULL,
63                                              &e_vecs_full[i+start_e]);
64       CeedChkBackend(ierr);
65     }
66 
67     switch(eval_mode) {
68     case CEED_EVAL_NONE:
69       ierr = CeedQFunctionFieldGetSize(qf_fields[i], &size); CeedChkBackend(ierr);
70       q_size = (CeedSize)Q*size;
71       ierr = CeedVectorCreate(ceed, q_size, &q_vecs[i]); CeedChkBackend(ierr);
72       break;
73     case CEED_EVAL_INTERP:
74     case CEED_EVAL_GRAD:
75       ierr = CeedOperatorFieldGetBasis(op_fields[i], &basis); CeedChkBackend(ierr);
76       ierr = CeedQFunctionFieldGetSize(qf_fields[i], &size); CeedChkBackend(ierr);
77       ierr = CeedBasisGetNumNodes(basis, &P); CeedChkBackend(ierr);
78       ierr = CeedBasisGetNumComponents(basis, &num_comp); CeedChkBackend(ierr);
79       e_size = (CeedSize)P*num_comp;
80       ierr = CeedVectorCreate(ceed, e_size, &e_vecs[i]); CeedChkBackend(ierr);
81       q_size = (CeedSize)Q*size;
82       ierr = CeedVectorCreate(ceed, q_size, &q_vecs[i]); CeedChkBackend(ierr);
83       break;
84     case CEED_EVAL_WEIGHT: // Only on input fields
85       ierr = CeedOperatorFieldGetBasis(op_fields[i], &basis); CeedChkBackend(ierr);
86       q_size = (CeedSize)Q;
87       ierr = CeedVectorCreate(ceed, q_size, &q_vecs[i]); CeedChkBackend(ierr);
88       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
89                             CEED_VECTOR_NONE, q_vecs[i]); CeedChkBackend(ierr);
90       break;
91     case CEED_EVAL_DIV:
92       break; // Not implemented
93     case CEED_EVAL_CURL:
94       break; // Not implemented
95     }
96   }
97   return CEED_ERROR_SUCCESS;
98 }
99 
100 //------------------------------------------------------------------------------
101 // Setup Operator
102 //------------------------------------------------------------------------------/*
103 static int CeedOperatorSetup_Ref(CeedOperator op) {
104   int ierr;
105   bool is_setup_done;
106   ierr = CeedOperatorIsSetupDone(op, &is_setup_done); CeedChkBackend(ierr);
107   if (is_setup_done) return CEED_ERROR_SUCCESS;
108   Ceed ceed;
109   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
110   CeedOperator_Ref *impl;
111   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
112   CeedQFunction qf;
113   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
114   CeedInt Q, num_input_fields, num_output_fields;
115   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr);
116   ierr = CeedQFunctionIsIdentity(qf, &impl->is_identity_qf); CeedChkBackend(ierr);
117   CeedOperatorField *op_input_fields, *op_output_fields;
118   ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields,
119                                &num_output_fields, &op_output_fields);
120   CeedChkBackend(ierr);
121   CeedQFunctionField *qf_input_fields, *qf_output_fields;
122   ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL,
123                                 &qf_output_fields);
124   CeedChkBackend(ierr);
125 
126   // Allocate
127   ierr = CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full);
128   CeedChkBackend(ierr);
129 
130   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->input_states); CeedChkBackend(ierr);
131   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in); CeedChkBackend(ierr);
132   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out); CeedChkBackend(ierr);
133   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in); CeedChkBackend(ierr);
134   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out); CeedChkBackend(ierr);
135 
136   impl->num_inputs = num_input_fields;
137   impl->num_outputs = num_output_fields;
138 
139   // Set up infield and outfield e_vecs and q_vecs
140   // Infields
141   ierr = CeedOperatorSetupFields_Ref(qf, op, true, impl->e_vecs_full,
142                                      impl->e_vecs_in, impl->q_vecs_in, 0,
143                                      num_input_fields, Q);
144   CeedChkBackend(ierr);
145   // Outfields
146   ierr = CeedOperatorSetupFields_Ref(qf, op, false, impl->e_vecs_full,
147                                      impl->e_vecs_out, impl->q_vecs_out,
148                                      num_input_fields, num_output_fields, Q);
149   CeedChkBackend(ierr);
150 
151   // Identity QFunctions
152   if (impl->is_identity_qf) {
153     CeedEvalMode in_mode, out_mode;
154     CeedQFunctionField *in_fields, *out_fields;
155     ierr = CeedQFunctionGetFields(qf, NULL, &in_fields, NULL, &out_fields);
156     CeedChkBackend(ierr);
157     ierr = CeedQFunctionFieldGetEvalMode(in_fields[0], &in_mode);
158     CeedChkBackend(ierr);
159     ierr = CeedQFunctionFieldGetEvalMode(out_fields[0], &out_mode);
160     CeedChkBackend(ierr);
161 
162     if (in_mode == CEED_EVAL_NONE && out_mode == CEED_EVAL_NONE) {
163       impl->is_identity_restr_op = true;
164     } else {
165       ierr = CeedVectorDestroy(&impl->q_vecs_out[0]); CeedChkBackend(ierr);
166       impl->q_vecs_out[0] = impl->q_vecs_in[0];
167       ierr = CeedVectorAddReference(impl->q_vecs_in[0]); CeedChkBackend(ierr);
168     }
169   }
170 
171   ierr = CeedOperatorSetSetupDone(op); CeedChkBackend(ierr);
172 
173   return CEED_ERROR_SUCCESS;
174 }
175 
176 //------------------------------------------------------------------------------
177 // Setup Operator Inputs
178 //------------------------------------------------------------------------------
179 static inline int CeedOperatorSetupInputs_Ref(CeedInt num_input_fields,
180     CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
181     CeedVector in_vec, const bool skip_active,
182     CeedScalar *e_data_full[2*CEED_FIELD_MAX],
183     CeedOperator_Ref *impl, CeedRequest *request) {
184   CeedInt ierr;
185   CeedEvalMode eval_mode;
186   CeedVector vec;
187   CeedElemRestriction elem_restr;
188   uint64_t state;
189 
190   for (CeedInt i=0; i<num_input_fields; i++) {
191     // Get input vector
192     ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec);
193     CeedChkBackend(ierr);
194     if (vec == CEED_VECTOR_ACTIVE) {
195       if (skip_active)
196         continue;
197       else
198         vec = in_vec;
199     }
200 
201     ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode);
202     CeedChkBackend(ierr);
203     // Restrict and Evec
204     if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
205     } else {
206       // Restrict
207       ierr = CeedVectorGetState(vec, &state); CeedChkBackend(ierr);
208       // Skip restriction if input is unchanged
209       if (state != impl->input_states[i] || vec == in_vec) {
210         ierr = CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr);
211         CeedChkBackend(ierr);
212         ierr = CeedElemRestrictionApply(elem_restr, CEED_NOTRANSPOSE, vec,
213                                         impl->e_vecs_full[i], request);
214         CeedChkBackend(ierr);
215         impl->input_states[i] = state;
216       }
217       // Get evec
218       ierr = CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST,
219                                     (const CeedScalar **) &e_data_full[i]);
220       CeedChkBackend(ierr);
221     }
222   }
223   return CEED_ERROR_SUCCESS;
224 }
225 
226 //------------------------------------------------------------------------------
227 // Input Basis Action
228 //------------------------------------------------------------------------------
229 static inline int CeedOperatorInputBasis_Ref(CeedInt e, CeedInt Q,
230     CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
231     CeedInt num_input_fields, const bool skip_active,
232     CeedScalar *e_data_full[2*CEED_FIELD_MAX], CeedOperator_Ref *impl) {
233   CeedInt ierr;
234   CeedInt dim, elem_size, size;
235   CeedElemRestriction elem_restr;
236   CeedEvalMode eval_mode;
237   CeedBasis basis;
238 
239   for (CeedInt i=0; i<num_input_fields; i++) {
240     // Skip active input
241     if (skip_active) {
242       CeedVector vec;
243       ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec);
244       CeedChkBackend(ierr);
245       if (vec == CEED_VECTOR_ACTIVE)
246         continue;
247     }
248     // Get elem_size, eval_mode, size
249     ierr = CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr);
250     CeedChkBackend(ierr);
251     ierr = CeedElemRestrictionGetElementSize(elem_restr, &elem_size);
252     CeedChkBackend(ierr);
253     ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode);
254     CeedChkBackend(ierr);
255     ierr = CeedQFunctionFieldGetSize(qf_input_fields[i], &size);
256     CeedChkBackend(ierr);
257     // Basis action
258     switch(eval_mode) {
259     case CEED_EVAL_NONE:
260       ierr = CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST,
261                                 CEED_USE_POINTER, &e_data_full[i][e*Q*size]);
262       CeedChkBackend(ierr);
263       break;
264     case CEED_EVAL_INTERP:
265       ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis);
266       CeedChkBackend(ierr);
267       ierr = CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST,
268                                 CEED_USE_POINTER, &e_data_full[i][e*elem_size*size]);
269       CeedChkBackend(ierr);
270       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_INTERP,
271                             impl->e_vecs_in[i], impl->q_vecs_in[i]); CeedChkBackend(ierr);
272       break;
273     case CEED_EVAL_GRAD:
274       ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis);
275       CeedChkBackend(ierr);
276       ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
277       ierr = CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST,
278                                 CEED_USE_POINTER, &e_data_full[i][e*elem_size*size/dim]);
279       CeedChkBackend(ierr);
280       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
281                             CEED_EVAL_GRAD, impl->e_vecs_in[i],
282                             impl->q_vecs_in[i]); CeedChkBackend(ierr);
283       break;
284     case CEED_EVAL_WEIGHT:
285       break;  // No action
286     // LCOV_EXCL_START
287     case CEED_EVAL_DIV:
288     case CEED_EVAL_CURL: {
289       ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis);
290       CeedChkBackend(ierr);
291       Ceed ceed;
292       ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
293       return CeedError(ceed, CEED_ERROR_BACKEND,
294                        "Ceed evaluation mode not implemented");
295       // LCOV_EXCL_STOP
296     }
297     }
298   }
299   return CEED_ERROR_SUCCESS;
300 }
301 
302 //------------------------------------------------------------------------------
303 // Output Basis Action
304 //------------------------------------------------------------------------------
305 static inline int CeedOperatorOutputBasis_Ref(CeedInt e, CeedInt Q,
306     CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields,
307     CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op,
308     CeedScalar *e_data_full[2*CEED_FIELD_MAX], CeedOperator_Ref *impl) {
309   CeedInt ierr;
310   CeedInt dim, elem_size, size;
311   CeedElemRestriction elem_restr;
312   CeedEvalMode eval_mode;
313   CeedBasis basis;
314 
315   for (CeedInt i=0; i<num_output_fields; i++) {
316     // Get elem_size, eval_mode, size
317     ierr = CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr);
318     CeedChkBackend(ierr);
319     ierr = CeedElemRestrictionGetElementSize(elem_restr, &elem_size);
320     CeedChkBackend(ierr);
321     ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode);
322     CeedChkBackend(ierr);
323     ierr = CeedQFunctionFieldGetSize(qf_output_fields[i], &size);
324     CeedChkBackend(ierr);
325     // Basis action
326     switch(eval_mode) {
327     case CEED_EVAL_NONE:
328       break; // No action
329     case CEED_EVAL_INTERP:
330       ierr = CeedOperatorFieldGetBasis(op_output_fields[i], &basis);
331       CeedChkBackend(ierr);
332       ierr = CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST,
333                                 CEED_USE_POINTER,
334                                 &e_data_full[i + num_input_fields][e*elem_size*size]);
335       CeedChkBackend(ierr);
336       ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
337                             CEED_EVAL_INTERP, impl->q_vecs_out[i],
338                             impl->e_vecs_out[i]); CeedChkBackend(ierr);
339       break;
340     case CEED_EVAL_GRAD:
341       ierr = CeedOperatorFieldGetBasis(op_output_fields[i], &basis);
342       CeedChkBackend(ierr);
343       ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
344       ierr = CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST,
345                                 CEED_USE_POINTER,
346                                 &e_data_full[i + num_input_fields][e*elem_size*size/dim]);
347       CeedChkBackend(ierr);
348       ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
349                             CEED_EVAL_GRAD, impl->q_vecs_out[i],
350                             impl->e_vecs_out[i]); CeedChkBackend(ierr);
351       break;
352     // LCOV_EXCL_START
353     case CEED_EVAL_WEIGHT: {
354       Ceed ceed;
355       ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
356       return CeedError(ceed, CEED_ERROR_BACKEND,
357                        "CEED_EVAL_WEIGHT cannot be an output "
358                        "evaluation mode");
359     }
360     case CEED_EVAL_DIV:
361     case CEED_EVAL_CURL: {
362       Ceed ceed;
363       ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
364       return CeedError(ceed, CEED_ERROR_BACKEND,
365                        "Ceed evaluation mode not implemented");
366       // LCOV_EXCL_STOP
367     }
368     }
369   }
370   return CEED_ERROR_SUCCESS;
371 }
372 
373 //------------------------------------------------------------------------------
374 // Restore Input Vectors
375 //------------------------------------------------------------------------------
376 static inline int CeedOperatorRestoreInputs_Ref(CeedInt num_input_fields,
377     CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
378     const bool skip_active, CeedScalar *e_data_full[2*CEED_FIELD_MAX],
379     CeedOperator_Ref *impl) {
380   CeedInt ierr;
381   CeedEvalMode eval_mode;
382 
383   for (CeedInt i=0; i<num_input_fields; i++) {
384     // Skip active inputs
385     if (skip_active) {
386       CeedVector vec;
387       ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec);
388       CeedChkBackend(ierr);
389       if (vec == CEED_VECTOR_ACTIVE)
390         continue;
391     }
392     // Restore input
393     ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode);
394     CeedChkBackend(ierr);
395     if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
396     } else {
397       ierr = CeedVectorRestoreArrayRead(impl->e_vecs_full[i],
398                                         (const CeedScalar **) &e_data_full[i]);
399       CeedChkBackend(ierr);
400     }
401   }
402   return CEED_ERROR_SUCCESS;
403 }
404 
405 //------------------------------------------------------------------------------
406 // Operator Apply
407 //------------------------------------------------------------------------------
408 static int CeedOperatorApplyAdd_Ref(CeedOperator op, CeedVector in_vec,
409                                     CeedVector out_vec, CeedRequest *request) {
410   int ierr;
411   CeedOperator_Ref *impl;
412   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
413   CeedQFunction qf;
414   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
415   CeedInt Q, num_elem, num_input_fields, num_output_fields, size;
416   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr);
417   ierr = CeedOperatorGetNumElements(op, &num_elem); CeedChkBackend(ierr);
418   CeedOperatorField *op_input_fields, *op_output_fields;
419   ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields,
420                                &num_output_fields, &op_output_fields);
421   CeedChkBackend(ierr);
422   CeedQFunctionField *qf_input_fields, *qf_output_fields;
423   ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL,
424                                 &qf_output_fields);
425   CeedChkBackend(ierr);
426   CeedEvalMode eval_mode;
427   CeedVector vec;
428   CeedElemRestriction elem_restr;
429   CeedScalar *e_data_full[2*CEED_FIELD_MAX] = {0};
430 
431   // Setup
432   ierr = CeedOperatorSetup_Ref(op); CeedChkBackend(ierr);
433 
434   // Restriction only operator
435   if (impl->is_identity_restr_op) {
436     ierr = CeedOperatorFieldGetElemRestriction(op_input_fields[0], &elem_restr);
437     CeedChkBackend(ierr);
438     ierr = CeedElemRestrictionApply(elem_restr, CEED_NOTRANSPOSE, in_vec,
439                                     impl->e_vecs_full[0], request);
440     CeedChkBackend(ierr);
441     ierr = CeedOperatorFieldGetElemRestriction(op_output_fields[0], &elem_restr);
442     CeedChkBackend(ierr);
443     ierr = CeedElemRestrictionApply(elem_restr, CEED_TRANSPOSE,
444                                     impl->e_vecs_full[0],
445                                     out_vec, request); CeedChkBackend(ierr);
446     return CEED_ERROR_SUCCESS;
447   }
448 
449   // Input Evecs and Restriction
450   ierr = CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields,
451                                      op_input_fields, in_vec, false, e_data_full, impl,
452                                      request); CeedChkBackend(ierr);
453 
454   // Output Evecs
455   for (CeedInt i=0; i<num_output_fields; i++) {
456     ierr = CeedVectorGetArrayWrite(impl->e_vecs_full[i+impl->num_inputs],
457                                    CEED_MEM_HOST, &e_data_full[i + num_input_fields]);
458     CeedChkBackend(ierr);
459   }
460 
461   // Loop through elements
462   for (CeedInt e=0; e<num_elem; e++) {
463     // Output pointers
464     for (CeedInt i=0; i<num_output_fields; i++) {
465       ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode);
466       CeedChkBackend(ierr);
467       if (eval_mode == CEED_EVAL_NONE) {
468         ierr = CeedQFunctionFieldGetSize(qf_output_fields[i], &size);
469         CeedChkBackend(ierr);
470         ierr = CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_HOST,
471                                   CEED_USE_POINTER,
472                                   &e_data_full[i + num_input_fields][e*Q*size]);
473         CeedChkBackend(ierr);
474       }
475     }
476 
477     // Input basis apply
478     ierr = CeedOperatorInputBasis_Ref(e, Q, qf_input_fields, op_input_fields,
479                                       num_input_fields, false, e_data_full, impl);
480     CeedChkBackend(ierr);
481 
482     // Q function
483     if (!impl->is_identity_qf) {
484       ierr = CeedQFunctionApply(qf, Q, impl->q_vecs_in, impl->q_vecs_out);
485       CeedChkBackend(ierr);
486     }
487 
488     // Output basis apply
489     ierr = CeedOperatorOutputBasis_Ref(e, Q, qf_output_fields, op_output_fields,
490                                        num_input_fields, num_output_fields, op,
491                                        e_data_full, impl); CeedChkBackend(ierr);
492   }
493 
494   // Output restriction
495   for (CeedInt i=0; i<num_output_fields; i++) {
496     // Restore Evec
497     ierr = CeedVectorRestoreArray(impl->e_vecs_full[i+impl->num_inputs],
498                                   &e_data_full[i + num_input_fields]);
499     CeedChkBackend(ierr);
500     // Get output vector
501     ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec);
502     CeedChkBackend(ierr);
503     // Active
504     if (vec == CEED_VECTOR_ACTIVE)
505       vec = out_vec;
506     // Restrict
507     ierr = CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr);
508     CeedChkBackend(ierr);
509     ierr = CeedElemRestrictionApply(elem_restr, CEED_TRANSPOSE,
510                                     impl->e_vecs_full[i+impl->num_inputs],
511                                     vec, request); CeedChkBackend(ierr);
512   }
513 
514   // Restore input arrays
515   ierr = CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields,
516                                        op_input_fields, false, e_data_full, impl);
517   CeedChkBackend(ierr);
518 
519   return CEED_ERROR_SUCCESS;
520 }
521 
522 //------------------------------------------------------------------------------
523 // Core code for assembling linear QFunction
524 //------------------------------------------------------------------------------
525 static inline int CeedOperatorLinearAssembleQFunctionCore_Ref(CeedOperator op,
526     bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr,
527     CeedRequest *request) {
528   int ierr;
529   CeedOperator_Ref *impl;
530   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
531   CeedQFunction qf;
532   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
533   CeedInt Q, num_elem, num_input_fields, num_output_fields, size;
534   CeedSize q_size;
535   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr);
536   ierr = CeedOperatorGetNumElements(op, &num_elem); CeedChkBackend(ierr);
537   CeedOperatorField *op_input_fields, *op_output_fields;
538   ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields,
539                                &num_output_fields, &op_output_fields);
540   CeedChkBackend(ierr);
541   CeedQFunctionField *qf_input_fields, *qf_output_fields;
542   ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL,
543                                 &qf_output_fields);
544   CeedChkBackend(ierr);
545   CeedVector vec;
546   CeedInt num_active_in = impl->num_active_in,
547           num_active_out = impl->num_active_out;
548   CeedVector *active_in = impl->qf_active_in;
549   CeedScalar *a, *tmp;
550   Ceed ceed, ceed_parent;
551   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
552   ierr = CeedGetOperatorFallbackParentCeed(ceed, &ceed_parent);
553   CeedChkBackend(ierr);
554   ceed_parent = ceed_parent ? ceed_parent : ceed;
555   CeedScalar *e_data_full[2*CEED_FIELD_MAX] = {0};
556 
557   // Setup
558   ierr = CeedOperatorSetup_Ref(op); CeedChkBackend(ierr);
559 
560   // Check for identity
561   if (impl->is_identity_qf)
562     // LCOV_EXCL_START
563     return CeedError(ceed, CEED_ERROR_BACKEND,
564                      "Assembling identity QFunctions not supported");
565   // LCOV_EXCL_STOP
566 
567   // Input Evecs and Restriction
568   ierr = CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields,
569                                      op_input_fields, NULL, true, e_data_full,
570                                      impl, request); CeedChkBackend(ierr);
571 
572   // Count number of active input fields
573   if (!num_active_in) {
574     for (CeedInt i=0; i<num_input_fields; i++) {
575       // Get input vector
576       ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec);
577       CeedChkBackend(ierr);
578       // Check if active input
579       if (vec == CEED_VECTOR_ACTIVE) {
580         ierr = CeedQFunctionFieldGetSize(qf_input_fields[i], &size);
581         CeedChkBackend(ierr);
582         ierr = CeedVectorSetValue(impl->q_vecs_in[i], 0.0); CeedChkBackend(ierr);
583         ierr = CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &tmp);
584         CeedChkBackend(ierr);
585         ierr = CeedRealloc(num_active_in + size, &active_in); CeedChkBackend(ierr);
586         for (CeedInt field=0; field<size; field++) {
587           q_size = (CeedSize)Q;
588           ierr = CeedVectorCreate(ceed, q_size, &active_in[num_active_in+field]);
589           CeedChkBackend(ierr);
590           ierr = CeedVectorSetArray(active_in[num_active_in+field], CEED_MEM_HOST,
591                                     CEED_USE_POINTER, &tmp[field*Q]);
592           CeedChkBackend(ierr);
593         }
594         num_active_in += size;
595         ierr = CeedVectorRestoreArray(impl->q_vecs_in[i], &tmp); CeedChkBackend(ierr);
596       }
597     }
598     impl->num_active_in = num_active_in;
599     impl->qf_active_in = active_in;
600   }
601 
602   // Count number of active output fields
603   if (!num_active_out) {
604     for (CeedInt i=0; i<num_output_fields; i++) {
605       // Get output vector
606       ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec);
607       CeedChkBackend(ierr);
608       // Check if active output
609       if (vec == CEED_VECTOR_ACTIVE) {
610         ierr = CeedQFunctionFieldGetSize(qf_output_fields[i], &size);
611         CeedChkBackend(ierr);
612         num_active_out += size;
613       }
614     }
615     impl->num_active_out = num_active_out;
616   }
617 
618   // Check sizes
619   if (!num_active_in || !num_active_out)
620     // LCOV_EXCL_START
621     return CeedError(ceed, CEED_ERROR_BACKEND,
622                      "Cannot assemble QFunction without active inputs "
623                      "and outputs");
624   // LCOV_EXCL_STOP
625 
626   // Build objects if needed
627   if (build_objects) {
628     // Create output restriction
629     CeedInt strides[3] = {1, Q, num_active_in*num_active_out*Q}; /* *NOPAD* */
630     ierr = CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q,
631                                             num_active_in*num_active_out,
632                                             num_active_in*num_active_out*num_elem*Q,
633                                             strides, rstr); CeedChkBackend(ierr);
634     // Create assembled vector
635     CeedSize l_size = (CeedSize)num_elem*Q*num_active_in*num_active_out;
636     ierr = CeedVectorCreate(ceed_parent, l_size, assembled); CeedChkBackend(ierr);
637   }
638   // Clear output vector
639   ierr = CeedVectorSetValue(*assembled, 0.0); CeedChkBackend(ierr);
640   ierr = CeedVectorGetArray(*assembled, CEED_MEM_HOST, &a); CeedChkBackend(ierr);
641 
642   // Loop through elements
643   for (CeedInt e=0; e<num_elem; e++) {
644     // Input basis apply
645     ierr = CeedOperatorInputBasis_Ref(e, Q, qf_input_fields, op_input_fields,
646                                       num_input_fields, true, e_data_full, impl);
647     CeedChkBackend(ierr);
648 
649     // Assemble QFunction
650     for (CeedInt in=0; in<num_active_in; in++) {
651       // Set Inputs
652       ierr = CeedVectorSetValue(active_in[in], 1.0); CeedChkBackend(ierr);
653       if (num_active_in > 1) {
654         ierr = CeedVectorSetValue(active_in[(in+num_active_in-1)%num_active_in],
655                                   0.0); CeedChkBackend(ierr);
656       }
657       // Set Outputs
658       for (CeedInt out=0; out<num_output_fields; out++) {
659         // Get output vector
660         ierr = CeedOperatorFieldGetVector(op_output_fields[out], &vec);
661         CeedChkBackend(ierr);
662         // Check if active output
663         if (vec == CEED_VECTOR_ACTIVE) {
664           CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST,
665                              CEED_USE_POINTER, a); CeedChkBackend(ierr);
666           ierr = CeedQFunctionFieldGetSize(qf_output_fields[out], &size);
667           CeedChkBackend(ierr);
668           a += size*Q; // Advance the pointer by the size of the output
669         }
670       }
671       // Apply QFunction
672       ierr = CeedQFunctionApply(qf, Q, impl->q_vecs_in, impl->q_vecs_out);
673       CeedChkBackend(ierr);
674     }
675   }
676 
677   // Un-set output Qvecs to prevent accidental overwrite of Assembled
678   for (CeedInt out=0; out<num_output_fields; out++) {
679     // Get output vector
680     ierr = CeedOperatorFieldGetVector(op_output_fields[out], &vec);
681     CeedChkBackend(ierr);
682     // Check if active output
683     if (vec == CEED_VECTOR_ACTIVE) {
684       CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_HOST, NULL);
685       CeedChkBackend(ierr);
686     }
687   }
688 
689   // Restore input arrays
690   ierr = CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields,
691                                        op_input_fields, true, e_data_full, impl);
692   CeedChkBackend(ierr);
693 
694   // Restore output
695   ierr = CeedVectorRestoreArray(*assembled, &a); CeedChkBackend(ierr);
696 
697   return CEED_ERROR_SUCCESS;
698 }
699 
700 //------------------------------------------------------------------------------
701 // Assemble Linear QFunction
702 //------------------------------------------------------------------------------
703 static int CeedOperatorLinearAssembleQFunction_Ref(CeedOperator op,
704     CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) {
705   return CeedOperatorLinearAssembleQFunctionCore_Ref(op, true, assembled, rstr,
706          request);
707 }
708 
709 //------------------------------------------------------------------------------
710 // Update Assembled Linear QFunction
711 //------------------------------------------------------------------------------
712 static int CeedOperatorLinearAssembleQFunctionUpdate_Ref(CeedOperator op,
713     CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) {
714   return CeedOperatorLinearAssembleQFunctionCore_Ref(op, false, &assembled,
715          &rstr, request);
716 }
717 
718 //------------------------------------------------------------------------------
719 // Operator Destroy
720 //------------------------------------------------------------------------------
721 static int CeedOperatorDestroy_Ref(CeedOperator op) {
722   int ierr;
723   CeedOperator_Ref *impl;
724   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
725 
726   for (CeedInt i=0; i<impl->num_inputs+impl->num_outputs; i++) {
727     ierr = CeedVectorDestroy(&impl->e_vecs_full[i]); CeedChkBackend(ierr);
728   }
729   ierr = CeedFree(&impl->e_vecs_full); CeedChkBackend(ierr);
730   ierr = CeedFree(&impl->input_states); CeedChkBackend(ierr);
731 
732   for (CeedInt i=0; i<impl->num_inputs; i++) {
733     ierr = CeedVectorDestroy(&impl->e_vecs_in[i]); CeedChkBackend(ierr);
734     ierr = CeedVectorDestroy(&impl->q_vecs_in[i]); CeedChkBackend(ierr);
735   }
736   ierr = CeedFree(&impl->e_vecs_in); CeedChkBackend(ierr);
737   ierr = CeedFree(&impl->q_vecs_in); CeedChkBackend(ierr);
738 
739   for (CeedInt i=0; i<impl->num_outputs; i++) {
740     ierr = CeedVectorDestroy(&impl->e_vecs_out[i]); CeedChkBackend(ierr);
741     ierr = CeedVectorDestroy(&impl->q_vecs_out[i]); CeedChkBackend(ierr);
742   }
743   ierr = CeedFree(&impl->e_vecs_out); CeedChkBackend(ierr);
744   ierr = CeedFree(&impl->q_vecs_out); CeedChkBackend(ierr);
745 
746   // QFunction assembly
747   for (CeedInt i=0; i<impl->num_active_in; i++) {
748     ierr = CeedVectorDestroy(&impl->qf_active_in[i]); CeedChkBackend(ierr);
749   }
750   ierr = CeedFree(&impl->qf_active_in); CeedChkBackend(ierr);
751 
752   ierr = CeedFree(&impl); CeedChkBackend(ierr);
753   return CEED_ERROR_SUCCESS;
754 }
755 
756 //------------------------------------------------------------------------------
757 // Operator Create
758 //------------------------------------------------------------------------------
759 int CeedOperatorCreate_Ref(CeedOperator op) {
760   int ierr;
761   Ceed ceed;
762   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
763   CeedOperator_Ref *impl;
764 
765   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
766   ierr = CeedOperatorSetData(op, impl); CeedChkBackend(ierr);
767 
768   ierr = CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction",
769                                 CeedOperatorLinearAssembleQFunction_Ref);
770   CeedChkBackend(ierr);
771   ierr = CeedSetBackendFunction(ceed, "Operator", op,
772                                 "LinearAssembleQFunctionUpdate",
773                                 CeedOperatorLinearAssembleQFunctionUpdate_Ref);
774   CeedChkBackend(ierr);
775   ierr = CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd",
776                                 CeedOperatorApplyAdd_Ref); CeedChkBackend(ierr);
777   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
778                                 CeedOperatorDestroy_Ref); CeedChkBackend(ierr);
779   return CEED_ERROR_SUCCESS;
780 }
781