xref: /libCEED/backends/opt/ceed-opt-operator.c (revision b13efd58b277efef1db70d6f06eaaf4d415a7642)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <stdbool.h>
11 #include <stdint.h>
12 #include <string.h>
13 
14 #include "ceed-opt.h"
15 
16 //------------------------------------------------------------------------------
17 // Setup Input/Output Fields
18 //------------------------------------------------------------------------------
19 static int CeedOperatorSetupFields_Opt(CeedQFunction qf, CeedOperator op, bool is_input, bool *skip_rstr, bool *apply_add_basis,
20                                        const CeedInt block_size, CeedElemRestriction *block_rstr, CeedVector *e_vecs_full, CeedVector *e_vecs,
21                                        CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q) {
22   Ceed                ceed;
23   CeedSize            e_size, q_size;
24   CeedInt             num_comp, size, P;
25   CeedQFunctionField *qf_fields;
26   CeedOperatorField  *op_fields;
27 
28   {
29     Ceed ceed_parent;
30 
31     CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
32     CeedCallBackend(CeedGetParent(ceed, &ceed_parent));
33     if (ceed_parent) ceed = ceed_parent;
34   }
35   if (is_input) {
36     CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL));
37     CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL));
38   } else {
39     CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields));
40     CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields));
41   }
42 
43   // Loop over fields
44   for (CeedInt i = 0; i < num_fields; i++) {
45     CeedEvalMode eval_mode;
46     CeedBasis    basis;
47 
48     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
49     if (eval_mode != CEED_EVAL_WEIGHT) {
50       Ceed                ceed_rstr;
51       CeedSize            l_size;
52       CeedInt             num_elem, elem_size, comp_stride;
53       CeedRestrictionType rstr_type;
54       CeedElemRestriction rstr;
55 
56       CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr));
57       CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed_rstr));
58       CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem));
59       CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size));
60       CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size));
61       CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp));
62       CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride));
63 
64       CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type));
65       switch (rstr_type) {
66         case CEED_RESTRICTION_STANDARD: {
67           const CeedInt *offsets = NULL;
68 
69           CeedCallBackend(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets));
70           CeedCallBackend(CeedElemRestrictionCreateBlocked(ceed_rstr, num_elem, elem_size, block_size, num_comp, comp_stride, l_size, CEED_MEM_HOST,
71                                                            CEED_COPY_VALUES, offsets, &block_rstr[i + start_e]));
72           CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr, &offsets));
73         } break;
74         case CEED_RESTRICTION_ORIENTED: {
75           const bool    *orients = NULL;
76           const CeedInt *offsets = NULL;
77 
78           CeedCallBackend(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets));
79           CeedCallBackend(CeedElemRestrictionGetOrientations(rstr, CEED_MEM_HOST, &orients));
80           CeedCallBackend(CeedElemRestrictionCreateBlockedOriented(ceed_rstr, num_elem, elem_size, block_size, num_comp, comp_stride, l_size,
81                                                                    CEED_MEM_HOST, CEED_COPY_VALUES, offsets, orients, &block_rstr[i + start_e]));
82           CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr, &offsets));
83           CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr, &orients));
84         } break;
85         case CEED_RESTRICTION_CURL_ORIENTED: {
86           const CeedInt8 *curl_orients = NULL;
87           const CeedInt  *offsets      = NULL;
88 
89           CeedCallBackend(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets));
90           CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr, CEED_MEM_HOST, &curl_orients));
91           CeedCallBackend(CeedElemRestrictionCreateBlockedCurlOriented(ceed_rstr, num_elem, elem_size, block_size, num_comp, comp_stride, l_size,
92                                                                        CEED_MEM_HOST, CEED_COPY_VALUES, offsets, curl_orients,
93                                                                        &block_rstr[i + start_e]));
94           CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr, &offsets));
95           CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr, &curl_orients));
96         } break;
97         case CEED_RESTRICTION_STRIDED: {
98           CeedInt strides[3];
99 
100           CeedCallBackend(CeedElemRestrictionGetStrides(rstr, strides));
101           CeedCallBackend(CeedElemRestrictionCreateBlockedStrided(ceed_rstr, num_elem, elem_size, block_size, num_comp, l_size, strides,
102                                                                   &block_rstr[i + start_e]));
103         } break;
104         case CEED_RESTRICTION_POINTS:
105           // Empty case - won't occur
106           break;
107       }
108       CeedCallBackend(CeedElemRestrictionDestroy(&rstr));
109       CeedCallBackend(CeedElemRestrictionCreateVector(block_rstr[i + start_e], NULL, &e_vecs_full[i + start_e]));
110     }
111 
112     switch (eval_mode) {
113       case CEED_EVAL_NONE:
114         CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size));
115         e_size = (CeedSize)Q * size * block_size;
116         CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i]));
117         q_size = (CeedSize)Q * size * block_size;
118         CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
119         break;
120       case CEED_EVAL_INTERP:
121       case CEED_EVAL_GRAD:
122       case CEED_EVAL_DIV:
123       case CEED_EVAL_CURL:
124         CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
125         CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size));
126         CeedCallBackend(CeedBasisGetNumNodes(basis, &P));
127         CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
128         CeedCallBackend(CeedBasisDestroy(&basis));
129         e_size = (CeedSize)P * num_comp * block_size;
130         CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i]));
131         q_size = (CeedSize)Q * size * block_size;
132         CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
133         break;
134       case CEED_EVAL_WEIGHT:  // Only on input fields
135         CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
136         q_size = (CeedSize)Q * block_size;
137         CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
138         CeedCallBackend(CeedBasisApply(basis, block_size, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i]));
139         CeedCallBackend(CeedBasisDestroy(&basis));
140         break;
141     }
142     // Initialize E-vec arrays
143     if (e_vecs[i]) CeedCallBackend(CeedVectorSetValue(e_vecs[i], 0.0));
144   }
145   // Drop duplicate restrictions
146   if (is_input) {
147     for (CeedInt i = 0; i < num_fields; i++) {
148       CeedVector          vec_i;
149       CeedElemRestriction rstr_i;
150 
151       CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i));
152       CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i));
153       for (CeedInt j = i + 1; j < num_fields; j++) {
154         CeedVector          vec_j;
155         CeedElemRestriction rstr_j;
156 
157         CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j));
158         CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
159         if (vec_i == vec_j && rstr_i == rstr_j) {
160           CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j]));
161           CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e]));
162           skip_rstr[j] = true;
163         }
164         CeedCallBackend(CeedVectorDestroy(&vec_j));
165         CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j));
166       }
167       CeedCallBackend(CeedVectorDestroy(&vec_i));
168       CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
169     }
170   } else {
171     for (CeedInt i = num_fields - 1; i >= 0; i--) {
172       CeedVector          vec_i;
173       CeedElemRestriction rstr_i;
174 
175       CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i));
176       CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i));
177       for (CeedInt j = i - 1; j >= 0; j--) {
178         CeedVector          vec_j;
179         CeedElemRestriction rstr_j;
180 
181         CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j));
182         CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
183         if (vec_i == vec_j && rstr_i == rstr_j) {
184           CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j]));
185           CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e]));
186           skip_rstr[j]       = true;
187           apply_add_basis[i] = true;
188         }
189         CeedCallBackend(CeedVectorDestroy(&vec_j));
190         CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j));
191       }
192       CeedCallBackend(CeedVectorDestroy(&vec_i));
193       CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
194     }
195   }
196   return CEED_ERROR_SUCCESS;
197 }
198 
199 //------------------------------------------------------------------------------
200 // Setup Operator
201 //------------------------------------------------------------------------------
202 static int CeedOperatorSetup_Opt(CeedOperator op) {
203   bool                is_setup_done;
204   Ceed                ceed;
205   Ceed_Opt           *ceed_impl;
206   CeedInt             Q, num_input_fields, num_output_fields;
207   CeedQFunctionField *qf_input_fields, *qf_output_fields;
208   CeedQFunction       qf;
209   CeedOperatorField  *op_input_fields, *op_output_fields;
210   CeedOperator_Opt   *impl;
211 
212   CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done));
213   if (is_setup_done) return CEED_ERROR_SUCCESS;
214 
215   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
216   CeedCallBackend(CeedGetData(ceed, &ceed_impl));
217   CeedCallBackend(CeedOperatorGetData(op, &impl));
218   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
219   CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
220   CeedCallBackend(CeedQFunctionIsIdentity(qf, &impl->is_identity_qf));
221   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
222   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
223   const CeedInt block_size = ceed_impl->block_size;
224 
225   // Allocate
226   CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->block_rstr));
227   CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full));
228 
229   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in));
230   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out));
231   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out));
232   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states));
233   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in));
234   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out));
235   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in));
236   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out));
237 
238   impl->num_inputs  = num_input_fields;
239   impl->num_outputs = num_output_fields;
240 
241   // Set up infield and outfield pointer arrays
242   // Infields
243   CeedCallBackend(CeedOperatorSetupFields_Opt(qf, op, true, impl->skip_rstr_in, NULL, block_size, impl->block_rstr, impl->e_vecs_full,
244                                               impl->e_vecs_in, impl->q_vecs_in, 0, num_input_fields, Q));
245   // Outfields
246   CeedCallBackend(CeedOperatorSetupFields_Opt(qf, op, false, impl->skip_rstr_out, impl->apply_add_basis_out, block_size, impl->block_rstr,
247                                               impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, num_input_fields, num_output_fields, Q));
248 
249   // Identity QFunctions
250   if (impl->is_identity_qf) {
251     CeedEvalMode        in_mode, out_mode;
252     CeedQFunctionField *in_fields, *out_fields;
253 
254     CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &in_fields, NULL, &out_fields));
255     CeedCallBackend(CeedQFunctionFieldGetEvalMode(in_fields[0], &in_mode));
256     CeedCallBackend(CeedQFunctionFieldGetEvalMode(out_fields[0], &out_mode));
257 
258     if (in_mode == CEED_EVAL_NONE && out_mode == CEED_EVAL_NONE) {
259       impl->is_identity_rstr_op = true;
260     } else {
261       CeedCallBackend(CeedVectorReferenceCopy(impl->q_vecs_in[0], &impl->q_vecs_out[0]));
262     }
263   }
264 
265   CeedCallBackend(CeedOperatorSetSetupDone(op));
266   return CEED_ERROR_SUCCESS;
267 }
268 
269 //------------------------------------------------------------------------------
270 // Setup Input Fields
271 //------------------------------------------------------------------------------
272 static inline int CeedOperatorSetupInputs_Opt(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
273                                               CeedVector in_vec, CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Opt *impl,
274                                               CeedRequest *request) {
275   for (CeedInt i = 0; i < num_input_fields; i++) {
276     CeedEvalMode eval_mode;
277 
278     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
279     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
280     } else {
281       uint64_t   state;
282       CeedVector vec;
283 
284       // Get input vector
285       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
286       if (vec != CEED_VECTOR_ACTIVE) {
287         // Restrict
288         CeedCallBackend(CeedVectorGetState(vec, &state));
289         if (state != impl->input_states[i] && impl->block_rstr[i] && !impl->skip_rstr_in[i]) {
290           CeedCallBackend(CeedElemRestrictionApply(impl->block_rstr[i], CEED_NOTRANSPOSE, vec, impl->e_vecs_full[i], request));
291         }
292         impl->input_states[i] = state;
293         // Get evec
294         CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST, (const CeedScalar **)&e_data[i]));
295       } else {
296         // Set Qvec for CEED_EVAL_NONE
297         if (eval_mode == CEED_EVAL_NONE) {
298           CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_in[i], CEED_MEM_HOST, (const CeedScalar **)&e_data[i]));
299           CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, e_data[i]));
300           CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs_in[i], (const CeedScalar **)&e_data[i]));
301         }
302       }
303       CeedCallBackend(CeedVectorDestroy(&vec));
304     }
305   }
306   return CEED_ERROR_SUCCESS;
307 }
308 
309 //------------------------------------------------------------------------------
310 // Input Basis Action
311 //------------------------------------------------------------------------------
312 static inline int CeedOperatorInputBasis_Opt(CeedInt e, CeedInt Q, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
313                                              CeedInt num_input_fields, CeedInt block_size, CeedVector in_vec, bool skip_active,
314                                              CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Opt *impl, CeedRequest *request) {
315   for (CeedInt i = 0; i < num_input_fields; i++) {
316     bool                is_active;
317     CeedInt             elem_size, size, num_comp;
318     CeedEvalMode        eval_mode;
319     CeedVector          vec;
320     CeedElemRestriction elem_rstr;
321     CeedBasis           basis;
322 
323     // Skip active input
324     CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
325     is_active = vec == CEED_VECTOR_ACTIVE;
326     CeedCallBackend(CeedVectorDestroy(&vec));
327     if (skip_active && is_active) continue;
328 
329     // Get elem_size, eval_mode, size
330     CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
331     CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
332     CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
333     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
334     CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size));
335     // Restrict block active input
336     if (is_active && impl->block_rstr[i]) {
337       CeedCallBackend(CeedElemRestrictionApplyBlock(impl->block_rstr[i], e / block_size, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[i], request));
338     }
339     // Basis action
340     switch (eval_mode) {
341       case CEED_EVAL_NONE:
342         if (!is_active) {
343           CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data[i][(CeedSize)e * Q * size]));
344         }
345         break;
346       case CEED_EVAL_INTERP:
347       case CEED_EVAL_GRAD:
348       case CEED_EVAL_DIV:
349       case CEED_EVAL_CURL:
350         CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
351         if (!is_active) {
352           CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
353           CeedCallBackend(CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data[i][(CeedSize)e * elem_size * num_comp]));
354         }
355         CeedCallBackend(CeedBasisApply(basis, block_size, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs_in[i], impl->q_vecs_in[i]));
356         CeedCallBackend(CeedBasisDestroy(&basis));
357         break;
358       case CEED_EVAL_WEIGHT:
359         break;  // No action
360     }
361   }
362   return CEED_ERROR_SUCCESS;
363 }
364 
365 //------------------------------------------------------------------------------
366 // Output Basis Action
367 //------------------------------------------------------------------------------
368 static inline int CeedOperatorOutputBasis_Opt(CeedInt e, CeedInt Q, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields,
369                                               CeedInt block_size, CeedInt num_input_fields, CeedInt num_output_fields, bool *apply_add_basis,
370                                               bool *skip_rstr, CeedOperator op, CeedVector out_vec, CeedOperator_Opt *impl, CeedRequest *request) {
371   for (CeedInt i = 0; i < num_output_fields; i++) {
372     bool         is_active;
373     CeedEvalMode eval_mode;
374     CeedVector   vec;
375     CeedBasis    basis;
376 
377     // Get eval_mode
378     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
379     // Basis action
380     switch (eval_mode) {
381       case CEED_EVAL_NONE:
382         break;  // No action
383       case CEED_EVAL_INTERP:
384       case CEED_EVAL_GRAD:
385       case CEED_EVAL_DIV:
386       case CEED_EVAL_CURL:
387         CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
388         if (apply_add_basis[i]) {
389           CeedCallBackend(CeedBasisApplyAdd(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i]));
390         } else {
391           CeedCallBackend(CeedBasisApply(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i]));
392         }
393         CeedCallBackend(CeedBasisDestroy(&basis));
394         break;
395       // LCOV_EXCL_START
396       case CEED_EVAL_WEIGHT: {
397         return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
398         // LCOV_EXCL_STOP
399       }
400     }
401     // Restrict output block
402     if (skip_rstr[i]) continue;
403     // Get output vector
404     CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
405     is_active = vec == CEED_VECTOR_ACTIVE;
406     if (is_active) vec = out_vec;
407     // Restrict
408     CeedCallBackend(
409         CeedElemRestrictionApplyBlock(impl->block_rstr[i + impl->num_inputs], e / block_size, CEED_TRANSPOSE, impl->e_vecs_out[i], vec, request));
410     if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
411   }
412   return CEED_ERROR_SUCCESS;
413 }
414 
415 //------------------------------------------------------------------------------
416 // Restore Input Vectors
417 //------------------------------------------------------------------------------
418 static inline int CeedOperatorRestoreInputs_Opt(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
419                                                 CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Opt *impl) {
420   for (CeedInt i = 0; i < num_input_fields; i++) {
421     CeedEvalMode eval_mode;
422     CeedVector   vec;
423 
424     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
425     CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
426     if (eval_mode != CEED_EVAL_WEIGHT && vec != CEED_VECTOR_ACTIVE) {
427       CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs_full[i], (const CeedScalar **)&e_data[i]));
428     }
429     CeedCallBackend(CeedVectorDestroy(&vec));
430   }
431   return CEED_ERROR_SUCCESS;
432 }
433 
434 //------------------------------------------------------------------------------
435 // Operator Apply
436 //------------------------------------------------------------------------------
437 static int CeedOperatorApplyAdd_Opt(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) {
438   Ceed                ceed;
439   Ceed_Opt           *ceed_impl;
440   CeedInt             Q, num_input_fields, num_output_fields, num_elem;
441   CeedEvalMode        eval_mode;
442   CeedScalar         *e_data[2 * CEED_FIELD_MAX] = {0};
443   CeedQFunctionField *qf_input_fields, *qf_output_fields;
444   CeedQFunction       qf;
445   CeedOperatorField  *op_input_fields, *op_output_fields;
446   CeedOperator_Opt   *impl;
447 
448   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
449   CeedCallBackend(CeedGetData(ceed, &ceed_impl));
450   CeedCallBackend(CeedOperatorGetData(op, &impl));
451   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
452   CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
453   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
454   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
455   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
456   const CeedInt block_size = ceed_impl->block_size;
457   const CeedInt num_blocks = (num_elem / block_size) + !!(num_elem % block_size);
458 
459   // Setup
460   CeedCallBackend(CeedOperatorSetup_Opt(op));
461 
462   // Restriction only operator
463   if (impl->is_identity_rstr_op) {
464     for (CeedInt b = 0; b < num_blocks; b++) {
465       CeedCallBackend(CeedElemRestrictionApplyBlock(impl->block_rstr[0], b, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[0], request));
466       CeedCallBackend(CeedElemRestrictionApplyBlock(impl->block_rstr[1], b, CEED_TRANSPOSE, impl->e_vecs_in[0], out_vec, request));
467     }
468     return CEED_ERROR_SUCCESS;
469   }
470 
471   // Input Evecs and Restriction
472   CeedCallBackend(CeedOperatorSetupInputs_Opt(num_input_fields, qf_input_fields, op_input_fields, in_vec, e_data, impl, request));
473 
474   // Output Lvecs, Evecs, and Qvecs
475   for (CeedInt i = 0; i < num_output_fields; i++) {
476     // Set Qvec if needed
477     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
478     if (eval_mode == CEED_EVAL_NONE) {
479       // Set qvec to single block evec
480       CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_HOST, &e_data[i + num_input_fields]));
481       CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, e_data[i + num_input_fields]));
482       CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_out[i], &e_data[i + num_input_fields]));
483     }
484   }
485 
486   // Loop through elements
487   for (CeedInt e = 0; e < num_blocks * block_size; e += block_size) {
488     // Input basis apply
489     CeedCallBackend(
490         CeedOperatorInputBasis_Opt(e, Q, qf_input_fields, op_input_fields, num_input_fields, block_size, in_vec, false, e_data, impl, request));
491 
492     // Q function
493     if (!impl->is_identity_qf) {
494       CeedCallBackend(CeedQFunctionApply(qf, Q * block_size, impl->q_vecs_in, impl->q_vecs_out));
495     }
496 
497     // Output basis apply and restriction
498     CeedCallBackend(CeedOperatorOutputBasis_Opt(e, Q, qf_output_fields, op_output_fields, block_size, num_input_fields, num_output_fields,
499                                                 impl->apply_add_basis_out, impl->skip_rstr_out, op, out_vec, impl, request));
500   }
501 
502   // Restore input arrays
503   CeedCallBackend(CeedOperatorRestoreInputs_Opt(num_input_fields, qf_input_fields, op_input_fields, e_data, impl));
504   return CEED_ERROR_SUCCESS;
505 }
506 
507 //------------------------------------------------------------------------------
508 // Core code for linear QFunction assembly
509 //------------------------------------------------------------------------------
510 static inline int CeedOperatorLinearAssembleQFunctionCore_Opt(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr,
511                                                               CeedRequest *request) {
512   Ceed                ceed;
513   Ceed_Opt           *ceed_impl;
514   CeedInt             qf_size_in, qf_size_out, Q, num_input_fields, num_output_fields, num_elem;
515   CeedScalar         *l_vec_array, *e_data[2 * CEED_FIELD_MAX] = {0};
516   CeedQFunctionField *qf_input_fields, *qf_output_fields;
517   CeedQFunction       qf;
518   CeedOperatorField  *op_input_fields, *op_output_fields;
519   CeedOperator_Opt   *impl;
520 
521   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
522   CeedCallBackend(CeedGetData(ceed, &ceed_impl));
523   CeedCallBackend(CeedOperatorGetData(op, &impl));
524   qf_size_in  = impl->qf_size_in;
525   qf_size_out = impl->qf_size_out;
526 
527   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
528   CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
529   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
530   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
531   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
532   const CeedInt       block_size = ceed_impl->block_size;
533   const CeedInt       num_blocks = (num_elem / block_size) + !!(num_elem % block_size);
534   CeedVector          l_vec      = impl->qf_l_vec;
535   CeedElemRestriction block_rstr = impl->qf_block_rstr;
536 
537   // Setup
538   CeedCallBackend(CeedOperatorSetup_Opt(op));
539 
540   // Check for restriction only operator
541   CeedCheck(!impl->is_identity_rstr_op, ceed, CEED_ERROR_BACKEND, "Assembling restriction only operators is not supported");
542 
543   // Input Evecs and Restriction
544   CeedCallBackend(CeedOperatorSetupInputs_Opt(num_input_fields, qf_input_fields, op_input_fields, NULL, e_data, impl, request));
545 
546   // Count number of active input fields
547   if (qf_size_in == 0) {
548     for (CeedInt i = 0; i < num_input_fields; i++) {
549       CeedInt    field_size;
550       CeedVector vec;
551 
552       // Check if active input
553       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
554       if (vec == CEED_VECTOR_ACTIVE) {
555         CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size));
556         CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0));
557         qf_size_in += field_size;
558       }
559       CeedCallBackend(CeedVectorDestroy(&vec));
560     }
561     CeedCheck(qf_size_in > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs");
562     impl->qf_size_in = qf_size_in;
563   }
564 
565   // Count number of active output fields
566   if (qf_size_out == 0) {
567     for (CeedInt i = 0; i < num_output_fields; i++) {
568       CeedInt    field_size;
569       CeedVector vec;
570 
571       // Check if active output
572       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
573       if (vec == CEED_VECTOR_ACTIVE) {
574         CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &field_size));
575         qf_size_out += field_size;
576       }
577       CeedCallBackend(CeedVectorDestroy(&vec));
578     }
579     CeedCheck(qf_size_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs");
580     impl->qf_size_out = qf_size_out;
581   }
582 
583   // Setup l_vec
584   if (!l_vec) {
585     const CeedSize l_size = (CeedSize)block_size * Q * qf_size_in * qf_size_out;
586 
587     CeedCallBackend(CeedVectorCreate(ceed, l_size, &l_vec));
588     CeedCallBackend(CeedVectorSetValue(l_vec, 0.0));
589     impl->qf_l_vec = l_vec;
590   }
591 
592   // Output blocked restriction
593   if (!block_rstr) {
594     CeedInt strides[3] = {1, Q, qf_size_in * qf_size_out * Q};
595 
596     CeedCallBackend(CeedElemRestrictionCreateBlockedStrided(ceed, num_elem, Q, block_size, qf_size_in * qf_size_out,
597                                                             qf_size_in * qf_size_out * num_elem * Q, strides, &block_rstr));
598     impl->qf_block_rstr = block_rstr;
599   }
600 
601   // Build objects if needed
602   if (build_objects) {
603     const CeedSize l_size     = (CeedSize)num_elem * Q * qf_size_in * qf_size_out;
604     CeedInt        strides[3] = {1, Q, qf_size_in * qf_size_out * Q};
605 
606     // Create output restriction
607     CeedCallBackend(CeedElemRestrictionCreateStrided(ceed, num_elem, Q, qf_size_in * qf_size_out,
608                                                      (CeedSize)qf_size_in * (CeedSize)qf_size_out * (CeedSize)num_elem * (CeedSize)Q, strides, rstr));
609     // Create assembled vector
610     CeedCallBackend(CeedVectorCreate(ceed, l_size, assembled));
611   }
612 
613   // Loop through elements
614   CeedCallBackend(CeedVectorSetValue(*assembled, 0.0));
615   for (CeedInt e = 0; e < num_blocks * block_size; e += block_size) {
616     CeedCallBackend(CeedVectorGetArray(l_vec, CEED_MEM_HOST, &l_vec_array));
617 
618     // Input basis apply
619     CeedCallBackend(
620         CeedOperatorInputBasis_Opt(e, Q, qf_input_fields, op_input_fields, num_input_fields, block_size, NULL, true, e_data, impl, request));
621 
622     // Assemble QFunction
623     for (CeedInt i = 0; i < num_input_fields; i++) {
624       bool       is_active;
625       CeedInt    field_size;
626       CeedVector vec;
627 
628       // Check if active input
629       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
630       is_active = vec == CEED_VECTOR_ACTIVE;
631       CeedCallBackend(CeedVectorDestroy(&vec));
632       if (!is_active) continue;
633       CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size));
634       for (CeedInt field = 0; field < field_size; field++) {
635         // Set current portion of input to 1.0
636         {
637           CeedScalar *array;
638 
639           CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &array));
640           for (CeedInt j = 0; j < Q * block_size; j++) array[field * Q * block_size + j] = 1.0;
641           CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &array));
642         }
643 
644         if (!impl->is_identity_qf) {
645           // Set Outputs
646           for (CeedInt out = 0; out < num_output_fields; out++) {
647             CeedVector vec;
648 
649             // Check if active output
650             CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec));
651             if (vec == CEED_VECTOR_ACTIVE) {
652               CeedInt field_size;
653 
654               CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_USE_POINTER, l_vec_array));
655               CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &field_size));
656               l_vec_array += field_size * Q * block_size;  // Advance the pointer by the size of the output
657             }
658             CeedCallBackend(CeedVectorDestroy(&vec));
659           }
660           // Apply QFunction
661           CeedCallBackend(CeedQFunctionApply(qf, Q * block_size, impl->q_vecs_in, impl->q_vecs_out));
662         } else {
663           CeedInt           field_size;
664           const CeedScalar *array;
665 
666           // Copy Identity Outputs
667           CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[0], &field_size));
668           CeedCallBackend(CeedVectorGetArrayRead(impl->q_vecs_out[0], CEED_MEM_HOST, &array));
669           for (CeedInt j = 0; j < field_size * Q * block_size; j++) l_vec_array[j] = array[j];
670           CeedCallBackend(CeedVectorRestoreArrayRead(impl->q_vecs_out[0], &array));
671           l_vec_array += field_size * Q * block_size;
672         }
673         // Reset input to 0.0
674         {
675           CeedScalar *array;
676 
677           CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &array));
678           for (CeedInt j = 0; j < Q * block_size; j++) array[field * Q * block_size + j] = 0.0;
679           CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &array));
680         }
681       }
682     }
683 
684     // Un-set output Qvecs to prevent accidental overwrite of Assembled
685     if (!impl->is_identity_qf) {
686       for (CeedInt out = 0; out < num_output_fields; out++) {
687         CeedVector vec;
688 
689         // Check if active output
690         CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec));
691         if (vec == CEED_VECTOR_ACTIVE && num_elem > 0) {
692           CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_HOST, NULL));
693         }
694         CeedCallBackend(CeedVectorDestroy(&vec));
695       }
696     }
697 
698     // Assemble into assembled vector
699     CeedCallBackend(CeedVectorRestoreArray(l_vec, &l_vec_array));
700     CeedCallBackend(CeedElemRestrictionApplyBlock(block_rstr, e / block_size, CEED_TRANSPOSE, l_vec, *assembled, request));
701   }
702 
703   // Reset output Qvecs
704   for (CeedInt out = 0; out < num_output_fields; out++) {
705     CeedVector vec;
706 
707     // Initialize array if active output
708     CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec));
709     if (vec == CEED_VECTOR_ACTIVE) CeedCallBackend(CeedVectorSetValue(impl->q_vecs_out[out], 0.0));
710     CeedCallBackend(CeedVectorDestroy(&vec));
711   }
712 
713   // Restore input arrays
714   CeedCallBackend(CeedOperatorRestoreInputs_Opt(num_input_fields, qf_input_fields, op_input_fields, e_data, impl));
715   return CEED_ERROR_SUCCESS;
716 }
717 
718 //------------------------------------------------------------------------------
719 // Assemble Linear QFunction
720 //------------------------------------------------------------------------------
721 static int CeedOperatorLinearAssembleQFunction_Opt(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) {
722   return CeedOperatorLinearAssembleQFunctionCore_Opt(op, true, assembled, rstr, request);
723 }
724 
725 //------------------------------------------------------------------------------
726 // Update Assembled Linear QFunction
727 //------------------------------------------------------------------------------
728 static int CeedOperatorLinearAssembleQFunctionUpdate_Opt(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) {
729   return CeedOperatorLinearAssembleQFunctionCore_Opt(op, false, &assembled, &rstr, request);
730 }
731 
732 //------------------------------------------------------------------------------
733 // Operator Destroy
734 //------------------------------------------------------------------------------
735 static int CeedOperatorDestroy_Opt(CeedOperator op) {
736   CeedOperator_Opt *impl;
737 
738   CeedCallBackend(CeedOperatorGetData(op, &impl));
739   for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) {
740     CeedCallBackend(CeedElemRestrictionDestroy(&impl->block_rstr[i]));
741     CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_full[i]));
742   }
743   CeedCallBackend(CeedFree(&impl->block_rstr));
744   CeedCallBackend(CeedFree(&impl->e_vecs_full));
745   CeedCallBackend(CeedFree(&impl->input_states));
746   CeedCallBackend(CeedFree(&impl->skip_rstr_in));
747   CeedCallBackend(CeedFree(&impl->skip_rstr_out));
748   CeedCallBackend(CeedFree(&impl->apply_add_basis_out));
749 
750   for (CeedInt i = 0; i < impl->num_inputs; i++) {
751     CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_in[i]));
752     CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i]));
753   }
754   CeedCallBackend(CeedFree(&impl->e_vecs_in));
755   CeedCallBackend(CeedFree(&impl->q_vecs_in));
756 
757   for (CeedInt i = 0; i < impl->num_outputs; i++) {
758     CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_out[i]));
759     CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i]));
760   }
761   CeedCallBackend(CeedFree(&impl->e_vecs_out));
762   CeedCallBackend(CeedFree(&impl->q_vecs_out));
763 
764   // QFunction assembly data
765   CeedCallBackend(CeedVectorDestroy(&impl->qf_l_vec));
766   CeedCallBackend(CeedElemRestrictionDestroy(&impl->qf_block_rstr));
767 
768   CeedCallBackend(CeedFree(&impl));
769   return CEED_ERROR_SUCCESS;
770 }
771 
772 //------------------------------------------------------------------------------
773 // Operator Create
774 //------------------------------------------------------------------------------
775 int CeedOperatorCreate_Opt(CeedOperator op) {
776   Ceed              ceed;
777   Ceed_Opt         *ceed_impl;
778   CeedOperator_Opt *impl;
779 
780   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
781   CeedCallBackend(CeedGetData(ceed, &ceed_impl));
782   const CeedInt block_size = ceed_impl->block_size;
783 
784   CeedCallBackend(CeedCalloc(1, &impl));
785   CeedCallBackend(CeedOperatorSetData(op, impl));
786 
787   CeedCheck(block_size == 1 || block_size == 8, ceed, CEED_ERROR_BACKEND, "Opt backend cannot use blocksize: %" CeedInt_FMT, block_size);
788 
789   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Opt));
790   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Opt));
791   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Opt));
792   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Opt));
793   return CEED_ERROR_SUCCESS;
794 }
795 
796 //------------------------------------------------------------------------------
797