xref: /libCEED/backends/opt/ceed-opt-operator.c (revision 77d1c127eaba12da4c1761ef74a16ca3fc16e493)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <stdbool.h>
11 #include <stdint.h>
12 #include <string.h>
13 
14 #include "ceed-opt.h"
15 
16 //------------------------------------------------------------------------------
17 // Setup Input/Output Fields
18 //------------------------------------------------------------------------------
19 static int CeedOperatorSetupFields_Opt(CeedQFunction qf, CeedOperator op, bool is_input, const CeedInt blk_size, CeedElemRestriction *blk_restr,
20                                        CeedVector *e_vecs_full, CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields,
21                                        CeedInt Q) {
22   CeedInt  num_comp, size, P;
23   CeedSize e_size, q_size;
24   Ceed     ceed;
25   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
26   CeedBasis           basis;
27   CeedElemRestriction r;
28   CeedOperatorField  *op_fields;
29   CeedQFunctionField *qf_fields;
30   if (is_input) {
31     CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL));
32     CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL));
33   } else {
34     CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields));
35     CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields));
36   }
37 
38   // Loop over fields
39   for (CeedInt i = 0; i < num_fields; i++) {
40     CeedEvalMode eval_mode;
41     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
42 
43     if (eval_mode != CEED_EVAL_WEIGHT) {
44       CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &r));
45       Ceed ceed;
46       CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
47       CeedSize l_size;
48       CeedInt  num_elem, elem_size, comp_stride;
49       CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
50       CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
51       CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size));
52       CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
53 
54       bool strided;
55       CeedCallBackend(CeedElemRestrictionIsStrided(r, &strided));
56       if (strided) {
57         CeedInt strides[3];
58         CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides));
59         CeedCallBackend(
60             CeedElemRestrictionCreateBlockedStrided(ceed, num_elem, elem_size, blk_size, num_comp, l_size, strides, &blk_restr[i + start_e]));
61       } else {
62         const CeedInt *offsets      = NULL;
63         const bool    *orients      = NULL;
64         const CeedInt *curl_orients = NULL;
65         CeedCallBackend(CeedElemRestrictionGetOffsets(r, CEED_MEM_HOST, &offsets));
66         CeedCallBackend(CeedElemRestrictionGetOrientations(r, CEED_MEM_HOST, &orients));
67         CeedCallBackend(CeedElemRestrictionGetCurlOrientations(r, CEED_MEM_HOST, &curl_orients));
68         CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride));
69         if (!orients && !curl_orients) {
70           CeedCallBackend(CeedElemRestrictionCreateBlocked(ceed, num_elem, elem_size, blk_size, num_comp, comp_stride, l_size, CEED_MEM_HOST,
71                                                            CEED_COPY_VALUES, offsets, &blk_restr[i + start_e]));
72         } else if (!curl_orients) {
73           CeedCallBackend(CeedElemRestrictionCreateBlockedOriented(ceed, num_elem, elem_size, blk_size, num_comp, comp_stride, l_size, CEED_MEM_HOST,
74                                                                    CEED_COPY_VALUES, offsets, orients, &blk_restr[i + start_e]));
75         } else {
76           CeedCallBackend(CeedElemRestrictionCreateBlockedCurlOriented(ceed, num_elem, elem_size, blk_size, num_comp, comp_stride, l_size,
77                                                                        CEED_MEM_HOST, CEED_COPY_VALUES, offsets, curl_orients,
78                                                                        &blk_restr[i + start_e]));
79         }
80         CeedCallBackend(CeedElemRestrictionRestoreOffsets(r, &offsets));
81         CeedCallBackend(CeedElemRestrictionRestoreOrientations(r, &orients));
82         CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(r, &curl_orients));
83       }
84       CeedCallBackend(CeedElemRestrictionCreateVector(blk_restr[i + start_e], NULL, &e_vecs_full[i + start_e]));
85     }
86 
87     switch (eval_mode) {
88       case CEED_EVAL_NONE:
89         CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size));
90         e_size = (CeedSize)Q * size * blk_size;
91         CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i]));
92         q_size = (CeedSize)Q * size * blk_size;
93         CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
94         break;
95       case CEED_EVAL_INTERP:
96       case CEED_EVAL_GRAD:
97       case CEED_EVAL_DIV:
98       case CEED_EVAL_CURL:
99         CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
100         CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size));
101         CeedCallBackend(CeedBasisGetNumNodes(basis, &P));
102         CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
103         e_size = (CeedSize)P * num_comp * blk_size;
104         CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i]));
105         q_size = (CeedSize)Q * size * blk_size;
106         CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
107         break;
108       case CEED_EVAL_WEIGHT:  // Only on input fields
109         CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
110         q_size = (CeedSize)Q * blk_size;
111         CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
112         CeedCallBackend(CeedBasisApply(basis, blk_size, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i]));
113         break;
114     }
115     if (is_input && !!e_vecs[i]) {
116       CeedCallBackend(CeedVectorSetArray(e_vecs[i], CEED_MEM_HOST, CEED_COPY_VALUES, NULL));
117     }
118   }
119   return CEED_ERROR_SUCCESS;
120 }
121 
122 //------------------------------------------------------------------------------
123 // Setup Operator
124 //------------------------------------------------------------------------------
125 static int CeedOperatorSetup_Opt(CeedOperator op) {
126   bool is_setup_done;
127   CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done));
128   if (is_setup_done) return CEED_ERROR_SUCCESS;
129   Ceed ceed;
130   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
131   Ceed_Opt *ceed_impl;
132   CeedCallBackend(CeedGetData(ceed, &ceed_impl));
133   const CeedInt     blk_size = ceed_impl->blk_size;
134   CeedOperator_Opt *impl;
135   CeedCallBackend(CeedOperatorGetData(op, &impl));
136   CeedQFunction qf;
137   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
138   CeedInt Q, num_input_fields, num_output_fields;
139   CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
140   CeedCallBackend(CeedQFunctionIsIdentity(qf, &impl->is_identity_qf));
141   CeedOperatorField *op_input_fields, *op_output_fields;
142   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
143   CeedQFunctionField *qf_input_fields, *qf_output_fields;
144   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
145 
146   // Allocate
147   CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->blk_restr));
148   CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full));
149 
150   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states));
151   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in));
152   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out));
153   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in));
154   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out));
155 
156   impl->num_inputs  = num_input_fields;
157   impl->num_outputs = num_output_fields;
158 
159   // Set up infield and outfield pointer arrays
160   // Infields
161   CeedCallBackend(CeedOperatorSetupFields_Opt(qf, op, true, blk_size, impl->blk_restr, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0,
162                                               num_input_fields, Q));
163   // Outfields
164   CeedCallBackend(CeedOperatorSetupFields_Opt(qf, op, false, blk_size, impl->blk_restr, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out,
165                                               num_input_fields, num_output_fields, Q));
166 
167   // Identity QFunctions
168   if (impl->is_identity_qf) {
169     CeedEvalMode        in_mode, out_mode;
170     CeedQFunctionField *in_fields, *out_fields;
171     CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &in_fields, NULL, &out_fields));
172     CeedCallBackend(CeedQFunctionFieldGetEvalMode(in_fields[0], &in_mode));
173     CeedCallBackend(CeedQFunctionFieldGetEvalMode(out_fields[0], &out_mode));
174 
175     if (in_mode == CEED_EVAL_NONE && out_mode == CEED_EVAL_NONE) {
176       impl->is_identity_restr_op = true;
177     } else {
178       CeedCallBackend(CeedVectorReferenceCopy(impl->q_vecs_in[0], &impl->q_vecs_out[0]));
179     }
180   }
181 
182   CeedCallBackend(CeedOperatorSetSetupDone(op));
183 
184   return CEED_ERROR_SUCCESS;
185 }
186 
187 //------------------------------------------------------------------------------
188 // Setup Input Fields
189 //------------------------------------------------------------------------------
190 static inline int CeedOperatorSetupInputs_Opt(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
191                                               CeedVector in_vec, CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Opt *impl,
192                                               CeedRequest *request) {
193   CeedEvalMode eval_mode;
194   CeedVector   vec;
195   uint64_t     state;
196 
197   for (CeedInt i = 0; i < num_input_fields; i++) {
198     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
199     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
200     } else {
201       // Get input vector
202       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
203       if (vec != CEED_VECTOR_ACTIVE) {
204         // Restrict
205         CeedCallBackend(CeedVectorGetState(vec, &state));
206         if (state != impl->input_states[i]) {
207           CeedCallBackend(CeedElemRestrictionApply(impl->blk_restr[i], CEED_NOTRANSPOSE, vec, impl->e_vecs_full[i], request));
208           impl->input_states[i] = state;
209         }
210         // Get evec
211         CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST, (const CeedScalar **)&e_data[i]));
212       } else {
213         // Set Qvec for CEED_EVAL_NONE
214         if (eval_mode == CEED_EVAL_NONE) {
215           CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_in[i], CEED_MEM_HOST, (const CeedScalar **)&e_data[i]));
216           CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, e_data[i]));
217           CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs_in[i], (const CeedScalar **)&e_data[i]));
218         }
219       }
220     }
221   }
222   return CEED_ERROR_SUCCESS;
223 }
224 
225 //------------------------------------------------------------------------------
226 // Input Basis Action
227 //------------------------------------------------------------------------------
228 static inline int CeedOperatorInputBasis_Opt(CeedInt e, CeedInt Q, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
229                                              CeedInt num_input_fields, CeedInt blk_size, CeedVector in_vec, bool skip_active,
230                                              CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Opt *impl, CeedRequest *request) {
231   CeedInt             elem_size, size, num_comp;
232   CeedElemRestriction elem_restr;
233   CeedEvalMode        eval_mode;
234   CeedBasis           basis;
235   CeedVector          vec;
236 
237   for (CeedInt i = 0; i < num_input_fields; i++) {
238     CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
239     // Skip active input
240     if (skip_active) {
241       if (vec == CEED_VECTOR_ACTIVE) continue;
242     }
243 
244     CeedInt active_in = 0;
245     // Get elem_size, eval_mode, size
246     CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_restr));
247     CeedCallBackend(CeedElemRestrictionGetElementSize(elem_restr, &elem_size));
248     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
249     CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size));
250     // Restrict block active input
251     if (vec == CEED_VECTOR_ACTIVE) {
252       CeedCallBackend(CeedElemRestrictionApplyBlock(impl->blk_restr[i], e / blk_size, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[i], request));
253       active_in = 1;
254     }
255     // Basis action
256     switch (eval_mode) {
257       case CEED_EVAL_NONE:
258         if (!active_in) {
259           CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data[i][e * Q * size]));
260         }
261         break;
262       case CEED_EVAL_INTERP:
263       case CEED_EVAL_GRAD:
264       case CEED_EVAL_DIV:
265       case CEED_EVAL_CURL:
266         CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
267         if (!active_in) {
268           CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
269           CeedCallBackend(CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data[i][e * elem_size * num_comp]));
270         }
271         CeedCallBackend(CeedBasisApply(basis, blk_size, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs_in[i], impl->q_vecs_in[i]));
272         break;
273       case CEED_EVAL_WEIGHT:
274         break;  // No action
275     }
276   }
277   return CEED_ERROR_SUCCESS;
278 }
279 
280 //------------------------------------------------------------------------------
281 // Output Basis Action
282 //------------------------------------------------------------------------------
283 static inline int CeedOperatorOutputBasis_Opt(CeedInt e, CeedInt Q, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields,
284                                               CeedInt blk_size, CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op,
285                                               CeedVector out_vec, CeedOperator_Opt *impl, CeedRequest *request) {
286   CeedElemRestriction elem_restr;
287   CeedEvalMode        eval_mode;
288   CeedBasis           basis;
289   CeedVector          vec;
290 
291   for (CeedInt i = 0; i < num_output_fields; i++) {
292     // Get elem_size, eval_mode, size
293     CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_restr));
294     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
295     // Basis action
296     switch (eval_mode) {
297       case CEED_EVAL_NONE:
298         break;  // No action
299       case CEED_EVAL_INTERP:
300       case CEED_EVAL_GRAD:
301       case CEED_EVAL_DIV:
302       case CEED_EVAL_CURL:
303         CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
304         CeedCallBackend(CeedBasisApply(basis, blk_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i]));
305         break;
306       // LCOV_EXCL_START
307       case CEED_EVAL_WEIGHT: {
308         Ceed ceed;
309         CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
310         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
311         // LCOV_EXCL_STOP
312       }
313     }
314     // Restrict output block
315     // Get output vector
316     CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
317     if (vec == CEED_VECTOR_ACTIVE) vec = out_vec;
318     // Restrict
319     CeedCallBackend(
320         CeedElemRestrictionApplyBlock(impl->blk_restr[i + impl->num_inputs], e / blk_size, CEED_TRANSPOSE, impl->e_vecs_out[i], vec, request));
321   }
322   return CEED_ERROR_SUCCESS;
323 }
324 
325 //------------------------------------------------------------------------------
326 // Restore Input Vectors
327 //------------------------------------------------------------------------------
328 static inline int CeedOperatorRestoreInputs_Opt(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
329                                                 CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Opt *impl) {
330   for (CeedInt i = 0; i < num_input_fields; i++) {
331     CeedEvalMode eval_mode;
332     CeedVector   vec;
333     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
334     CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
335     if (eval_mode != CEED_EVAL_WEIGHT && vec != CEED_VECTOR_ACTIVE) {
336       CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs_full[i], (const CeedScalar **)&e_data[i]));
337     }
338   }
339   return CEED_ERROR_SUCCESS;
340 }
341 
342 //------------------------------------------------------------------------------
343 // Operator Apply
344 //------------------------------------------------------------------------------
345 static int CeedOperatorApplyAdd_Opt(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) {
346   Ceed ceed;
347   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
348   Ceed_Opt *ceed_impl;
349   CeedCallBackend(CeedGetData(ceed, &ceed_impl));
350   CeedInt           blk_size = ceed_impl->blk_size;
351   CeedOperator_Opt *impl;
352   CeedCallBackend(CeedOperatorGetData(op, &impl));
353   CeedInt Q, num_input_fields, num_output_fields, num_elem;
354   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
355   CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
356   CeedInt       num_blks = (num_elem / blk_size) + !!(num_elem % blk_size);
357   CeedQFunction qf;
358   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
359   CeedOperatorField *op_input_fields, *op_output_fields;
360   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
361   CeedQFunctionField *qf_input_fields, *qf_output_fields;
362   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
363   CeedEvalMode eval_mode;
364   CeedScalar  *e_data[2 * CEED_FIELD_MAX] = {0};
365 
366   // Setup
367   CeedCallBackend(CeedOperatorSetup_Opt(op));
368 
369   // Restriction only operator
370   if (impl->is_identity_restr_op) {
371     for (CeedInt b = 0; b < num_blks; b++) {
372       CeedCallBackend(CeedElemRestrictionApplyBlock(impl->blk_restr[0], b, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[0], request));
373       CeedCallBackend(CeedElemRestrictionApplyBlock(impl->blk_restr[1], b, CEED_TRANSPOSE, impl->e_vecs_in[0], out_vec, request));
374     }
375     return CEED_ERROR_SUCCESS;
376   }
377 
378   // Input Evecs and Restriction
379   CeedCallBackend(CeedOperatorSetupInputs_Opt(num_input_fields, qf_input_fields, op_input_fields, in_vec, e_data, impl, request));
380 
381   // Output Lvecs, Evecs, and Qvecs
382   for (CeedInt i = 0; i < num_output_fields; i++) {
383     // Set Qvec if needed
384     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
385     if (eval_mode == CEED_EVAL_NONE) {
386       // Set qvec to single block evec
387       CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_HOST, &e_data[i + num_input_fields]));
388       CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, e_data[i + num_input_fields]));
389       CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_out[i], &e_data[i + num_input_fields]));
390     }
391   }
392 
393   // Loop through elements
394   for (CeedInt e = 0; e < num_blks * blk_size; e += blk_size) {
395     // Input basis apply
396     CeedCallBackend(
397         CeedOperatorInputBasis_Opt(e, Q, qf_input_fields, op_input_fields, num_input_fields, blk_size, in_vec, false, e_data, impl, request));
398 
399     // Q function
400     if (!impl->is_identity_qf) {
401       CeedCallBackend(CeedQFunctionApply(qf, Q * blk_size, impl->q_vecs_in, impl->q_vecs_out));
402     }
403 
404     // Output basis apply and restrict
405     CeedCallBackend(CeedOperatorOutputBasis_Opt(e, Q, qf_output_fields, op_output_fields, blk_size, num_input_fields, num_output_fields, op, out_vec,
406                                                 impl, request));
407   }
408 
409   // Restore input arrays
410   CeedCallBackend(CeedOperatorRestoreInputs_Opt(num_input_fields, qf_input_fields, op_input_fields, e_data, impl));
411 
412   return CEED_ERROR_SUCCESS;
413 }
414 
415 //------------------------------------------------------------------------------
416 // Core code for linear QFunction assembly
417 //------------------------------------------------------------------------------
418 static inline int CeedOperatorLinearAssembleQFunctionCore_Opt(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr,
419                                                               CeedRequest *request) {
420   Ceed ceed;
421   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
422   Ceed_Opt *ceed_impl;
423   CeedCallBackend(CeedGetData(ceed, &ceed_impl));
424   const CeedInt     blk_size = ceed_impl->blk_size;
425   CeedSize          q_size;
426   CeedOperator_Opt *impl;
427   CeedCallBackend(CeedOperatorGetData(op, &impl));
428   CeedInt Q, num_input_fields, num_output_fields, num_elem, size;
429   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
430   CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
431   CeedInt       num_blks = (num_elem / blk_size) + !!(num_elem % blk_size);
432   CeedQFunction qf;
433   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
434   CeedOperatorField *op_input_fields, *op_output_fields;
435   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
436   CeedQFunctionField *qf_input_fields, *qf_output_fields;
437   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
438   CeedVector  vec, l_vec = impl->qf_l_vec;
439   CeedInt     num_active_in = impl->num_active_in, num_active_out = impl->num_active_out;
440   CeedVector *active_in = impl->qf_active_in;
441   CeedScalar *a, *tmp;
442   CeedScalar *e_data[2 * CEED_FIELD_MAX] = {0};
443 
444   // Setup
445   CeedCallBackend(CeedOperatorSetup_Opt(op));
446 
447   // Check for identity
448   CeedCheck(!impl->is_identity_qf, ceed, CEED_ERROR_BACKEND, "Assembling identity qfunctions not supported");
449 
450   // Input Evecs and Restriction
451   CeedCallBackend(CeedOperatorSetupInputs_Opt(num_input_fields, qf_input_fields, op_input_fields, NULL, e_data, impl, request));
452 
453   // Count number of active input fields
454   if (!num_active_in) {
455     for (CeedInt i = 0; i < num_input_fields; i++) {
456       // Get input vector
457       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
458       // Check if active input
459       if (vec == CEED_VECTOR_ACTIVE) {
460         CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size));
461         CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0));
462         CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &tmp));
463         CeedCallBackend(CeedRealloc(num_active_in + size, &active_in));
464         for (CeedInt field = 0; field < size; field++) {
465           q_size = (CeedSize)Q * blk_size;
466           CeedCallBackend(CeedVectorCreate(ceed, q_size, &active_in[num_active_in + field]));
467           CeedCallBackend(CeedVectorSetArray(active_in[num_active_in + field], CEED_MEM_HOST, CEED_USE_POINTER, &tmp[field * Q * blk_size]));
468         }
469         num_active_in += size;
470         CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &tmp));
471       }
472     }
473     impl->num_active_in = num_active_in;
474     impl->qf_active_in  = active_in;
475   }
476 
477   // Count number of active output fields
478   if (!num_active_out) {
479     for (CeedInt i = 0; i < num_output_fields; i++) {
480       // Get output vector
481       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
482       // Check if active output
483       if (vec == CEED_VECTOR_ACTIVE) {
484         CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size));
485         num_active_out += size;
486       }
487     }
488     impl->num_active_out = num_active_out;
489   }
490 
491   // Check sizes
492   CeedCheck(num_active_in > 0 && num_active_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs");
493 
494   // Setup l_vec
495   if (!l_vec) {
496     CeedSize l_size = (CeedSize)blk_size * Q * num_active_in * num_active_out;
497     CeedCallBackend(CeedVectorCreate(ceed, l_size, &l_vec));
498     CeedCallBackend(CeedVectorSetValue(l_vec, 0.0));
499     impl->qf_l_vec = l_vec;
500   }
501 
502   // Build objects if needed
503   CeedInt strides[3] = {1, Q, num_active_in * num_active_out * Q};
504   if (build_objects) {
505     // Create output restriction
506     CeedCallBackend(CeedElemRestrictionCreateStrided(ceed, num_elem, Q, num_active_in * num_active_out, num_active_in * num_active_out * num_elem * Q,
507                                                      strides, rstr));
508     // Create assembled vector
509     CeedSize l_size = (CeedSize)num_elem * Q * num_active_in * num_active_out;
510     CeedCallBackend(CeedVectorCreate(ceed, l_size, assembled));
511   }
512 
513   // Output blocked restriction
514   CeedElemRestriction blk_rstr = impl->qf_blk_rstr;
515   if (!blk_rstr) {
516     CeedCallBackend(CeedElemRestrictionCreateBlockedStrided(ceed, num_elem, Q, blk_size, num_active_in * num_active_out,
517                                                             num_active_in * num_active_out * num_elem * Q, strides, &blk_rstr));
518     impl->qf_blk_rstr = blk_rstr;
519   }
520 
521   // Loop through elements
522   CeedCallBackend(CeedVectorSetValue(*assembled, 0.0));
523   for (CeedInt e = 0; e < num_blks * blk_size; e += blk_size) {
524     CeedCallBackend(CeedVectorGetArray(l_vec, CEED_MEM_HOST, &a));
525 
526     // Input basis apply
527     CeedCallBackend(
528         CeedOperatorInputBasis_Opt(e, Q, qf_input_fields, op_input_fields, num_input_fields, blk_size, NULL, true, e_data, impl, request));
529 
530     // Assemble QFunction
531     for (CeedInt in = 0; in < num_active_in; in++) {
532       // Set Inputs
533       CeedCallBackend(CeedVectorSetValue(active_in[in], 1.0));
534       if (num_active_in > 1) {
535         CeedCallBackend(CeedVectorSetValue(active_in[(in + num_active_in - 1) % num_active_in], 0.0));
536       }
537       // Set Outputs
538       for (CeedInt out = 0; out < num_output_fields; out++) {
539         // Get output vector
540         CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec));
541         // Check if active output
542         if (vec == CEED_VECTOR_ACTIVE) {
543           CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_USE_POINTER, a));
544           CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &size));
545           a += size * Q * blk_size;  // Advance the pointer by the size of the output
546         }
547       }
548       // Apply QFunction
549       CeedCallBackend(CeedQFunctionApply(qf, Q * blk_size, impl->q_vecs_in, impl->q_vecs_out));
550     }
551 
552     // Assemble into assembled vector
553     CeedCallBackend(CeedVectorRestoreArray(l_vec, &a));
554     CeedCallBackend(CeedElemRestrictionApplyBlock(blk_rstr, e / blk_size, CEED_TRANSPOSE, l_vec, *assembled, request));
555   }
556 
557   // Un-set output Qvecs to prevent accidental overwrite of Assembled
558   for (CeedInt out = 0; out < num_output_fields; out++) {
559     // Get output vector
560     CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec));
561     // Check if active output
562     if (vec == CEED_VECTOR_ACTIVE) {
563       CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_COPY_VALUES, NULL));
564     }
565   }
566 
567   // Restore input arrays
568   CeedCallBackend(CeedOperatorRestoreInputs_Opt(num_input_fields, qf_input_fields, op_input_fields, e_data, impl));
569 
570   return CEED_ERROR_SUCCESS;
571 }
572 
573 //------------------------------------------------------------------------------
574 // Assemble Linear QFunction
575 //------------------------------------------------------------------------------
576 static int CeedOperatorLinearAssembleQFunction_Opt(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) {
577   return CeedOperatorLinearAssembleQFunctionCore_Opt(op, true, assembled, rstr, request);
578 }
579 
580 //------------------------------------------------------------------------------
581 // Update Assembled Linear QFunction
582 //------------------------------------------------------------------------------
583 static int CeedOperatorLinearAssembleQFunctionUpdate_Opt(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) {
584   return CeedOperatorLinearAssembleQFunctionCore_Opt(op, false, &assembled, &rstr, request);
585 }
586 
587 //------------------------------------------------------------------------------
588 // Operator Destroy
589 //------------------------------------------------------------------------------
590 static int CeedOperatorDestroy_Opt(CeedOperator op) {
591   CeedOperator_Opt *impl;
592   CeedCallBackend(CeedOperatorGetData(op, &impl));
593 
594   for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) {
595     CeedCallBackend(CeedElemRestrictionDestroy(&impl->blk_restr[i]));
596     CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_full[i]));
597   }
598   CeedCallBackend(CeedFree(&impl->blk_restr));
599   CeedCallBackend(CeedFree(&impl->e_vecs_full));
600   CeedCallBackend(CeedFree(&impl->input_states));
601 
602   for (CeedInt i = 0; i < impl->num_inputs; i++) {
603     CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_in[i]));
604     CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i]));
605   }
606   CeedCallBackend(CeedFree(&impl->e_vecs_in));
607   CeedCallBackend(CeedFree(&impl->q_vecs_in));
608 
609   for (CeedInt i = 0; i < impl->num_outputs; i++) {
610     CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_out[i]));
611     CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i]));
612   }
613   CeedCallBackend(CeedFree(&impl->e_vecs_out));
614   CeedCallBackend(CeedFree(&impl->q_vecs_out));
615 
616   // QFunction assembly data
617   for (CeedInt i = 0; i < impl->num_active_in; i++) {
618     CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i]));
619   }
620   CeedCallBackend(CeedFree(&impl->qf_active_in));
621   CeedCallBackend(CeedVectorDestroy(&impl->qf_l_vec));
622   CeedCallBackend(CeedElemRestrictionDestroy(&impl->qf_blk_rstr));
623 
624   CeedCallBackend(CeedFree(&impl));
625   return CEED_ERROR_SUCCESS;
626 }
627 
628 //------------------------------------------------------------------------------
629 // Operator Create
630 //------------------------------------------------------------------------------
631 int CeedOperatorCreate_Opt(CeedOperator op) {
632   Ceed ceed;
633   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
634   Ceed_Opt *ceed_impl;
635   CeedCallBackend(CeedGetData(ceed, &ceed_impl));
636   CeedInt           blk_size = ceed_impl->blk_size;
637   CeedOperator_Opt *impl;
638 
639   CeedCallBackend(CeedCalloc(1, &impl));
640   CeedCallBackend(CeedOperatorSetData(op, impl));
641 
642   CeedCheck(blk_size == 1 || blk_size == 8, ceed, CEED_ERROR_BACKEND, "Opt backend cannot use blocksize: %" CeedInt_FMT, blk_size);
643 
644   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Opt));
645   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Opt));
646   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Opt));
647   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Opt));
648   return CEED_ERROR_SUCCESS;
649 }
650 
651 //------------------------------------------------------------------------------
652