xref: /libCEED/backends/blocked/ceed-blocked-operator.c (revision db2becc9f302fe8eb3a32ace50ce3f3a5d42e6c4)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <stdbool.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 
14 #include "ceed-blocked.h"
15 
16 //------------------------------------------------------------------------------
17 // Setup Input/Output Fields
18 //------------------------------------------------------------------------------
19 static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bool is_input, bool *skip_rstr, const CeedInt block_size,
20                                            CeedElemRestriction *block_rstr, CeedVector *e_vecs_full, CeedVector *e_vecs, CeedVector *q_vecs,
21                                            CeedInt start_e, CeedInt num_fields, CeedInt Q) {
22   Ceed                ceed;
23   CeedSize            e_size, q_size;
24   CeedInt             num_comp, size, P;
25   CeedQFunctionField *qf_fields;
26   CeedOperatorField  *op_fields;
27 
28   {
29     Ceed ceed_parent;
30 
31     CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
32     CeedCallBackend(CeedGetParent(ceed, &ceed_parent));
33     if (ceed_parent) ceed = ceed_parent;
34   }
35   if (is_input) {
36     CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL));
37     CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL));
38   } else {
39     CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields));
40     CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields));
41   }
42 
43   // Loop over fields
44   for (CeedInt i = 0; i < num_fields; i++) {
45     CeedEvalMode eval_mode;
46     CeedBasis    basis;
47 
48     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
49     if (eval_mode != CEED_EVAL_WEIGHT) {
50       Ceed                ceed_rstr;
51       CeedSize            l_size;
52       CeedInt             num_elem, elem_size, comp_stride;
53       CeedRestrictionType rstr_type;
54       CeedElemRestriction rstr;
55 
56       CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr));
57       CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed_rstr));
58       CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem));
59       CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size));
60       CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size));
61       CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp));
62       CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride));
63 
64       CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type));
65       switch (rstr_type) {
66         case CEED_RESTRICTION_STANDARD: {
67           const CeedInt *offsets = NULL;
68 
69           CeedCallBackend(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets));
70           CeedCallBackend(CeedElemRestrictionCreateBlocked(ceed_rstr, num_elem, elem_size, block_size, num_comp, comp_stride, l_size, CEED_MEM_HOST,
71                                                            CEED_COPY_VALUES, offsets, &block_rstr[i + start_e]));
72           CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr, &offsets));
73         } break;
74         case CEED_RESTRICTION_ORIENTED: {
75           const bool    *orients = NULL;
76           const CeedInt *offsets = NULL;
77 
78           CeedCallBackend(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets));
79           CeedCallBackend(CeedElemRestrictionGetOrientations(rstr, CEED_MEM_HOST, &orients));
80           CeedCallBackend(CeedElemRestrictionCreateBlockedOriented(ceed_rstr, num_elem, elem_size, block_size, num_comp, comp_stride, l_size,
81                                                                    CEED_MEM_HOST, CEED_COPY_VALUES, offsets, orients, &block_rstr[i + start_e]));
82           CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr, &offsets));
83           CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr, &orients));
84         } break;
85         case CEED_RESTRICTION_CURL_ORIENTED: {
86           const CeedInt8 *curl_orients = NULL;
87           const CeedInt  *offsets      = NULL;
88 
89           CeedCallBackend(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets));
90           CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr, CEED_MEM_HOST, &curl_orients));
91           CeedCallBackend(CeedElemRestrictionCreateBlockedCurlOriented(ceed_rstr, num_elem, elem_size, block_size, num_comp, comp_stride, l_size,
92                                                                        CEED_MEM_HOST, CEED_COPY_VALUES, offsets, curl_orients,
93                                                                        &block_rstr[i + start_e]));
94           CeedCallBackend(CeedElemRestrictionRestoreOffsets(rstr, &offsets));
95           CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr, &curl_orients));
96         } break;
97         case CEED_RESTRICTION_STRIDED: {
98           CeedInt strides[3];
99 
100           CeedCallBackend(CeedElemRestrictionGetStrides(rstr, strides));
101           CeedCallBackend(CeedElemRestrictionCreateBlockedStrided(ceed_rstr, num_elem, elem_size, block_size, num_comp, l_size, strides,
102                                                                   &block_rstr[i + start_e]));
103         } break;
104         case CEED_RESTRICTION_POINTS:
105           // Empty case - won't occur
106           break;
107       }
108       CeedCallBackend(CeedElemRestrictionCreateVector(block_rstr[i + start_e], NULL, &e_vecs_full[i + start_e]));
109     }
110 
111     switch (eval_mode) {
112       case CEED_EVAL_NONE:
113         CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size));
114         q_size = (CeedSize)Q * size * block_size;
115         CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
116         break;
117       case CEED_EVAL_INTERP:
118       case CEED_EVAL_GRAD:
119       case CEED_EVAL_DIV:
120       case CEED_EVAL_CURL:
121         CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
122         CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size));
123         CeedCallBackend(CeedBasisGetNumNodes(basis, &P));
124         CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
125         e_size = (CeedSize)P * num_comp * block_size;
126         CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i]));
127         q_size = (CeedSize)Q * size * block_size;
128         CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
129         break;
130       case CEED_EVAL_WEIGHT:  // Only on input fields
131         CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
132         q_size = (CeedSize)Q * block_size;
133         CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
134         CeedCallBackend(CeedBasisApply(basis, block_size, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i]));
135         break;
136     }
137   }
138   // Drop duplicate input restrictions
139   if (is_input) {
140     for (CeedInt i = 0; i < num_fields; i++) {
141       CeedVector          vec_i;
142       CeedElemRestriction rstr_i;
143 
144       CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i));
145       CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i));
146       for (CeedInt j = i + 1; j < num_fields; j++) {
147         CeedVector          vec_j;
148         CeedElemRestriction rstr_j;
149 
150         CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j));
151         CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
152         if (vec_i == vec_j && rstr_i == rstr_j) {
153           CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j]));
154           CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i], &e_vecs_full[j]));
155           skip_rstr[j] = true;
156         }
157       }
158     }
159   }
160   return CEED_ERROR_SUCCESS;
161 }
162 
163 //------------------------------------------------------------------------------
164 // Setup Operator
165 //------------------------------------------------------------------------------
166 static int CeedOperatorSetup_Blocked(CeedOperator op) {
167   bool                  is_setup_done;
168   Ceed                  ceed;
169   CeedInt               Q, num_input_fields, num_output_fields;
170   const CeedInt         block_size = 8;
171   CeedQFunctionField   *qf_input_fields, *qf_output_fields;
172   CeedQFunction         qf;
173   CeedOperatorField    *op_input_fields, *op_output_fields;
174   CeedOperator_Blocked *impl;
175 
176   CeedCallBackend(CeedOperatorIsSetupDone(op, &is_setup_done));
177   if (is_setup_done) return CEED_ERROR_SUCCESS;
178 
179   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
180   CeedCallBackend(CeedOperatorGetData(op, &impl));
181   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
182   CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
183   CeedCallBackend(CeedQFunctionIsIdentity(qf, &impl->is_identity_qf));
184   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
185   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
186 
187   // Allocate
188   CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->block_rstr));
189   CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full));
190 
191   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in));
192   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states));
193   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in));
194   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out));
195   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in));
196   CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out));
197 
198   impl->num_inputs  = num_input_fields;
199   impl->num_outputs = num_output_fields;
200 
201   // Set up infield and outfield pointer arrays
202   // Infields
203   CeedCallBackend(CeedOperatorSetupFields_Blocked(qf, op, true, impl->skip_rstr_in, block_size, impl->block_rstr, impl->e_vecs_full, impl->e_vecs_in,
204                                                   impl->q_vecs_in, 0, num_input_fields, Q));
205   // Outfields
206   CeedCallBackend(CeedOperatorSetupFields_Blocked(qf, op, false, NULL, block_size, impl->block_rstr, impl->e_vecs_full, impl->e_vecs_out,
207                                                   impl->q_vecs_out, num_input_fields, num_output_fields, Q));
208 
209   // Identity QFunctions
210   if (impl->is_identity_qf) {
211     CeedEvalMode        in_mode, out_mode;
212     CeedQFunctionField *in_fields, *out_fields;
213 
214     CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &in_fields, NULL, &out_fields));
215     CeedCallBackend(CeedQFunctionFieldGetEvalMode(in_fields[0], &in_mode));
216     CeedCallBackend(CeedQFunctionFieldGetEvalMode(out_fields[0], &out_mode));
217 
218     if (in_mode == CEED_EVAL_NONE && out_mode == CEED_EVAL_NONE) {
219       impl->is_identity_rstr_op = true;
220     } else {
221       CeedCallBackend(CeedVectorReferenceCopy(impl->q_vecs_in[0], &impl->q_vecs_out[0]));
222     }
223   }
224 
225   CeedCallBackend(CeedOperatorSetSetupDone(op));
226   return CEED_ERROR_SUCCESS;
227 }
228 
229 //------------------------------------------------------------------------------
230 // Setup Operator Inputs
231 //------------------------------------------------------------------------------
232 static inline int CeedOperatorSetupInputs_Blocked(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
233                                                   CeedVector in_vec, bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX],
234                                                   CeedOperator_Blocked *impl, CeedRequest *request) {
235   for (CeedInt i = 0; i < num_input_fields; i++) {
236     uint64_t     state;
237     CeedEvalMode eval_mode;
238     CeedVector   vec;
239 
240     // Get input vector
241     CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
242     if (vec == CEED_VECTOR_ACTIVE) {
243       if (skip_active) continue;
244       else vec = in_vec;
245     }
246 
247     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
248     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
249     } else {
250       // Restrict
251       CeedCallBackend(CeedVectorGetState(vec, &state));
252       if ((state != impl->input_states[i] || vec == in_vec) && !impl->skip_rstr_in[i]) {
253         CeedCallBackend(CeedElemRestrictionApply(impl->block_rstr[i], CEED_NOTRANSPOSE, vec, impl->e_vecs_full[i], request));
254       }
255       impl->input_states[i] = state;
256       // Get evec
257       CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST, (const CeedScalar **)&e_data_full[i]));
258     }
259   }
260   return CEED_ERROR_SUCCESS;
261 }
262 
263 //------------------------------------------------------------------------------
264 // Input Basis Action
265 //------------------------------------------------------------------------------
266 static inline int CeedOperatorInputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
267                                                  CeedInt num_input_fields, CeedInt block_size, bool skip_active,
268                                                  CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Blocked *impl) {
269   for (CeedInt i = 0; i < num_input_fields; i++) {
270     CeedInt             elem_size, size, num_comp;
271     CeedEvalMode        eval_mode;
272     CeedElemRestriction elem_rstr;
273     CeedBasis           basis;
274 
275     // Skip active input
276     if (skip_active) {
277       CeedVector vec;
278 
279       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
280       if (vec == CEED_VECTOR_ACTIVE) continue;
281     }
282 
283     // Get elem_size, eval_mode, size
284     CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
285     CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
286     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
287     CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size));
288     // Basis action
289     switch (eval_mode) {
290       case CEED_EVAL_NONE:
291         CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][(CeedSize)e * Q * size]));
292         break;
293       case CEED_EVAL_INTERP:
294       case CEED_EVAL_GRAD:
295       case CEED_EVAL_DIV:
296       case CEED_EVAL_CURL:
297         CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
298         CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
299         CeedCallBackend(CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][(CeedSize)e * elem_size * num_comp]));
300         CeedCallBackend(CeedBasisApply(basis, block_size, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs_in[i], impl->q_vecs_in[i]));
301         break;
302       case CEED_EVAL_WEIGHT:
303         break;  // No action
304     }
305   }
306   return CEED_ERROR_SUCCESS;
307 }
308 
309 //------------------------------------------------------------------------------
310 // Output Basis Action
311 //------------------------------------------------------------------------------
312 static inline int CeedOperatorOutputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields,
313                                                   CeedInt block_size, CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op,
314                                                   CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Blocked *impl) {
315   for (CeedInt i = 0; i < num_output_fields; i++) {
316     CeedInt             elem_size, num_comp;
317     CeedEvalMode        eval_mode;
318     CeedElemRestriction elem_rstr;
319     CeedBasis           basis;
320 
321     // Get elem_size, eval_mode, size
322     CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
323     CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
324     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
325     // Basis action
326     switch (eval_mode) {
327       case CEED_EVAL_NONE:
328         break;  // No action
329       case CEED_EVAL_INTERP:
330       case CEED_EVAL_GRAD:
331       case CEED_EVAL_DIV:
332       case CEED_EVAL_CURL:
333         CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
334         CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
335         CeedCallBackend(CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER,
336                                            &e_data_full[i + num_input_fields][(CeedSize)e * elem_size * num_comp]));
337         CeedCallBackend(CeedBasisApply(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i]));
338         break;
339       // LCOV_EXCL_START
340       case CEED_EVAL_WEIGHT: {
341         return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
342         // LCOV_EXCL_STOP
343       }
344     }
345   }
346   return CEED_ERROR_SUCCESS;
347 }
348 
349 //------------------------------------------------------------------------------
350 // Restore Input Vectors
351 //------------------------------------------------------------------------------
352 static inline int CeedOperatorRestoreInputs_Blocked(CeedInt num_input_fields, CeedQFunctionField *qf_input_fields, CeedOperatorField *op_input_fields,
353                                                     bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Blocked *impl) {
354   for (CeedInt i = 0; i < num_input_fields; i++) {
355     CeedEvalMode eval_mode;
356 
357     // Skip active inputs
358     if (skip_active) {
359       CeedVector vec;
360 
361       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
362       if (vec == CEED_VECTOR_ACTIVE) continue;
363     }
364     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
365     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
366     } else {
367       CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs_full[i], (const CeedScalar **)&e_data_full[i]));
368     }
369   }
370   return CEED_ERROR_SUCCESS;
371 }
372 
373 //------------------------------------------------------------------------------
374 // Operator Apply
375 //------------------------------------------------------------------------------
376 static int CeedOperatorApplyAdd_Blocked(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) {
377   CeedInt               Q, num_input_fields, num_output_fields, num_elem, size;
378   const CeedInt         block_size = 8;
379   CeedEvalMode          eval_mode;
380   CeedScalar           *e_data_full[2 * CEED_FIELD_MAX] = {0};
381   CeedQFunctionField   *qf_input_fields, *qf_output_fields;
382   CeedQFunction         qf;
383   CeedOperatorField    *op_input_fields, *op_output_fields;
384   CeedOperator_Blocked *impl;
385 
386   CeedCallBackend(CeedOperatorGetData(op, &impl));
387   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
388   CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
389   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
390   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
391   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
392   const CeedInt num_blocks = (num_elem / block_size) + !!(num_elem % block_size);
393 
394   // Setup
395   CeedCallBackend(CeedOperatorSetup_Blocked(op));
396 
397   // Restriction only operator
398   if (impl->is_identity_rstr_op) {
399     CeedCallBackend(CeedElemRestrictionApply(impl->block_rstr[0], CEED_NOTRANSPOSE, in_vec, impl->e_vecs_full[0], request));
400     CeedCallBackend(CeedElemRestrictionApply(impl->block_rstr[1], CEED_TRANSPOSE, impl->e_vecs_full[0], out_vec, request));
401     return CEED_ERROR_SUCCESS;
402   }
403 
404   // Input Evecs and Restriction
405   CeedCallBackend(CeedOperatorSetupInputs_Blocked(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data_full, impl, request));
406 
407   // Output Evecs
408   for (CeedInt i = 0; i < num_output_fields; i++) {
409     CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_full[i + impl->num_inputs], CEED_MEM_HOST, &e_data_full[i + num_input_fields]));
410   }
411 
412   // Loop through elements
413   for (CeedInt e = 0; e < num_blocks * block_size; e += block_size) {
414     // Output pointers
415     for (CeedInt i = 0; i < num_output_fields; i++) {
416       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
417       if (eval_mode == CEED_EVAL_NONE) {
418         CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size));
419         CeedCallBackend(
420             CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i + num_input_fields][(CeedSize)e * Q * size]));
421       }
422     }
423 
424     // Input basis apply
425     CeedCallBackend(CeedOperatorInputBasis_Blocked(e, Q, qf_input_fields, op_input_fields, num_input_fields, block_size, false, e_data_full, impl));
426 
427     // Q function
428     if (!impl->is_identity_qf) {
429       CeedCallBackend(CeedQFunctionApply(qf, Q * block_size, impl->q_vecs_in, impl->q_vecs_out));
430     }
431 
432     // Output basis apply
433     CeedCallBackend(CeedOperatorOutputBasis_Blocked(e, Q, qf_output_fields, op_output_fields, block_size, num_input_fields, num_output_fields, op,
434                                                     e_data_full, impl));
435   }
436 
437   // Output restriction
438   for (CeedInt i = 0; i < num_output_fields; i++) {
439     CeedVector vec;
440 
441     // Restore evec
442     CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_full[i + impl->num_inputs], &e_data_full[i + num_input_fields]));
443     // Get output vector
444     CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
445     // Active
446     if (vec == CEED_VECTOR_ACTIVE) vec = out_vec;
447     // Restrict
448     CeedCallBackend(
449         CeedElemRestrictionApply(impl->block_rstr[i + impl->num_inputs], CEED_TRANSPOSE, impl->e_vecs_full[i + impl->num_inputs], vec, request));
450   }
451 
452   // Restore input arrays
453   CeedCallBackend(CeedOperatorRestoreInputs_Blocked(num_input_fields, qf_input_fields, op_input_fields, false, e_data_full, impl));
454   return CEED_ERROR_SUCCESS;
455 }
456 
457 //------------------------------------------------------------------------------
458 // Core code for assembling linear QFunction
459 //------------------------------------------------------------------------------
460 static inline int CeedOperatorLinearAssembleQFunctionCore_Blocked(CeedOperator op, bool build_objects, CeedVector *assembled,
461                                                                   CeedElemRestriction *rstr, CeedRequest *request) {
462   Ceed                  ceed;
463   CeedInt               qf_size_in, qf_size_out, Q, num_input_fields, num_output_fields, num_elem;
464   const CeedInt         block_size = 8;
465   CeedScalar           *l_vec_array;
466   CeedScalar           *e_data_full[2 * CEED_FIELD_MAX] = {0};
467   CeedQFunctionField   *qf_input_fields, *qf_output_fields;
468   CeedQFunction         qf;
469   CeedOperatorField    *op_input_fields, *op_output_fields;
470   CeedOperator_Blocked *impl;
471 
472   CeedCallBackend(CeedOperatorGetData(op, &impl));
473   qf_size_in                     = impl->qf_size_in;
474   qf_size_out                    = impl->qf_size_out;
475   CeedVector          l_vec      = impl->qf_l_vec;
476   CeedElemRestriction block_rstr = impl->qf_block_rstr;
477 
478   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
479   CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
480   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
481   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
482   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
483   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
484   const CeedInt num_blocks = (num_elem / block_size) + !!(num_elem % block_size);
485 
486   // Setup
487   CeedCallBackend(CeedOperatorSetup_Blocked(op));
488 
489   // Check for restriction only operator
490   CeedCheck(!impl->is_identity_rstr_op, ceed, CEED_ERROR_BACKEND, "Assembling restriction only operators is not supported");
491 
492   // Input Evecs and Restriction
493   CeedCallBackend(CeedOperatorSetupInputs_Blocked(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data_full, impl, request));
494 
495   // Count number of active input fields
496   if (qf_size_in == 0) {
497     for (CeedInt i = 0; i < num_input_fields; i++) {
498       CeedInt    field_size;
499       CeedVector vec;
500 
501       // Get input vector
502       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
503       // Check if active input
504       if (vec == CEED_VECTOR_ACTIVE) {
505         CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size));
506         CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0));
507         qf_size_in += field_size;
508       }
509     }
510     CeedCheck(qf_size_in > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs");
511     impl->qf_size_in = qf_size_in;
512   }
513 
514   // Count number of active output fields
515   if (qf_size_out == 0) {
516     for (CeedInt i = 0; i < num_output_fields; i++) {
517       CeedInt    field_size;
518       CeedVector vec;
519 
520       // Get output vector
521       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
522       // Check if active output
523       if (vec == CEED_VECTOR_ACTIVE) {
524         CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &field_size));
525         qf_size_out += field_size;
526       }
527     }
528     CeedCheck(qf_size_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs");
529     impl->qf_size_out = qf_size_out;
530   }
531 
532   // Setup Lvec
533   if (!l_vec) {
534     const CeedSize l_size = (CeedSize)num_blocks * block_size * Q * qf_size_in * qf_size_out;
535 
536     CeedCallBackend(CeedVectorCreate(ceed, l_size, &l_vec));
537     impl->qf_l_vec = l_vec;
538   }
539   CeedCallBackend(CeedVectorGetArrayWrite(l_vec, CEED_MEM_HOST, &l_vec_array));
540 
541   // Setup block restriction
542   if (!block_rstr) {
543     const CeedInt strides[3] = {1, Q, qf_size_in * qf_size_out * Q};
544 
545     CeedCallBackend(CeedElemRestrictionCreateBlockedStrided(ceed, num_elem, Q, block_size, qf_size_in * qf_size_out,
546                                                             qf_size_in * qf_size_out * num_elem * Q, strides, &block_rstr));
547     impl->qf_block_rstr = block_rstr;
548   }
549 
550   // Build objects if needed
551   if (build_objects) {
552     const CeedSize l_size     = (CeedSize)num_elem * Q * qf_size_in * qf_size_out;
553     const CeedInt  strides[3] = {1, Q, qf_size_in * qf_size_out * Q};
554 
555     // Create output restriction
556     CeedCallBackend(CeedElemRestrictionCreateStrided(ceed, num_elem, Q, qf_size_in * qf_size_out,
557                                                      (CeedSize)qf_size_in * (CeedSize)qf_size_out * (CeedSize)num_elem * (CeedSize)Q, strides, rstr));
558     // Create assembled vector
559     CeedCallBackend(CeedVectorCreate(ceed, l_size, assembled));
560   }
561 
562   // Loop through elements
563   for (CeedInt e = 0; e < num_blocks * block_size; e += block_size) {
564     // Input basis apply
565     CeedCallBackend(CeedOperatorInputBasis_Blocked(e, Q, qf_input_fields, op_input_fields, num_input_fields, block_size, true, e_data_full, impl));
566 
567     // Assemble QFunction
568     for (CeedInt i = 0; i < num_input_fields; i++) {
569       CeedInt    field_size;
570       CeedVector vec;
571 
572       // Get input vector
573       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
574       // Check if active input
575       if (vec != CEED_VECTOR_ACTIVE) continue;
576       CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size));
577       for (CeedInt field = 0; field < field_size; field++) {
578         // Set current portion of input to 1.0
579         {
580           CeedScalar *array;
581 
582           CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &array));
583           for (CeedInt j = 0; j < Q * block_size; j++) array[field * Q * block_size + j] = 1.0;
584           CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &array));
585         }
586 
587         if (!impl->is_identity_qf) {
588           // Set Outputs
589           for (CeedInt out = 0; out < num_output_fields; out++) {
590             CeedInt    field_size;
591             CeedVector vec;
592 
593             // Get output vector
594             CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec));
595             // Check if active output
596             if (vec == CEED_VECTOR_ACTIVE) {
597               CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[out], CEED_MEM_HOST, CEED_USE_POINTER, l_vec_array));
598               CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &field_size));
599               l_vec_array += field_size * Q * block_size;  // Advance the pointer by the size of the output
600             }
601           }
602           // Apply QFunction
603           CeedCallBackend(CeedQFunctionApply(qf, Q * block_size, impl->q_vecs_in, impl->q_vecs_out));
604         } else {
605           CeedInt           field_size;
606           const CeedScalar *array;
607 
608           // Copy Identity Outputs
609           CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[0], &field_size));
610           CeedCallBackend(CeedVectorGetArrayRead(impl->q_vecs_out[0], CEED_MEM_HOST, &array));
611           for (CeedInt j = 0; j < field_size * Q * block_size; j++) l_vec_array[j] = array[j];
612           CeedCallBackend(CeedVectorRestoreArrayRead(impl->q_vecs_out[0], &array));
613           l_vec_array += field_size * Q * block_size;
614         }
615         // Reset input to 0.0
616         {
617           CeedScalar *array;
618 
619           CeedCallBackend(CeedVectorGetArray(impl->q_vecs_in[i], CEED_MEM_HOST, &array));
620           for (CeedInt j = 0; j < Q * block_size; j++) array[field * Q * block_size + j] = 0.0;
621           CeedCallBackend(CeedVectorRestoreArray(impl->q_vecs_in[i], &array));
622         }
623       }
624     }
625   }
626 
627   // Un-set output Qvecs to prevent accidental overwrite of Assembled
628   if (!impl->is_identity_qf) {
629     for (CeedInt out = 0; out < num_output_fields; out++) {
630       CeedVector vec;
631 
632       // Get output vector
633       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec));
634       // Check if active output
635       if (vec == CEED_VECTOR_ACTIVE) {
636         CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_HOST, NULL));
637       }
638     }
639   }
640 
641   // Restore input arrays
642   CeedCallBackend(CeedOperatorRestoreInputs_Blocked(num_input_fields, qf_input_fields, op_input_fields, true, e_data_full, impl));
643 
644   // Output blocked restriction
645   CeedCallBackend(CeedVectorRestoreArray(l_vec, &l_vec_array));
646   CeedCallBackend(CeedVectorSetValue(*assembled, 0.0));
647   CeedCallBackend(CeedElemRestrictionApply(block_rstr, CEED_TRANSPOSE, l_vec, *assembled, request));
648   return CEED_ERROR_SUCCESS;
649 }
650 
651 //------------------------------------------------------------------------------
652 // Assemble Linear QFunction
653 //------------------------------------------------------------------------------
654 static int CeedOperatorLinearAssembleQFunction_Blocked(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) {
655   return CeedOperatorLinearAssembleQFunctionCore_Blocked(op, true, assembled, rstr, request);
656 }
657 
658 //------------------------------------------------------------------------------
659 // Update Assembled Linear QFunction
660 //------------------------------------------------------------------------------
661 static int CeedOperatorLinearAssembleQFunctionUpdate_Blocked(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) {
662   return CeedOperatorLinearAssembleQFunctionCore_Blocked(op, false, &assembled, &rstr, request);
663 }
664 
665 //------------------------------------------------------------------------------
666 // Operator Destroy
667 //------------------------------------------------------------------------------
668 static int CeedOperatorDestroy_Blocked(CeedOperator op) {
669   CeedOperator_Blocked *impl;
670 
671   CeedCallBackend(CeedOperatorGetData(op, &impl));
672 
673   CeedCallBackend(CeedFree(&impl->skip_rstr_in));
674   for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) {
675     CeedCallBackend(CeedElemRestrictionDestroy(&impl->block_rstr[i]));
676     CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_full[i]));
677   }
678   CeedCallBackend(CeedFree(&impl->block_rstr));
679   CeedCallBackend(CeedFree(&impl->e_vecs_full));
680   CeedCallBackend(CeedFree(&impl->input_states));
681 
682   for (CeedInt i = 0; i < impl->num_inputs; i++) {
683     CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_in[i]));
684     CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_in[i]));
685   }
686   CeedCallBackend(CeedFree(&impl->e_vecs_in));
687   CeedCallBackend(CeedFree(&impl->q_vecs_in));
688 
689   for (CeedInt i = 0; i < impl->num_outputs; i++) {
690     CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_out[i]));
691     CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i]));
692   }
693   CeedCallBackend(CeedFree(&impl->e_vecs_out));
694   CeedCallBackend(CeedFree(&impl->q_vecs_out));
695 
696   // QFunction assembly data
697   CeedCallBackend(CeedVectorDestroy(&impl->qf_l_vec));
698   CeedCallBackend(CeedElemRestrictionDestroy(&impl->qf_block_rstr));
699 
700   CeedCallBackend(CeedFree(&impl));
701   return CEED_ERROR_SUCCESS;
702 }
703 
704 //------------------------------------------------------------------------------
705 // Operator Create
706 //------------------------------------------------------------------------------
707 int CeedOperatorCreate_Blocked(CeedOperator op) {
708   Ceed                  ceed;
709   CeedOperator_Blocked *impl;
710 
711   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
712   CeedCallBackend(CeedCalloc(1, &impl));
713   CeedCallBackend(CeedOperatorSetData(op, impl));
714   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Blocked));
715   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Blocked));
716   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Blocked));
717   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Blocked));
718   return CEED_ERROR_SUCCESS;
719 }
720 
721 //------------------------------------------------------------------------------
722