xref: /libCEED/backends/ref/ceed-ref-operator.c (revision 6d69246ad8d06846a22b2f7e3720a77f7a77e7f6)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed/ceed.h>
18 #include <ceed/backend.h>
19 #include <math.h>
20 #include <stdbool.h>
21 #include <stddef.h>
22 #include <stdint.h>
23 #include "ceed-ref.h"
24 
25 //------------------------------------------------------------------------------
26 // Setup Input/Output Fields
27 //------------------------------------------------------------------------------
28 static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op,
29                                        bool is_input, CeedVector *e_vecs_full,
30                                        CeedVector *e_vecs, CeedVector *q_vecs,
31                                        CeedInt start_e, CeedInt num_fields,
32                                        CeedInt Q) {
33   CeedInt ierr, num_comp, size, P;
34   Ceed ceed;
35   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
36   CeedBasis basis;
37   CeedElemRestriction elem_restr;
38   CeedOperatorField *op_fields;
39   CeedQFunctionField *qf_fields;
40   if (is_input) {
41     ierr = CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL);
42     CeedChkBackend(ierr);
43     ierr = CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL);
44     CeedChkBackend(ierr);
45   } else {
46     ierr = CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields);
47     CeedChkBackend(ierr);
48     ierr = CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields);
49     CeedChkBackend(ierr);
50   }
51 
52   // Loop over fields
53   for (CeedInt i=0; i<num_fields; i++) {
54     CeedEvalMode eval_mode;
55     ierr = CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode);
56     CeedChkBackend(ierr);
57 
58     if (eval_mode != CEED_EVAL_WEIGHT) {
59       ierr = CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_restr);
60       CeedChkBackend(ierr);
61       ierr = CeedElemRestrictionCreateVector(elem_restr, NULL,
62                                              &e_vecs_full[i+start_e]);
63       CeedChkBackend(ierr);
64     }
65 
66     switch(eval_mode) {
67     case CEED_EVAL_NONE:
68       ierr = CeedQFunctionFieldGetSize(qf_fields[i], &size); CeedChkBackend(ierr);
69       ierr = CeedVectorCreate(ceed, Q*size, &q_vecs[i]); CeedChkBackend(ierr);
70       break;
71     case CEED_EVAL_INTERP:
72     case CEED_EVAL_GRAD:
73       ierr = CeedOperatorFieldGetBasis(op_fields[i], &basis); CeedChkBackend(ierr);
74       ierr = CeedQFunctionFieldGetSize(qf_fields[i], &size); CeedChkBackend(ierr);
75       ierr = CeedBasisGetNumNodes(basis, &P); CeedChkBackend(ierr);
76       ierr = CeedBasisGetNumComponents(basis, &num_comp); CeedChkBackend(ierr);
77       ierr = CeedVectorCreate(ceed, P*num_comp, &e_vecs[i]); CeedChkBackend(ierr);
78       ierr = CeedVectorCreate(ceed, Q*size, &q_vecs[i]); CeedChkBackend(ierr);
79       break;
80     case CEED_EVAL_WEIGHT: // Only on input fields
81       ierr = CeedOperatorFieldGetBasis(op_fields[i], &basis); CeedChkBackend(ierr);
82       ierr = CeedVectorCreate(ceed, Q, &q_vecs[i]); CeedChkBackend(ierr);
83       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
84                             CEED_VECTOR_NONE, q_vecs[i]); CeedChkBackend(ierr);
85       break;
86     case CEED_EVAL_DIV:
87       break; // Not implemented
88     case CEED_EVAL_CURL:
89       break; // Not implemented
90     }
91   }
92   return CEED_ERROR_SUCCESS;
93 }
94 
95 //------------------------------------------------------------------------------
96 // Setup Operator
97 //------------------------------------------------------------------------------/*
98 static int CeedOperatorSetup_Ref(CeedOperator op) {
99   int ierr;
100   bool is_setup_done;
101   ierr = CeedOperatorIsSetupDone(op, &is_setup_done); CeedChkBackend(ierr);
102   if (is_setup_done) return CEED_ERROR_SUCCESS;
103   Ceed ceed;
104   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
105   CeedOperator_Ref *impl;
106   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
107   CeedQFunction qf;
108   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
109   CeedInt Q, num_input_fields, num_output_fields;
110   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr);
111   ierr = CeedQFunctionIsIdentity(qf, &impl->is_identity_qf); CeedChkBackend(ierr);
112   CeedOperatorField *op_input_fields, *op_output_fields;
113   ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields,
114                                &num_output_fields, &op_output_fields);
115   CeedChkBackend(ierr);
116   CeedQFunctionField *qf_input_fields, *qf_output_fields;
117   ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL,
118                                 &qf_output_fields);
119   CeedChkBackend(ierr);
120 
121   // Allocate
122   ierr = CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full);
123   CeedChkBackend(ierr);
124 
125   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->input_states); CeedChkBackend(ierr);
126   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in); CeedChkBackend(ierr);
127   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out); CeedChkBackend(ierr);
128   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in); CeedChkBackend(ierr);
129   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out); CeedChkBackend(ierr);
130 
131   impl->num_inputs = num_input_fields;
132   impl->num_outputs = num_output_fields;
133 
134   // Set up infield and outfield e_vecs and q_vecs
135   // Infields
136   ierr = CeedOperatorSetupFields_Ref(qf, op, true, impl->e_vecs_full,
137                                      impl->e_vecs_in, impl->q_vecs_in, 0,
138                                      num_input_fields, Q);
139   CeedChkBackend(ierr);
140   // Outfields
141   ierr = CeedOperatorSetupFields_Ref(qf, op, false, impl->e_vecs_full,
142                                      impl->e_vecs_out, impl->q_vecs_out,
143                                      num_input_fields, num_output_fields, Q);
144   CeedChkBackend(ierr);
145 
146   // Identity QFunctions
147   if (impl->is_identity_qf) {
148     CeedEvalMode in_mode, out_mode;
149     CeedQFunctionField *in_fields, *out_fields;
150     ierr = CeedQFunctionGetFields(qf, NULL, &in_fields, NULL, &out_fields);
151     CeedChkBackend(ierr);
152     ierr = CeedQFunctionFieldGetEvalMode(in_fields[0], &in_mode);
153     CeedChkBackend(ierr);
154     ierr = CeedQFunctionFieldGetEvalMode(out_fields[0], &out_mode);
155     CeedChkBackend(ierr);
156 
157     if (in_mode == CEED_EVAL_NONE && out_mode == CEED_EVAL_NONE) {
158       impl->is_identity_restr_op = true;
159     } else {
160       ierr = CeedVectorDestroy(&impl->q_vecs_out[0]); CeedChkBackend(ierr);
161       impl->q_vecs_out[0] = impl->q_vecs_in[0];
162       ierr = CeedVectorAddReference(impl->q_vecs_in[0]); CeedChkBackend(ierr);
163     }
164   }
165 
166   ierr = CeedOperatorSetSetupDone(op); CeedChkBackend(ierr);
167 
168   return CEED_ERROR_SUCCESS;
169 }
170 
171 //------------------------------------------------------------------------------
172 // Setup Operator Inputs
173 //------------------------------------------------------------------------------
174 static inline int CeedOperatorSetupInputs_Ref(CeedInt num_input_fields,
175     CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
176     CeedVector in_vec, const bool skip_active,
177     CeedScalar *e_data_full[2*CEED_FIELD_MAX],
178     CeedOperator_Ref *impl, CeedRequest *request) {
179   CeedInt ierr;
180   CeedEvalMode eval_mode;
181   CeedVector vec;
182   CeedElemRestriction elem_restr;
183   uint64_t state;
184 
185   for (CeedInt i=0; i<num_input_fields; i++) {
186     // Get input vector
187     ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec);
188     CeedChkBackend(ierr);
189     if (vec == CEED_VECTOR_ACTIVE) {
190       if (skip_active)
191         continue;
192       else
193         vec = in_vec;
194     }
195 
196     ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode);
197     CeedChkBackend(ierr);
198     // Restrict and Evec
199     if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
200     } else {
201       // Restrict
202       ierr = CeedVectorGetState(vec, &state); CeedChkBackend(ierr);
203       // Skip restriction if input is unchanged
204       if (state != impl->input_states[i] || vec == in_vec) {
205         ierr = CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr);
206         CeedChkBackend(ierr);
207         ierr = CeedElemRestrictionApply(elem_restr, CEED_NOTRANSPOSE, vec,
208                                         impl->e_vecs_full[i], request);
209         CeedChkBackend(ierr);
210         impl->input_states[i] = state;
211       }
212       // Get evec
213       ierr = CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST,
214                                     (const CeedScalar **) &e_data_full[i]);
215       CeedChkBackend(ierr);
216     }
217   }
218   return CEED_ERROR_SUCCESS;
219 }
220 
221 //------------------------------------------------------------------------------
222 // Input Basis Action
223 //------------------------------------------------------------------------------
224 static inline int CeedOperatorInputBasis_Ref(CeedInt e, CeedInt Q,
225     CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
226     CeedInt num_input_fields, const bool skip_active,
227     CeedScalar *e_data_full[2*CEED_FIELD_MAX], CeedOperator_Ref *impl) {
228   CeedInt ierr;
229   CeedInt dim, elem_size, size;
230   CeedElemRestriction elem_restr;
231   CeedEvalMode eval_mode;
232   CeedBasis basis;
233 
234   for (CeedInt i=0; i<num_input_fields; i++) {
235     // Skip active input
236     if (skip_active) {
237       CeedVector vec;
238       ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec);
239       CeedChkBackend(ierr);
240       if (vec == CEED_VECTOR_ACTIVE)
241         continue;
242     }
243     // Get elem_size, eval_mode, size
244     ierr = CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr);
245     CeedChkBackend(ierr);
246     ierr = CeedElemRestrictionGetElementSize(elem_restr, &elem_size);
247     CeedChkBackend(ierr);
248     ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode);
249     CeedChkBackend(ierr);
250     ierr = CeedQFunctionFieldGetSize(qf_input_fields[i], &size);
251     CeedChkBackend(ierr);
252     // Basis action
253     switch(eval_mode) {
254     case CEED_EVAL_NONE:
255       ierr = CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST,
256                                 CEED_USE_POINTER, &e_data_full[i][e*Q*size]);
257       CeedChkBackend(ierr);
258       break;
259     case CEED_EVAL_INTERP:
260       ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis);
261       CeedChkBackend(ierr);
262       ierr = CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST,
263                                 CEED_USE_POINTER, &e_data_full[i][e*elem_size*size]);
264       CeedChkBackend(ierr);
265       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_INTERP,
266                             impl->e_vecs_in[i], impl->q_vecs_in[i]); CeedChkBackend(ierr);
267       break;
268     case CEED_EVAL_GRAD:
269       ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis);
270       CeedChkBackend(ierr);
271       ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
272       ierr = CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST,
273                                 CEED_USE_POINTER, &e_data_full[i][e*elem_size*size/dim]);
274       CeedChkBackend(ierr);
275       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
276                             CEED_EVAL_GRAD, impl->e_vecs_in[i],
277                             impl->q_vecs_in[i]); CeedChkBackend(ierr);
278       break;
279     case CEED_EVAL_WEIGHT:
280       break;  // No action
281     // LCOV_EXCL_START
282     case CEED_EVAL_DIV:
283     case CEED_EVAL_CURL: {
284       ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis);
285       CeedChkBackend(ierr);
286       Ceed ceed;
287       ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
288       return CeedError(ceed, CEED_ERROR_BACKEND,
289                        "Ceed evaluation mode not implemented");
290       // LCOV_EXCL_STOP
291     }
292     }
293   }
294   return CEED_ERROR_SUCCESS;
295 }
296 
297 //------------------------------------------------------------------------------
298 // Output Basis Action
299 //------------------------------------------------------------------------------
300 static inline int CeedOperatorOutputBasis_Ref(CeedInt e, CeedInt Q,
301     CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields,
302     CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op,
303     CeedScalar *e_data_full[2*CEED_FIELD_MAX], CeedOperator_Ref *impl) {
304   CeedInt ierr;
305   CeedInt dim, elem_size, size;
306   CeedElemRestriction elem_restr;
307   CeedEvalMode eval_mode;
308   CeedBasis basis;
309 
310   for (CeedInt i=0; i<num_output_fields; i++) {
311     // Get elem_size, eval_mode, size
312     ierr = CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr);
313     CeedChkBackend(ierr);
314     ierr = CeedElemRestrictionGetElementSize(elem_restr, &elem_size);
315     CeedChkBackend(ierr);
316     ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode);
317     CeedChkBackend(ierr);
318     ierr = CeedQFunctionFieldGetSize(qf_output_fields[i], &size);
319     CeedChkBackend(ierr);
320     // Basis action
321     switch(eval_mode) {
322     case CEED_EVAL_NONE:
323       break; // No action
324     case CEED_EVAL_INTERP:
325       ierr = CeedOperatorFieldGetBasis(op_output_fields[i], &basis);
326       CeedChkBackend(ierr);
327       ierr = CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST,
328                                 CEED_USE_POINTER,
329                                 &e_data_full[i + num_input_fields][e*elem_size*size]);
330       CeedChkBackend(ierr);
331       ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
332                             CEED_EVAL_INTERP, impl->q_vecs_out[i],
333                             impl->e_vecs_out[i]); CeedChkBackend(ierr);
334       break;
335     case CEED_EVAL_GRAD:
336       ierr = CeedOperatorFieldGetBasis(op_output_fields[i], &basis);
337       CeedChkBackend(ierr);
338       ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
339       ierr = CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST,
340                                 CEED_USE_POINTER,
341                                 &e_data_full[i + num_input_fields][e*elem_size*size/dim]);
342       CeedChkBackend(ierr);
343       ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
344                             CEED_EVAL_GRAD, impl->q_vecs_out[i],
345                             impl->e_vecs_out[i]); CeedChkBackend(ierr);
346       break;
347     // LCOV_EXCL_START
348     case CEED_EVAL_WEIGHT: {
349       Ceed ceed;
350       ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
351       return CeedError(ceed, CEED_ERROR_BACKEND,
352                        "CEED_EVAL_WEIGHT cannot be an output "
353                        "evaluation mode");
354     }
355     case CEED_EVAL_DIV:
356     case CEED_EVAL_CURL: {
357       Ceed ceed;
358       ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
359       return CeedError(ceed, CEED_ERROR_BACKEND,
360                        "Ceed evaluation mode not implemented");
361       // LCOV_EXCL_STOP
362     }
363     }
364   }
365   return CEED_ERROR_SUCCESS;
366 }
367 
368 //------------------------------------------------------------------------------
369 // Restore Input Vectors
370 //------------------------------------------------------------------------------
371 static inline int CeedOperatorRestoreInputs_Ref(CeedInt num_input_fields,
372     CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
373     const bool skip_active, CeedScalar *e_data_full[2*CEED_FIELD_MAX],
374     CeedOperator_Ref *impl) {
375   CeedInt ierr;
376   CeedEvalMode eval_mode;
377 
378   for (CeedInt i=0; i<num_input_fields; i++) {
379     // Skip active inputs
380     if (skip_active) {
381       CeedVector vec;
382       ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec);
383       CeedChkBackend(ierr);
384       if (vec == CEED_VECTOR_ACTIVE)
385         continue;
386     }
387     // Restore input
388     ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode);
389     CeedChkBackend(ierr);
390     if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
391     } else {
392       ierr = CeedVectorRestoreArrayRead(impl->e_vecs_full[i],
393                                         (const CeedScalar **) &e_data_full[i]);
394       CeedChkBackend(ierr);
395     }
396   }
397   return CEED_ERROR_SUCCESS;
398 }
399 
400 //------------------------------------------------------------------------------
401 // Operator Apply
402 //------------------------------------------------------------------------------
403 static int CeedOperatorApplyAdd_Ref(CeedOperator op, CeedVector in_vec,
404                                     CeedVector out_vec, CeedRequest *request) {
405   int ierr;
406   CeedOperator_Ref *impl;
407   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
408   CeedQFunction qf;
409   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
410   CeedInt Q, num_elem, num_input_fields, num_output_fields, size;
411   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr);
412   ierr = CeedOperatorGetNumElements(op, &num_elem); CeedChkBackend(ierr);
413   CeedOperatorField *op_input_fields, *op_output_fields;
414   ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields,
415                                &num_output_fields, &op_output_fields);
416   CeedChkBackend(ierr);
417   CeedQFunctionField *qf_input_fields, *qf_output_fields;
418   ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL,
419                                 &qf_output_fields);
420   CeedChkBackend(ierr);
421   CeedEvalMode eval_mode;
422   CeedVector vec;
423   CeedElemRestriction elem_restr;
424   CeedScalar *e_data_full[2*CEED_FIELD_MAX] = {0};
425 
426   // Setup
427   ierr = CeedOperatorSetup_Ref(op); CeedChkBackend(ierr);
428 
429   // Restriction only operator
430   if (impl->is_identity_restr_op) {
431     ierr = CeedOperatorFieldGetElemRestriction(op_input_fields[0], &elem_restr);
432     CeedChkBackend(ierr);
433     ierr = CeedElemRestrictionApply(elem_restr, CEED_NOTRANSPOSE, in_vec,
434                                     impl->e_vecs_full[0], request);
435     CeedChkBackend(ierr);
436     ierr = CeedOperatorFieldGetElemRestriction(op_output_fields[0], &elem_restr);
437     CeedChkBackend(ierr);
438     ierr = CeedElemRestrictionApply(elem_restr, CEED_TRANSPOSE,
439                                     impl->e_vecs_full[0],
440                                     out_vec, request); CeedChkBackend(ierr);
441     return CEED_ERROR_SUCCESS;
442   }
443 
444   // Input Evecs and Restriction
445   ierr = CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields,
446                                      op_input_fields, in_vec, false, e_data_full, impl,
447                                      request); CeedChkBackend(ierr);
448 
449   // Output Evecs
450   for (CeedInt i=0; i<num_output_fields; i++) {
451     ierr = CeedVectorGetArrayWrite(impl->e_vecs_full[i+impl->num_inputs],
452                                    CEED_MEM_HOST, &e_data_full[i + num_input_fields]);
453     CeedChkBackend(ierr);
454   }
455 
456   // Loop through elements
457   for (CeedInt e=0; e<num_elem; e++) {
458     // Output pointers
459     for (CeedInt i=0; i<num_output_fields; i++) {
460       ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode);
461       CeedChkBackend(ierr);
462       if (eval_mode == CEED_EVAL_NONE) {
463         ierr = CeedQFunctionFieldGetSize(qf_output_fields[i], &size);
464         CeedChkBackend(ierr);
465         ierr = CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_HOST,
466                                   CEED_USE_POINTER,
467                                   &e_data_full[i + num_input_fields][e*Q*size]);
468         CeedChkBackend(ierr);
469       }
470     }
471 
472     // Input basis apply
473     ierr = CeedOperatorInputBasis_Ref(e, Q, qf_input_fields, op_input_fields,
474                                       num_input_fields, false, e_data_full, impl);
475     CeedChkBackend(ierr);
476 
477     // Q function
478     if (!impl->is_identity_qf) {
479       ierr = CeedQFunctionApply(qf, Q, impl->q_vecs_in, impl->q_vecs_out);
480       CeedChkBackend(ierr);
481     }
482 
483     // Output basis apply
484     ierr = CeedOperatorOutputBasis_Ref(e, Q, qf_output_fields, op_output_fields,
485                                        num_input_fields, num_output_fields, op,
486                                        e_data_full, impl); CeedChkBackend(ierr);
487   }
488 
489   // Output restriction
490   for (CeedInt i=0; i<num_output_fields; i++) {
491     // Restore Evec
492     ierr = CeedVectorRestoreArray(impl->e_vecs_full[i+impl->num_inputs],
493                                   &e_data_full[i + num_input_fields]);
494     CeedChkBackend(ierr);
495     // Get output vector
496     ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec);
497     CeedChkBackend(ierr);
498     // Active
499     if (vec == CEED_VECTOR_ACTIVE)
500       vec = out_vec;
501     // Restrict
502     ierr = CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr);
503     CeedChkBackend(ierr);
504     ierr = CeedElemRestrictionApply(elem_restr, CEED_TRANSPOSE,
505                                     impl->e_vecs_full[i+impl->num_inputs],
506                                     vec, request); CeedChkBackend(ierr);
507   }
508 
509   // Restore input arrays
510   ierr = CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields,
511                                        op_input_fields, false, e_data_full, impl);
512   CeedChkBackend(ierr);
513 
514   return CEED_ERROR_SUCCESS;
515 }
516 
517 //------------------------------------------------------------------------------
518 // Core code for assembling linear QFunction
519 //------------------------------------------------------------------------------
520 static inline int CeedOperatorLinearAssembleQFunctionCore_Ref(CeedOperator op,
521     bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr,
522     CeedRequest *request) {
523   int ierr;
524   CeedOperator_Ref *impl;
525   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
526   CeedQFunction qf;
527   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
528   CeedInt Q, num_elem, num_input_fields, num_output_fields, size;
529   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr);
530   ierr = CeedOperatorGetNumElements(op, &num_elem); CeedChkBackend(ierr);
531   CeedOperatorField *op_input_fields, *op_output_fields;
532   ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields,
533                                &num_output_fields, &op_output_fields);
534   CeedChkBackend(ierr);
535   CeedQFunctionField *qf_input_fields, *qf_output_fields;
536   ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL,
537                                 &qf_output_fields);
538   CeedChkBackend(ierr);
539   CeedVector vec;
540   CeedInt num_active_in = impl->num_active_in,
541           num_active_out = impl->num_active_out;
542   CeedVector *active_in = impl->qf_active_in;
543   CeedScalar *a, *tmp;
544   Ceed ceed, ceed_parent;
545   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
546   ierr = CeedGetOperatorFallbackParentCeed(ceed, &ceed_parent);
547   CeedChkBackend(ierr);
548   ceed_parent = ceed_parent ? ceed_parent : ceed;
549   CeedScalar *e_data_full[2*CEED_FIELD_MAX] = {0};
550 
551   // Setup
552   ierr = CeedOperatorSetup_Ref(op); CeedChkBackend(ierr);
553 
554   // Check for identity
555   if (impl->is_identity_qf)
556     // LCOV_EXCL_START
557     return CeedError(ceed, CEED_ERROR_BACKEND,
558                      "Assembling identity QFunctions not supported");
559   // LCOV_EXCL_STOP
560 
561   // Input Evecs and Restriction
562   ierr = CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields,
563                                      op_input_fields, NULL, true, e_data_full,
564                                      impl, request); CeedChkBackend(ierr);
565 
566   // Count number of active input fields
567   if (!num_active_in) {
568     for (CeedInt i=0; i<num_input_fields; i++) {
569       // Get input vector
570       ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec);
571       CeedChkBackend(ierr);
572       // Check if active input
573       if (vec == CEED_VECTOR_ACTIVE) {
574         ierr = CeedQFunctionFieldGetSize(qf_input_fields[i], &size);
575         CeedChkBackend(ierr);
576         ierr = CeedVectorSetValue(impl->q_vecs_in[i], 0.0); CeedChkBackend(ierr);
577         ierr = CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &tmp);
578         CeedChkBackend(ierr);
579         ierr = CeedRealloc(num_active_in + size, &active_in); CeedChkBackend(ierr);
580         for (CeedInt field=0; field<size; field++) {
581           ierr = CeedVectorCreate(ceed, Q, &active_in[num_active_in+field]);
582           CeedChkBackend(ierr);
583           ierr = CeedVectorSetArray(active_in[num_active_in+field], CEED_MEM_HOST,
584                                     CEED_USE_POINTER, &tmp[field*Q]);
585           CeedChkBackend(ierr);
586         }
587         num_active_in += size;
588         ierr = CeedVectorRestoreArray(impl->q_vecs_in[i], &tmp); CeedChkBackend(ierr);
589       }
590     }
591     impl->num_active_in = num_active_in;
592     impl->qf_active_in = active_in;
593   }
594 
595   // Count number of active output fields
596   if (!num_active_out) {
597     for (CeedInt i=0; i<num_output_fields; i++) {
598       // Get output vector
599       ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec);
600       CeedChkBackend(ierr);
601       // Check if active output
602       if (vec == CEED_VECTOR_ACTIVE) {
603         ierr = CeedQFunctionFieldGetSize(qf_output_fields[i], &size);
604         CeedChkBackend(ierr);
605         num_active_out += size;
606       }
607     }
608     impl->num_active_out = num_active_out;
609   }
610 
611   // Check sizes
612   if (!num_active_in || !num_active_out)
613     // LCOV_EXCL_START
614     return CeedError(ceed, CEED_ERROR_BACKEND,
615                      "Cannot assemble QFunction without active inputs "
616                      "and outputs");
617   // LCOV_EXCL_STOP
618 
619   // Build objects if needed
620   if (build_objects) {
621     // Create output restriction
622     CeedInt strides[3] = {1, Q, num_active_in*num_active_out*Q}; /* *NOPAD* */
623     ierr = CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q,
624                                             num_active_in*num_active_out,
625                                             num_active_in*num_active_out*num_elem*Q,
626                                             strides, rstr); CeedChkBackend(ierr);
627     // Create assembled vector
628     ierr = CeedVectorCreate(ceed_parent, num_elem*Q*num_active_in*num_active_out,
629                             assembled); CeedChkBackend(ierr);
630   }
631   // Clear output vector
632   ierr = CeedVectorSetValue(*assembled, 0.0); CeedChkBackend(ierr);
633   ierr = CeedVectorGetArray(*assembled, CEED_MEM_HOST, &a); CeedChkBackend(ierr);
634 
635   // Loop through elements
636   for (CeedInt e=0; e<num_elem; e++) {
637     // Input basis apply
638     ierr = CeedOperatorInputBasis_Ref(e, Q, qf_input_fields, op_input_fields,
639                                       num_input_fields, true, e_data_full, impl);
640     CeedChkBackend(ierr);
641 
642     // Assemble QFunction
643     for (CeedInt in=0; in<num_active_in; in++) {
644       // Set Inputs
645       ierr = CeedVectorSetValue(active_in[in], 1.0); CeedChkBackend(ierr);
646       if (num_active_in > 1) {
647         ierr = CeedVectorSetValue(active_in[(in+num_active_in-1)%num_active_in],
648                                   0.0); CeedChkBackend(ierr);
649       }
650       // Set Outputs
651       for (CeedInt out=0; out<num_output_fields; out++) {
652         // Get output vector
653         ierr = CeedOperatorFieldGetVector(op_output_fields[out], &vec);
654         CeedChkBackend(ierr);
655         // Check if active output
656         if (vec == CEED_VECTOR_ACTIVE) {
657           CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST,
658                              CEED_USE_POINTER, a); CeedChkBackend(ierr);
659           ierr = CeedQFunctionFieldGetSize(qf_output_fields[out], &size);
660           CeedChkBackend(ierr);
661           a += size*Q; // Advance the pointer by the size of the output
662         }
663       }
664       // Apply QFunction
665       ierr = CeedQFunctionApply(qf, Q, impl->q_vecs_in, impl->q_vecs_out);
666       CeedChkBackend(ierr);
667     }
668   }
669 
670   // Un-set output Qvecs to prevent accidental overwrite of Assembled
671   for (CeedInt out=0; out<num_output_fields; out++) {
672     // Get output vector
673     ierr = CeedOperatorFieldGetVector(op_output_fields[out], &vec);
674     CeedChkBackend(ierr);
675     // Check if active output
676     if (vec == CEED_VECTOR_ACTIVE) {
677       CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_HOST, NULL);
678       CeedChkBackend(ierr);
679     }
680   }
681 
682   // Restore input arrays
683   ierr = CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields,
684                                        op_input_fields, true, e_data_full, impl);
685   CeedChkBackend(ierr);
686 
687   // Restore output
688   ierr = CeedVectorRestoreArray(*assembled, &a); CeedChkBackend(ierr);
689 
690   return CEED_ERROR_SUCCESS;
691 }
692 
693 //------------------------------------------------------------------------------
694 // Assemble Linear QFunction
695 //------------------------------------------------------------------------------
696 static int CeedOperatorLinearAssembleQFunction_Ref(CeedOperator op,
697     CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) {
698   return CeedOperatorLinearAssembleQFunctionCore_Ref(op, true, assembled, rstr,
699          request);
700 }
701 
702 //------------------------------------------------------------------------------
703 // Update Assembled Linear QFunction
704 //------------------------------------------------------------------------------
705 static int CeedOperatorLinearAssembleQFunctionUpdate_Ref(CeedOperator op,
706     CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) {
707   return CeedOperatorLinearAssembleQFunctionCore_Ref(op, false, &assembled,
708          &rstr, request);
709 }
710 
711 //------------------------------------------------------------------------------
712 // Operator Destroy
713 //------------------------------------------------------------------------------
714 static int CeedOperatorDestroy_Ref(CeedOperator op) {
715   int ierr;
716   CeedOperator_Ref *impl;
717   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
718 
719   for (CeedInt i=0; i<impl->num_inputs+impl->num_outputs; i++) {
720     ierr = CeedVectorDestroy(&impl->e_vecs_full[i]); CeedChkBackend(ierr);
721   }
722   ierr = CeedFree(&impl->e_vecs_full); CeedChkBackend(ierr);
723   ierr = CeedFree(&impl->input_states); CeedChkBackend(ierr);
724 
725   for (CeedInt i=0; i<impl->num_inputs; i++) {
726     ierr = CeedVectorDestroy(&impl->e_vecs_in[i]); CeedChkBackend(ierr);
727     ierr = CeedVectorDestroy(&impl->q_vecs_in[i]); CeedChkBackend(ierr);
728   }
729   ierr = CeedFree(&impl->e_vecs_in); CeedChkBackend(ierr);
730   ierr = CeedFree(&impl->q_vecs_in); CeedChkBackend(ierr);
731 
732   for (CeedInt i=0; i<impl->num_outputs; i++) {
733     ierr = CeedVectorDestroy(&impl->e_vecs_out[i]); CeedChkBackend(ierr);
734     ierr = CeedVectorDestroy(&impl->q_vecs_out[i]); CeedChkBackend(ierr);
735   }
736   ierr = CeedFree(&impl->e_vecs_out); CeedChkBackend(ierr);
737   ierr = CeedFree(&impl->q_vecs_out); CeedChkBackend(ierr);
738 
739   // QFunction assembly
740   for (CeedInt i=0; i<impl->num_active_in; i++) {
741     ierr = CeedVectorDestroy(&impl->qf_active_in[i]); CeedChkBackend(ierr);
742   }
743   ierr = CeedFree(&impl->qf_active_in); CeedChkBackend(ierr);
744 
745   ierr = CeedFree(&impl); CeedChkBackend(ierr);
746   return CEED_ERROR_SUCCESS;
747 }
748 
749 //------------------------------------------------------------------------------
750 // Operator Create
751 //------------------------------------------------------------------------------
752 int CeedOperatorCreate_Ref(CeedOperator op) {
753   int ierr;
754   Ceed ceed;
755   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
756   CeedOperator_Ref *impl;
757 
758   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
759   ierr = CeedOperatorSetData(op, impl); CeedChkBackend(ierr);
760 
761   ierr = CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction",
762                                 CeedOperatorLinearAssembleQFunction_Ref);
763   CeedChkBackend(ierr);
764   ierr = CeedSetBackendFunction(ceed, "Operator", op,
765                                 "LinearAssembleQFunctionUpdate",
766                                 CeedOperatorLinearAssembleQFunctionUpdate_Ref);
767   CeedChkBackend(ierr);
768   ierr = CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd",
769                                 CeedOperatorApplyAdd_Ref); CeedChkBackend(ierr);
770   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
771                                 CeedOperatorDestroy_Ref); CeedChkBackend(ierr);
772   return CEED_ERROR_SUCCESS;
773 }
774