xref: /libCEED/backends/ref/ceed-ref-operator.c (revision a0162de9bfaf277d9d50534bbae669b234f08437)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed/ceed.h>
18 #include <ceed/backend.h>
19 #include <math.h>
20 #include <stdbool.h>
21 #include <stddef.h>
22 #include <stdint.h>
23 #include "ceed-ref.h"
24 
25 //------------------------------------------------------------------------------
26 // Setup Input/Output Fields
27 //------------------------------------------------------------------------------
28 static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op,
29                                        bool inOrOut,
30                                        CeedVector *full_evecs, CeedVector *e_vecs,
31                                        CeedVector *q_vecs, CeedInt starte,
32                                        CeedInt num_fields, CeedInt Q) {
33   CeedInt dim, ierr, size, P;
34   Ceed ceed;
35   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
36   CeedBasis basis;
37   CeedElemRestriction elem_restr;
38   CeedOperatorField *op_fields;
39   CeedQFunctionField *qf_fields;
40   if (inOrOut) {
41     ierr = CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields);
42     CeedChkBackend(ierr);
43     ierr = CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields);
44     CeedChkBackend(ierr);
45   } else {
46     ierr = CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL);
47     CeedChkBackend(ierr);
48     ierr = CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL);
49     CeedChkBackend(ierr);
50   }
51 
52   // Loop over fields
53   for (CeedInt i=0; i<num_fields; i++) {
54     CeedEvalMode eval_mode;
55     ierr = CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode);
56     CeedChkBackend(ierr);
57 
58     if (eval_mode != CEED_EVAL_WEIGHT) {
59       ierr = CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_restr);
60       CeedChkBackend(ierr);
61       ierr = CeedElemRestrictionCreateVector(elem_restr, NULL,
62                                              &full_evecs[i+starte]);
63       CeedChkBackend(ierr);
64     }
65 
66     switch(eval_mode) {
67     case CEED_EVAL_NONE:
68       ierr = CeedQFunctionFieldGetSize(qf_fields[i], &size); CeedChkBackend(ierr);
69       ierr = CeedVectorCreate(ceed, Q*size, &q_vecs[i]); CeedChkBackend(ierr);
70       break;
71     case CEED_EVAL_INTERP:
72       ierr = CeedQFunctionFieldGetSize(qf_fields[i], &size); CeedChkBackend(ierr);
73       ierr = CeedElemRestrictionGetElementSize(elem_restr, &P);
74       CeedChkBackend(ierr);
75       ierr = CeedVectorCreate(ceed, P*size, &e_vecs[i]); CeedChkBackend(ierr);
76       ierr = CeedVectorCreate(ceed, Q*size, &q_vecs[i]); CeedChkBackend(ierr);
77       break;
78     case CEED_EVAL_GRAD:
79       ierr = CeedOperatorFieldGetBasis(op_fields[i], &basis); CeedChkBackend(ierr);
80       ierr = CeedQFunctionFieldGetSize(qf_fields[i], &size); CeedChkBackend(ierr);
81       ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
82       ierr = CeedElemRestrictionGetElementSize(elem_restr, &P);
83       CeedChkBackend(ierr);
84       ierr = CeedVectorCreate(ceed, P*size/dim, &e_vecs[i]); CeedChkBackend(ierr);
85       ierr = CeedVectorCreate(ceed, Q*size, &q_vecs[i]); CeedChkBackend(ierr);
86       break;
87     case CEED_EVAL_WEIGHT: // Only on input fields
88       ierr = CeedOperatorFieldGetBasis(op_fields[i], &basis); CeedChkBackend(ierr);
89       ierr = CeedVectorCreate(ceed, Q, &q_vecs[i]); CeedChkBackend(ierr);
90       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT,
91                             CEED_VECTOR_NONE, q_vecs[i]); CeedChkBackend(ierr);
92       break;
93     case CEED_EVAL_DIV:
94       break; // Not implemented
95     case CEED_EVAL_CURL:
96       break; // Not implemented
97     }
98   }
99   return CEED_ERROR_SUCCESS;
100 }
101 
102 //------------------------------------------------------------------------------
103 // Setup Operator
104 //------------------------------------------------------------------------------/*
105 static int CeedOperatorSetup_Ref(CeedOperator op) {
106   int ierr;
107   bool setup_done;
108   ierr = CeedOperatorIsSetupDone(op, &setup_done); CeedChkBackend(ierr);
109   if (setup_done) return CEED_ERROR_SUCCESS;
110   Ceed ceed;
111   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
112   CeedOperator_Ref *impl;
113   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
114   CeedQFunction qf;
115   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
116   CeedInt Q, num_input_fields, num_output_fields;
117   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr);
118   ierr = CeedQFunctionIsIdentity(qf, &impl->is_identity_qf); CeedChkBackend(ierr);
119   CeedOperatorField *op_input_fields, *op_output_fields;
120   ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields,
121                                &num_output_fields, &op_output_fields);
122   CeedChkBackend(ierr);
123   CeedQFunctionField *qf_input_fields, *qf_output_fields;
124   ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL,
125                                 &qf_output_fields);
126   CeedChkBackend(ierr);
127 
128   // Allocate
129   ierr = CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs);
130   CeedChkBackend(ierr);
131 
132   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->input_state); CeedChkBackend(ierr);
133   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in); CeedChkBackend(ierr);
134   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out); CeedChkBackend(ierr);
135   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in); CeedChkBackend(ierr);
136   ierr = CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out); CeedChkBackend(ierr);
137 
138   impl->num_e_vecs_in = num_input_fields;
139   impl->num_e_vecs_out = num_output_fields;
140 
141   // Set up infield and outfield e_vecs and q_vecs
142   // Infields
143   ierr = CeedOperatorSetupFields_Ref(qf, op, 0, impl->e_vecs,
144                                      impl->e_vecs_in, impl->q_vecs_in, 0,
145                                      num_input_fields, Q);
146   CeedChkBackend(ierr);
147   // Outfields
148   ierr = CeedOperatorSetupFields_Ref(qf, op, 1, impl->e_vecs,
149                                      impl->e_vecs_out, impl->q_vecs_out,
150                                      num_input_fields, num_output_fields, Q);
151   CeedChkBackend(ierr);
152 
153   // Identity QFunctions
154   if (impl->is_identity_qf) {
155     CeedEvalMode in_mode, out_mode;
156     CeedQFunctionField *in_fields, *out_fields;
157     ierr = CeedQFunctionGetFields(qf, NULL, &in_fields, NULL, &out_fields);
158     CeedChkBackend(ierr);
159     ierr = CeedQFunctionFieldGetEvalMode(in_fields[0], &in_mode);
160     CeedChkBackend(ierr);
161     ierr = CeedQFunctionFieldGetEvalMode(out_fields[0], &out_mode);
162     CeedChkBackend(ierr);
163 
164     if (in_mode == CEED_EVAL_NONE && out_mode == CEED_EVAL_NONE) {
165       impl->is_identity_restr_op = true;
166     } else {
167       ierr = CeedVectorDestroy(&impl->q_vecs_out[0]); CeedChkBackend(ierr);
168       impl->q_vecs_out[0] = impl->q_vecs_in[0];
169       ierr = CeedVectorAddReference(impl->q_vecs_in[0]); CeedChkBackend(ierr);
170     }
171   }
172 
173   ierr = CeedOperatorSetSetupDone(op); CeedChkBackend(ierr);
174 
175   return CEED_ERROR_SUCCESS;
176 }
177 
178 //------------------------------------------------------------------------------
179 // Setup Operator Inputs
180 //------------------------------------------------------------------------------
181 static inline int CeedOperatorSetupInputs_Ref(CeedInt num_input_fields,
182     CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
183     CeedVector in_vec, const bool skip_active, CeedScalar *e_data[2*CEED_FIELD_MAX],
184     CeedOperator_Ref *impl, CeedRequest *request) {
185   CeedInt ierr;
186   CeedEvalMode eval_mode;
187   CeedVector vec;
188   CeedElemRestriction elem_restr;
189   uint64_t state;
190 
191   for (CeedInt i=0; i<num_input_fields; i++) {
192     // Get input vector
193     ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec);
194     CeedChkBackend(ierr);
195     if (vec == CEED_VECTOR_ACTIVE) {
196       if (skip_active)
197         continue;
198       else
199         vec = in_vec;
200     }
201 
202     ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode);
203     CeedChkBackend(ierr);
204     // Restrict and Evec
205     if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
206     } else {
207       // Restrict
208       ierr = CeedVectorGetState(vec, &state); CeedChkBackend(ierr);
209       // Skip restriction if input is unchanged
210       if (state != impl->input_state[i] || vec == in_vec) {
211         ierr = CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr);
212         CeedChkBackend(ierr);
213         ierr = CeedElemRestrictionApply(elem_restr, CEED_NOTRANSPOSE, vec,
214                                         impl->e_vecs[i], request); CeedChkBackend(ierr);
215         impl->input_state[i] = state;
216       }
217       // Get evec
218       ierr = CeedVectorGetArrayRead(impl->e_vecs[i], CEED_MEM_HOST,
219                                     (const CeedScalar **) &e_data[i]);
220       CeedChkBackend(ierr);
221     }
222   }
223   return CEED_ERROR_SUCCESS;
224 }
225 
226 //------------------------------------------------------------------------------
227 // Input Basis Action
228 //------------------------------------------------------------------------------
229 static inline int CeedOperatorInputBasis_Ref(CeedInt e, CeedInt Q,
230     CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
231     CeedInt num_input_fields, const bool skip_active,
232     CeedScalar *e_data[2*CEED_FIELD_MAX], CeedOperator_Ref *impl) {
233   CeedInt ierr;
234   CeedInt dim, elem_size, size;
235   CeedElemRestriction elem_restr;
236   CeedEvalMode eval_mode;
237   CeedBasis basis;
238 
239   for (CeedInt i=0; i<num_input_fields; i++) {
240     // Skip active input
241     if (skip_active) {
242       CeedVector vec;
243       ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec);
244       CeedChkBackend(ierr);
245       if (vec == CEED_VECTOR_ACTIVE)
246         continue;
247     }
248     // Get elem_size, eval_mode, size
249     ierr = CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr);
250     CeedChkBackend(ierr);
251     ierr = CeedElemRestrictionGetElementSize(elem_restr, &elem_size);
252     CeedChkBackend(ierr);
253     ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode);
254     CeedChkBackend(ierr);
255     ierr = CeedQFunctionFieldGetSize(qf_input_fields[i], &size);
256     CeedChkBackend(ierr);
257     // Basis action
258     switch(eval_mode) {
259     case CEED_EVAL_NONE:
260       ierr = CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST,
261                                 CEED_USE_POINTER, &e_data[i][e*Q*size]);
262       CeedChkBackend(ierr);
263       break;
264     case CEED_EVAL_INTERP:
265       ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis);
266       CeedChkBackend(ierr);
267       ierr = CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST,
268                                 CEED_USE_POINTER, &e_data[i][e*elem_size*size]);
269       CeedChkBackend(ierr);
270       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE, CEED_EVAL_INTERP,
271                             impl->e_vecs_in[i], impl->q_vecs_in[i]); CeedChkBackend(ierr);
272       break;
273     case CEED_EVAL_GRAD:
274       ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis);
275       CeedChkBackend(ierr);
276       ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
277       ierr = CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST,
278                                 CEED_USE_POINTER, &e_data[i][e*elem_size*size/dim]);
279       CeedChkBackend(ierr);
280       ierr = CeedBasisApply(basis, 1, CEED_NOTRANSPOSE,
281                             CEED_EVAL_GRAD, impl->e_vecs_in[i],
282                             impl->q_vecs_in[i]); CeedChkBackend(ierr);
283       break;
284     case CEED_EVAL_WEIGHT:
285       break;  // No action
286     // LCOV_EXCL_START
287     case CEED_EVAL_DIV:
288     case CEED_EVAL_CURL: {
289       ierr = CeedOperatorFieldGetBasis(op_input_fields[i], &basis);
290       CeedChkBackend(ierr);
291       Ceed ceed;
292       ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
293       return CeedError(ceed, CEED_ERROR_BACKEND,
294                        "Ceed evaluation mode not implemented");
295       // LCOV_EXCL_STOP
296     }
297     }
298   }
299   return CEED_ERROR_SUCCESS;
300 }
301 
302 //------------------------------------------------------------------------------
303 // Output Basis Action
304 //------------------------------------------------------------------------------
305 static inline int CeedOperatorOutputBasis_Ref(CeedInt e, CeedInt Q,
306     CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields,
307     CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op,
308     CeedScalar *e_data[2*CEED_FIELD_MAX], CeedOperator_Ref *impl) {
309   CeedInt ierr;
310   CeedInt dim, elem_size, size;
311   CeedElemRestriction elem_restr;
312   CeedEvalMode eval_mode;
313   CeedBasis basis;
314 
315   for (CeedInt i=0; i<num_output_fields; i++) {
316     // Get elem_size, eval_mode, size
317     ierr = CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr);
318     CeedChkBackend(ierr);
319     ierr = CeedElemRestrictionGetElementSize(elem_restr, &elem_size);
320     CeedChkBackend(ierr);
321     ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode);
322     CeedChkBackend(ierr);
323     ierr = CeedQFunctionFieldGetSize(qf_output_fields[i], &size);
324     CeedChkBackend(ierr);
325     // Basis action
326     switch(eval_mode) {
327     case CEED_EVAL_NONE:
328       break; // No action
329     case CEED_EVAL_INTERP:
330       ierr = CeedOperatorFieldGetBasis(op_output_fields[i], &basis);
331       CeedChkBackend(ierr);
332       ierr = CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST,
333                                 CEED_USE_POINTER,
334                                 &e_data[i + num_input_fields][e*elem_size*size]);
335       CeedChkBackend(ierr);
336       ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
337                             CEED_EVAL_INTERP, impl->q_vecs_out[i],
338                             impl->e_vecs_out[i]); CeedChkBackend(ierr);
339       break;
340     case CEED_EVAL_GRAD:
341       ierr = CeedOperatorFieldGetBasis(op_output_fields[i], &basis);
342       CeedChkBackend(ierr);
343       ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
344       ierr = CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST,
345                                 CEED_USE_POINTER,
346                                 &e_data[i + num_input_fields][e*elem_size*size/dim]);
347       CeedChkBackend(ierr);
348       ierr = CeedBasisApply(basis, 1, CEED_TRANSPOSE,
349                             CEED_EVAL_GRAD, impl->q_vecs_out[i],
350                             impl->e_vecs_out[i]); CeedChkBackend(ierr);
351       break;
352     // LCOV_EXCL_START
353     case CEED_EVAL_WEIGHT: {
354       Ceed ceed;
355       ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
356       return CeedError(ceed, CEED_ERROR_BACKEND,
357                        "CEED_EVAL_WEIGHT cannot be an output "
358                        "evaluation mode");
359     }
360     case CEED_EVAL_DIV:
361     case CEED_EVAL_CURL: {
362       Ceed ceed;
363       ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
364       return CeedError(ceed, CEED_ERROR_BACKEND,
365                        "Ceed evaluation mode not implemented");
366       // LCOV_EXCL_STOP
367     }
368     }
369   }
370   return CEED_ERROR_SUCCESS;
371 }
372 
373 //------------------------------------------------------------------------------
374 // Restore Input Vectors
375 //------------------------------------------------------------------------------
376 static inline int CeedOperatorRestoreInputs_Ref(CeedInt num_input_fields,
377     CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
378     const bool skip_active, CeedScalar *e_data[2*CEED_FIELD_MAX],
379     CeedOperator_Ref *impl) {
380   CeedInt ierr;
381   CeedEvalMode eval_mode;
382 
383   for (CeedInt i=0; i<num_input_fields; i++) {
384     // Skip active inputs
385     if (skip_active) {
386       CeedVector vec;
387       ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec);
388       CeedChkBackend(ierr);
389       if (vec == CEED_VECTOR_ACTIVE)
390         continue;
391     }
392     // Restore input
393     ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode);
394     CeedChkBackend(ierr);
395     if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
396     } else {
397       ierr = CeedVectorRestoreArrayRead(impl->e_vecs[i],
398                                         (const CeedScalar **) &e_data[i]);
399       CeedChkBackend(ierr);
400     }
401   }
402   return CEED_ERROR_SUCCESS;
403 }
404 
405 //------------------------------------------------------------------------------
406 // Operator Apply
407 //------------------------------------------------------------------------------
408 static int CeedOperatorApplyAdd_Ref(CeedOperator op, CeedVector in_vec,
409                                     CeedVector out_vec, CeedRequest *request) {
410   int ierr;
411   CeedOperator_Ref *impl;
412   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
413   CeedQFunction qf;
414   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
415   CeedInt Q, num_elem, num_input_fields, num_output_fields, size;
416   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr);
417   ierr = CeedOperatorGetNumElements(op, &num_elem); CeedChkBackend(ierr);
418   CeedOperatorField *op_input_fields, *op_output_fields;
419   ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields,
420                                &num_output_fields, &op_output_fields);
421   CeedChkBackend(ierr);
422   CeedQFunctionField *qf_input_fields, *qf_output_fields;
423   ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL,
424                                 &qf_output_fields);
425   CeedChkBackend(ierr);
426   CeedEvalMode eval_mode;
427   CeedVector vec;
428   CeedElemRestriction elem_restr;
429   CeedScalar *e_data[2*CEED_FIELD_MAX] = {0};
430 
431   // Setup
432   ierr = CeedOperatorSetup_Ref(op); CeedChkBackend(ierr);
433 
434   // Restriction only operator
435   if (impl->is_identity_restr_op) {
436     ierr = CeedOperatorFieldGetElemRestriction(op_input_fields[0], &elem_restr);
437     CeedChkBackend(ierr);
438     ierr = CeedElemRestrictionApply(elem_restr, CEED_NOTRANSPOSE, in_vec,
439                                     impl->e_vecs[0], request); CeedChkBackend(ierr);
440     ierr = CeedOperatorFieldGetElemRestriction(op_output_fields[0], &elem_restr);
441     CeedChkBackend(ierr);
442     ierr = CeedElemRestrictionApply(elem_restr, CEED_TRANSPOSE, impl->e_vecs[0],
443                                     out_vec, request); CeedChkBackend(ierr);
444     return CEED_ERROR_SUCCESS;
445   }
446 
447   // Input Evecs and Restriction
448   ierr = CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields,
449                                      op_input_fields, in_vec, false, e_data, impl,
450                                      request); CeedChkBackend(ierr);
451 
452   // Output Evecs
453   for (CeedInt i=0; i<num_output_fields; i++) {
454     ierr = CeedVectorGetArray(impl->e_vecs[i+impl->num_e_vecs_in], CEED_MEM_HOST,
455                               &e_data[i + num_input_fields]); CeedChkBackend(ierr);
456   }
457 
458   // Loop through elements
459   for (CeedInt e=0; e<num_elem; e++) {
460     // Output pointers
461     for (CeedInt i=0; i<num_output_fields; i++) {
462       ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode);
463       CeedChkBackend(ierr);
464       if (eval_mode == CEED_EVAL_NONE) {
465         ierr = CeedQFunctionFieldGetSize(qf_output_fields[i], &size);
466         CeedChkBackend(ierr);
467         ierr = CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_HOST,
468                                   CEED_USE_POINTER,
469                                   &e_data[i + num_input_fields][e*Q*size]);
470         CeedChkBackend(ierr);
471       }
472     }
473 
474     // Input basis apply
475     ierr = CeedOperatorInputBasis_Ref(e, Q, qf_input_fields, op_input_fields,
476                                       num_input_fields, false, e_data, impl);
477     CeedChkBackend(ierr);
478 
479     // Q function
480     if (!impl->is_identity_qf) {
481       ierr = CeedQFunctionApply(qf, Q, impl->q_vecs_in, impl->q_vecs_out);
482       CeedChkBackend(ierr);
483     }
484 
485     // Output basis apply
486     ierr = CeedOperatorOutputBasis_Ref(e, Q, qf_output_fields, op_output_fields,
487                                        num_input_fields, num_output_fields, op,
488                                        e_data, impl); CeedChkBackend(ierr);
489   }
490 
491   // Output restriction
492   for (CeedInt i=0; i<num_output_fields; i++) {
493     // Restore Evec
494     ierr = CeedVectorRestoreArray(impl->e_vecs[i+impl->num_e_vecs_in],
495                                   &e_data[i + num_input_fields]);
496     CeedChkBackend(ierr);
497     // Get output vector
498     ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec);
499     CeedChkBackend(ierr);
500     // Active
501     if (vec == CEED_VECTOR_ACTIVE)
502       vec = out_vec;
503     // Restrict
504     ierr = CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr);
505     CeedChkBackend(ierr);
506     ierr = CeedElemRestrictionApply(elem_restr, CEED_TRANSPOSE,
507                                     impl->e_vecs[i+impl->num_e_vecs_in], vec, request);
508     CeedChkBackend(ierr);
509   }
510 
511   // Restore input arrays
512   ierr = CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields,
513                                        op_input_fields, false, e_data, impl);
514   CeedChkBackend(ierr);
515 
516   return CEED_ERROR_SUCCESS;
517 }
518 
519 //------------------------------------------------------------------------------
520 // Core code for assembling linear QFunction
521 //------------------------------------------------------------------------------
522 static inline int CeedOperatorLinearAssembleQFunctionCore_Ref(CeedOperator op,
523     bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr,
524     CeedRequest *request) {
525   int ierr;
526   CeedOperator_Ref *impl;
527   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
528   CeedQFunction qf;
529   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
530   CeedInt Q, num_elem, num_input_fields, num_output_fields, size;
531   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChkBackend(ierr);
532   ierr = CeedOperatorGetNumElements(op, &num_elem); CeedChkBackend(ierr);
533   CeedOperatorField *op_input_fields, *op_output_fields;
534   ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields,
535                                &num_output_fields, &op_output_fields);
536   CeedChkBackend(ierr);
537   CeedQFunctionField *qf_input_fields, *qf_output_fields;
538   ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL,
539                                 &qf_output_fields);
540   CeedChkBackend(ierr);
541   CeedVector vec;
542   CeedInt num_active_in = impl->qf_num_active_in,
543           num_active_out = impl->qf_num_active_out;
544   CeedVector *active_in = impl->qf_active_in;
545   CeedScalar *a, *tmp;
546   Ceed ceed, ceed_parent;
547   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
548   ierr = CeedGetOperatorFallbackParentCeed(ceed, &ceed_parent);
549   CeedChkBackend(ierr);
550   ceed_parent = ceed_parent ? ceed_parent : ceed;
551   CeedScalar *e_data[2*CEED_FIELD_MAX] = {0};
552 
553   // Setup
554   ierr = CeedOperatorSetup_Ref(op); CeedChkBackend(ierr);
555 
556   // Check for identity
557   if (impl->is_identity_qf)
558     // LCOV_EXCL_START
559     return CeedError(ceed, CEED_ERROR_BACKEND,
560                      "Assembling identity QFunctions not supported");
561   // LCOV_EXCL_STOP
562 
563   // Input Evecs and Restriction
564   ierr = CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields,
565                                      op_input_fields, NULL, true, e_data,
566                                      impl, request); CeedChkBackend(ierr);
567 
568   // Count number of active input fields
569   if (!num_active_in) {
570     for (CeedInt i=0; i<num_input_fields; i++) {
571       // Get input vector
572       ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec);
573       CeedChkBackend(ierr);
574       // Check if active input
575       if (vec == CEED_VECTOR_ACTIVE) {
576         ierr = CeedQFunctionFieldGetSize(qf_input_fields[i], &size);
577         CeedChkBackend(ierr);
578         ierr = CeedVectorSetValue(impl->q_vecs_in[i], 0.0); CeedChkBackend(ierr);
579         ierr = CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &tmp);
580         CeedChkBackend(ierr);
581         ierr = CeedRealloc(num_active_in + size, &active_in); CeedChkBackend(ierr);
582         for (CeedInt field=0; field<size; field++) {
583           ierr = CeedVectorCreate(ceed, Q, &active_in[num_active_in+field]);
584           CeedChkBackend(ierr);
585           ierr = CeedVectorSetArray(active_in[num_active_in+field], CEED_MEM_HOST,
586                                     CEED_USE_POINTER, &tmp[field*Q]);
587           CeedChkBackend(ierr);
588         }
589         num_active_in += size;
590         ierr = CeedVectorRestoreArray(impl->q_vecs_in[i], &tmp); CeedChkBackend(ierr);
591       }
592     }
593     impl->qf_num_active_in = num_active_in;
594     impl->qf_active_in = active_in;
595   }
596 
597   // Count number of active output fields
598   if (!num_active_out) {
599     for (CeedInt i=0; i<num_output_fields; i++) {
600       // Get output vector
601       ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec);
602       CeedChkBackend(ierr);
603       // Check if active output
604       if (vec == CEED_VECTOR_ACTIVE) {
605         ierr = CeedQFunctionFieldGetSize(qf_output_fields[i], &size);
606         CeedChkBackend(ierr);
607         num_active_out += size;
608       }
609     }
610     impl->qf_num_active_out = num_active_out;
611   }
612 
613   // Check sizes
614   if (!num_active_in || !num_active_out)
615     // LCOV_EXCL_START
616     return CeedError(ceed, CEED_ERROR_BACKEND,
617                      "Cannot assemble QFunction without active inputs "
618                      "and outputs");
619   // LCOV_EXCL_STOP
620 
621   // Build objects if needed
622   if (build_objects) {
623     // Create output restriction
624     CeedInt strides[3] = {1, Q, num_active_in*num_active_out*Q}; /* *NOPAD* */
625     ierr = CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q,
626                                             num_active_in*num_active_out,
627                                             num_active_in*num_active_out*num_elem*Q,
628                                             strides, rstr); CeedChkBackend(ierr);
629     // Create assembled vector
630     ierr = CeedVectorCreate(ceed_parent, num_elem*Q*num_active_in*num_active_out,
631                             assembled); CeedChkBackend(ierr);
632   }
633   // Clear output vector
634   ierr = CeedVectorSetValue(*assembled, 0.0); CeedChkBackend(ierr);
635   ierr = CeedVectorGetArray(*assembled, CEED_MEM_HOST, &a); CeedChkBackend(ierr);
636 
637   // Loop through elements
638   for (CeedInt e=0; e<num_elem; e++) {
639     // Input basis apply
640     ierr = CeedOperatorInputBasis_Ref(e, Q, qf_input_fields, op_input_fields,
641                                       num_input_fields, true, e_data, impl);
642     CeedChkBackend(ierr);
643 
644     // Assemble QFunction
645     for (CeedInt in=0; in<num_active_in; in++) {
646       // Set Inputs
647       ierr = CeedVectorSetValue(active_in[in], 1.0); CeedChkBackend(ierr);
648       if (num_active_in > 1) {
649         ierr = CeedVectorSetValue(active_in[(in+num_active_in-1)%num_active_in],
650                                   0.0); CeedChkBackend(ierr);
651       }
652       // Set Outputs
653       for (CeedInt out=0; out<num_output_fields; out++) {
654         // Get output vector
655         ierr = CeedOperatorFieldGetVector(op_output_fields[out], &vec);
656         CeedChkBackend(ierr);
657         // Check if active output
658         if (vec == CEED_VECTOR_ACTIVE) {
659           CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST,
660                              CEED_USE_POINTER, a); CeedChkBackend(ierr);
661           ierr = CeedQFunctionFieldGetSize(qf_output_fields[out], &size);
662           CeedChkBackend(ierr);
663           a += size*Q; // Advance the pointer by the size of the output
664         }
665       }
666       // Apply QFunction
667       ierr = CeedQFunctionApply(qf, Q, impl->q_vecs_in, impl->q_vecs_out);
668       CeedChkBackend(ierr);
669     }
670   }
671 
672   // Un-set output Qvecs to prevent accidental overwrite of Assembled
673   for (CeedInt out=0; out<num_output_fields; out++) {
674     // Get output vector
675     ierr = CeedOperatorFieldGetVector(op_output_fields[out], &vec);
676     CeedChkBackend(ierr);
677     // Check if active output
678     if (vec == CEED_VECTOR_ACTIVE) {
679       CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_HOST, NULL);
680       CeedChkBackend(ierr);
681     }
682   }
683 
684   // Restore input arrays
685   ierr = CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields,
686                                        op_input_fields, true, e_data, impl);
687   CeedChkBackend(ierr);
688 
689   // Restore output
690   ierr = CeedVectorRestoreArray(*assembled, &a); CeedChkBackend(ierr);
691 
692   return CEED_ERROR_SUCCESS;
693 }
694 
695 //------------------------------------------------------------------------------
696 // Assemble Linear QFunction
697 //------------------------------------------------------------------------------
698 static int CeedOperatorLinearAssembleQFunction_Ref(CeedOperator op,
699     CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) {
700   return CeedOperatorLinearAssembleQFunctionCore_Ref(op, true, assembled, rstr,
701          request);
702 }
703 
704 //------------------------------------------------------------------------------
705 // Update Assembled Linear QFunction
706 //------------------------------------------------------------------------------
707 static int CeedOperatorLinearAssembleQFunctionUpdate_Ref(CeedOperator op,
708     CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) {
709   return CeedOperatorLinearAssembleQFunctionCore_Ref(op, false, &assembled,
710          &rstr, request);
711 }
712 
713 //------------------------------------------------------------------------------
714 // Operator Destroy
715 //------------------------------------------------------------------------------
716 static int CeedOperatorDestroy_Ref(CeedOperator op) {
717   int ierr;
718   CeedOperator_Ref *impl;
719   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
720 
721   for (CeedInt i=0; i<impl->num_e_vecs_in+impl->num_e_vecs_out; i++) {
722     ierr = CeedVectorDestroy(&impl->e_vecs[i]); CeedChkBackend(ierr);
723   }
724   ierr = CeedFree(&impl->e_vecs); CeedChkBackend(ierr);
725   ierr = CeedFree(&impl->input_state); CeedChkBackend(ierr);
726 
727   for (CeedInt i=0; i<impl->num_e_vecs_in; i++) {
728     ierr = CeedVectorDestroy(&impl->e_vecs_in[i]); CeedChkBackend(ierr);
729     ierr = CeedVectorDestroy(&impl->q_vecs_in[i]); CeedChkBackend(ierr);
730   }
731   ierr = CeedFree(&impl->e_vecs_in); CeedChkBackend(ierr);
732   ierr = CeedFree(&impl->q_vecs_in); CeedChkBackend(ierr);
733 
734   for (CeedInt i=0; i<impl->num_e_vecs_out; i++) {
735     ierr = CeedVectorDestroy(&impl->e_vecs_out[i]); CeedChkBackend(ierr);
736     ierr = CeedVectorDestroy(&impl->q_vecs_out[i]); CeedChkBackend(ierr);
737   }
738   ierr = CeedFree(&impl->e_vecs_out); CeedChkBackend(ierr);
739   ierr = CeedFree(&impl->q_vecs_out); CeedChkBackend(ierr);
740 
741   // QFunction assembly
742   for (CeedInt i=0; i<impl->qf_num_active_in; i++) {
743     ierr = CeedVectorDestroy(&impl->qf_active_in[i]); CeedChkBackend(ierr);
744   }
745   ierr = CeedFree(&impl->qf_active_in); CeedChkBackend(ierr);
746 
747   ierr = CeedFree(&impl); CeedChkBackend(ierr);
748   return CEED_ERROR_SUCCESS;
749 }
750 
751 //------------------------------------------------------------------------------
752 // Operator Create
753 //------------------------------------------------------------------------------
754 int CeedOperatorCreate_Ref(CeedOperator op) {
755   int ierr;
756   Ceed ceed;
757   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
758   CeedOperator_Ref *impl;
759 
760   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
761   ierr = CeedOperatorSetData(op, impl); CeedChkBackend(ierr);
762 
763   ierr = CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction",
764                                 CeedOperatorLinearAssembleQFunction_Ref);
765   CeedChkBackend(ierr);
766   ierr = CeedSetBackendFunction(ceed, "Operator", op,
767                                 "LinearAssembleQFunctionUpdate",
768                                 CeedOperatorLinearAssembleQFunctionUpdate_Ref);
769   CeedChkBackend(ierr);
770   ierr = CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd",
771                                 CeedOperatorApplyAdd_Ref); CeedChkBackend(ierr);
772   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
773                                 CeedOperatorDestroy_Ref); CeedChkBackend(ierr);
774   return CEED_ERROR_SUCCESS;
775 }
776