xref: /libCEED/backends/hip-gen/ceed-hip-gen-operator.c (revision 0183ed61035d97ff853cf8c8e722c0fda76e54df)
1 // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-source/hip/hip-types.h>
11 #include <stddef.h>
12 #include <hip/hiprtc.h>
13 
14 #include "../hip/ceed-hip-common.h"
15 #include "../hip/ceed-hip-compile.h"
16 #include "ceed-hip-gen-operator-build.h"
17 #include "ceed-hip-gen.h"
18 
19 //------------------------------------------------------------------------------
20 // Destroy operator
21 //------------------------------------------------------------------------------
22 static int CeedOperatorDestroy_Hip_gen(CeedOperator op) {
23   Ceed                  ceed;
24   CeedOperator_Hip_gen *impl;
25   bool                  is_composite;
26 
27   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
28   CeedCallBackend(CeedOperatorGetData(op, &impl));
29   CeedCallBackend(CeedOperatorIsComposite(op, &is_composite));
30   if (is_composite) {
31     CeedInt num_suboperators;
32 
33     CeedCall(CeedCompositeOperatorGetNumSub(op, &num_suboperators));
34     for (CeedInt i = 0; i < num_suboperators; i++) {
35       if (impl->streams[i]) CeedCallHip(ceed, hipStreamDestroy(impl->streams[i]));
36       impl->streams[i] = NULL;
37     }
38   }
39   if (impl->module) CeedCallHip(ceed, hipModuleUnload(impl->module));
40   if (impl->module_assemble_full) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_full));
41   if (impl->module_assemble_diagonal) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_diagonal));
42   if (impl->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)impl->points.num_per_elem));
43   CeedCallBackend(CeedFree(&impl));
44   CeedCallBackend(CeedDestroy(&ceed));
45   return CEED_ERROR_SUCCESS;
46 }
47 
48 //------------------------------------------------------------------------------
49 // Apply and add to output
50 //------------------------------------------------------------------------------
51 static int CeedOperatorApplyAddCore_Hip_gen(CeedOperator op, hipStream_t stream, const CeedScalar *input_arr, CeedScalar *output_arr,
52                                             bool *is_run_good, CeedRequest *request) {
53   bool                   is_at_points, is_tensor;
54   Ceed                   ceed;
55   CeedInt                num_elem, num_input_fields, num_output_fields;
56   CeedEvalMode           eval_mode;
57   CeedQFunctionField    *qf_input_fields, *qf_output_fields;
58   CeedQFunction_Hip_gen *qf_data;
59   CeedQFunction          qf;
60   CeedOperatorField     *op_input_fields, *op_output_fields;
61   CeedOperator_Hip_gen  *data;
62 
63   // Creation of the operator
64   CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, is_run_good));
65   if (!(*is_run_good)) return CEED_ERROR_SUCCESS;
66 
67   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
68   CeedCallBackend(CeedOperatorGetData(op, &data));
69   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
70   CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
71   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
72   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
73   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
74 
75   // Input vectors
76   for (CeedInt i = 0; i < num_input_fields; i++) {
77     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
78     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
79       data->fields.inputs[i] = NULL;
80     } else {
81       bool       is_active;
82       CeedVector vec;
83 
84       // Get input vector
85       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
86       is_active = vec == CEED_VECTOR_ACTIVE;
87       if (is_active) data->fields.inputs[i] = input_arr;
88       else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
89       CeedCallBackend(CeedVectorDestroy(&vec));
90     }
91   }
92 
93   // Output vectors
94   for (CeedInt i = 0; i < num_output_fields; i++) {
95     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
96     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
97       data->fields.outputs[i] = NULL;
98     } else {
99       bool       is_active;
100       CeedVector vec;
101 
102       // Get output vector
103       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
104       is_active = vec == CEED_VECTOR_ACTIVE;
105       if (is_active) data->fields.outputs[i] = output_arr;
106       else CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.outputs[i]));
107       CeedCallBackend(CeedVectorDestroy(&vec));
108     }
109   }
110 
111   // Point coordinates, if needed
112   CeedCallBackend(CeedOperatorIsAtPoints(op, &is_at_points));
113   if (is_at_points) {
114     // Coords
115     CeedVector vec;
116 
117     CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
118     CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords));
119     CeedCallBackend(CeedVectorDestroy(&vec));
120 
121     // Points per elem
122     if (num_elem != data->points.num_elem) {
123       CeedInt            *points_per_elem;
124       const CeedInt       num_bytes   = num_elem * sizeof(CeedInt);
125       CeedElemRestriction rstr_points = NULL;
126 
127       data->points.num_elem = num_elem;
128       CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL));
129       CeedCallBackend(CeedCalloc(num_elem, &points_per_elem));
130       for (CeedInt e = 0; e < num_elem; e++) {
131         CeedInt num_points_elem;
132 
133         CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem));
134         points_per_elem[e] = num_points_elem;
135       }
136       if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem));
137       CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes));
138       CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice));
139       CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
140       CeedCallBackend(CeedFree(&points_per_elem));
141     }
142   }
143 
144   // Get context data
145   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
146 
147   // Apply operator
148   void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points};
149 
150   CeedCallBackend(CeedOperatorHasTensorBases(op, &is_tensor));
151   CeedInt block_sizes[3] = {data->thread_1d, ((!is_tensor || data->dim == 1) ? 1 : data->thread_1d), -1};
152 
153   if (is_tensor) {
154     CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes));
155     if (is_at_points) block_sizes[2] = 1;
156   } else {
157     CeedInt elems_per_block = 64 * data->thread_1d > 256 ? 256 / data->thread_1d : 64;
158 
159     elems_per_block = elems_per_block > 0 ? elems_per_block : 1;
160     block_sizes[2]  = elems_per_block;
161   }
162   if (data->dim == 1 || !is_tensor) {
163     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
164     CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar);
165 
166     CeedCallBackend(
167         CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, is_run_good, opargs));
168   } else if (data->dim == 2) {
169     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
170     CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
171 
172     CeedCallBackend(
173         CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, is_run_good, opargs));
174   } else if (data->dim == 3) {
175     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
176     CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
177 
178     CeedCallBackend(
179         CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, is_run_good, opargs));
180   }
181 
182   // Restore input arrays
183   for (CeedInt i = 0; i < num_input_fields; i++) {
184     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
185     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
186     } else {
187       bool       is_active;
188       CeedVector vec;
189 
190       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
191       is_active = vec == CEED_VECTOR_ACTIVE;
192       if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
193       CeedCallBackend(CeedVectorDestroy(&vec));
194     }
195   }
196 
197   // Restore output arrays
198   for (CeedInt i = 0; i < num_output_fields; i++) {
199     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
200     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
201     } else {
202       bool       is_active;
203       CeedVector vec;
204 
205       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
206       is_active = vec == CEED_VECTOR_ACTIVE;
207       if (!is_active) CeedCallBackend(CeedVectorRestoreArray(vec, &data->fields.outputs[i]));
208       CeedCallBackend(CeedVectorDestroy(&vec));
209     }
210   }
211 
212   // Restore point coordinates, if needed
213   if (is_at_points) {
214     CeedVector vec;
215 
216     CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
217     CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords));
218     CeedCallBackend(CeedVectorDestroy(&vec));
219   }
220 
221   // Restore context data
222   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
223 
224   // Cleanup
225   CeedCallBackend(CeedDestroy(&ceed));
226   CeedCallBackend(CeedQFunctionDestroy(&qf));
227   if (!(*is_run_good)) data->use_fallback = true;
228   return CEED_ERROR_SUCCESS;
229 }
230 
231 static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
232   bool              is_run_good = false;
233   const CeedScalar *input_arr   = NULL;
234   CeedScalar       *output_arr  = NULL;
235 
236   // Try to run kernel
237   if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr));
238   if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr));
239   CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(op, NULL, input_arr, output_arr, &is_run_good, request));
240   if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr));
241   if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr));
242 
243   // Fallback on unsuccessful run
244   if (!is_run_good) {
245     CeedOperator op_fallback;
246 
247     CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator");
248     CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
249     CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
250   }
251   return CEED_ERROR_SUCCESS;
252 }
253 
254 static int CeedOperatorApplyAddComposite_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
255   bool                  is_run_good[CEED_COMPOSITE_MAX] = {true};
256   CeedInt               num_suboperators;
257   const CeedScalar     *input_arr = NULL;
258   CeedScalar           *output_arr;
259   Ceed                  ceed;
260   CeedOperator_Hip_gen *impl;
261   CeedOperator         *sub_operators;
262 
263   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
264   CeedCallBackend(CeedOperatorGetData(op, &impl));
265   CeedCallBackend(CeedCompositeOperatorGetNumSub(op, &num_suboperators));
266   CeedCallBackend(CeedCompositeOperatorGetSubList(op, &sub_operators));
267   if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr));
268   if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr));
269   for (CeedInt i = 0; i < num_suboperators; i++) {
270     CeedInt num_elem = 0;
271 
272     CeedCallBackend(CeedOperatorGetNumElements(sub_operators[i], &num_elem));
273     if (num_elem > 0) {
274       if (!impl->streams[i]) CeedCallHip(ceed, hipStreamCreate(&impl->streams[i]));
275       CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(sub_operators[i], impl->streams[i], input_arr, output_arr, &is_run_good[i], request));
276     } else {
277       is_run_good[i] = true;
278     }
279   }
280 
281   for (CeedInt i = 0; i < num_suboperators; i++) {
282     if (impl->streams[i]) {
283       if (is_run_good[i]) CeedCallHip(ceed, hipStreamSynchronize(impl->streams[i]));
284     }
285   }
286   if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr));
287   if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr));
288   CeedCallHip(ceed, hipDeviceSynchronize());
289 
290   // Fallback on unsuccessful run
291   for (CeedInt i = 0; i < num_suboperators; i++) {
292     if (!is_run_good[i]) {
293       CeedOperator op_fallback;
294 
295       CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator");
296       CeedCallBackend(CeedOperatorGetFallback(sub_operators[i], &op_fallback));
297       CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
298     }
299   }
300   CeedCallBackend(CeedDestroy(&ceed));
301   return CEED_ERROR_SUCCESS;
302 }
303 
304 //------------------------------------------------------------------------------
305 // AtPoints diagonal assembly
306 //------------------------------------------------------------------------------
307 static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip_gen(CeedOperator op, CeedVector assembled, CeedRequest *request) {
308   Ceed                  ceed;
309   CeedOperator_Hip_gen *data;
310 
311   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
312   CeedCallBackend(CeedOperatorGetData(op, &data));
313 
314   // Build the assembly kernel
315   if (!data->assemble_diagonal && !data->use_assembly_fallback) {
316     bool                     is_build_good = false;
317     CeedInt                  num_active_bases_in, num_active_bases_out;
318     CeedOperatorAssemblyData assembly_data;
319 
320     CeedCallBackend(CeedOperatorGetOperatorAssemblyData(op, &assembly_data));
321     CeedCallBackend(
322         CeedOperatorAssemblyDataGetEvalModes(assembly_data, &num_active_bases_in, NULL, NULL, NULL, &num_active_bases_out, NULL, NULL, NULL, NULL));
323     if (num_active_bases_in == num_active_bases_out) {
324       CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_build_good));
325       if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelDiagonalAssemblyAtPoints_Hip_gen(op, &is_build_good));
326     }
327     if (!is_build_good) data->use_assembly_fallback = true;
328   }
329 
330   // Try assembly
331   if (!data->use_assembly_fallback) {
332     bool                   is_run_good = true;
333     Ceed_Hip              *hip_data;
334     CeedInt                num_elem, num_input_fields, num_output_fields;
335     CeedEvalMode           eval_mode;
336     CeedScalar            *assembled_array;
337     CeedQFunctionField    *qf_input_fields, *qf_output_fields;
338     CeedQFunction_Hip_gen *qf_data;
339     CeedQFunction          qf;
340     CeedOperatorField     *op_input_fields, *op_output_fields;
341 
342     CeedCallBackend(CeedGetData(ceed, &hip_data));
343     CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
344     CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
345     CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
346     CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
347     CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
348 
349     // Input vectors
350     for (CeedInt i = 0; i < num_input_fields; i++) {
351       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
352       if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
353         data->fields.inputs[i] = NULL;
354       } else {
355         bool       is_active;
356         CeedVector vec;
357 
358         // Get input vector
359         CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
360         is_active = vec == CEED_VECTOR_ACTIVE;
361         if (is_active) data->fields.inputs[i] = NULL;
362         else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
363         CeedCallBackend(CeedVectorDestroy(&vec));
364       }
365     }
366 
367     // Point coordinates
368     {
369       CeedVector vec;
370 
371       CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
372       CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords));
373       CeedCallBackend(CeedVectorDestroy(&vec));
374 
375       // Points per elem
376       if (num_elem != data->points.num_elem) {
377         CeedInt            *points_per_elem;
378         const CeedInt       num_bytes   = num_elem * sizeof(CeedInt);
379         CeedElemRestriction rstr_points = NULL;
380 
381         data->points.num_elem = num_elem;
382         CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL));
383         CeedCallBackend(CeedCalloc(num_elem, &points_per_elem));
384         for (CeedInt e = 0; e < num_elem; e++) {
385           CeedInt num_points_elem;
386 
387           CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem));
388           points_per_elem[e] = num_points_elem;
389         }
390         if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem));
391         CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes));
392         CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice));
393         CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
394         CeedCallBackend(CeedFree(&points_per_elem));
395       }
396     }
397 
398     // Get context data
399     CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
400 
401     // Assembly array
402     CeedCallBackend(CeedVectorGetArray(assembled, CEED_MEM_DEVICE, &assembled_array));
403 
404     // Assemble diagonal
405     void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points, &assembled_array};
406 
407     CeedInt block_sizes[3] = {data->thread_1d, (data->dim == 1 ? 1 : data->thread_1d), -1};
408 
409     CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes));
410     block_sizes[2] = 1;
411     if (data->dim == 1) {
412       CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
413       CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar);
414 
415       CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2],
416                                                     sharedMem, &is_run_good, opargs));
417     } else if (data->dim == 2) {
418       CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
419       CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
420 
421       CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2],
422                                                     sharedMem, &is_run_good, opargs));
423     } else if (data->dim == 3) {
424       CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
425       CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
426 
427       CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2],
428                                                     sharedMem, &is_run_good, opargs));
429     }
430 
431     // Restore input arrays
432     for (CeedInt i = 0; i < num_input_fields; i++) {
433       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
434       if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
435       } else {
436         bool       is_active;
437         CeedVector vec;
438 
439         CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
440         is_active = vec == CEED_VECTOR_ACTIVE;
441         if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
442         CeedCallBackend(CeedVectorDestroy(&vec));
443       }
444     }
445 
446     // Restore point coordinates
447     {
448       CeedVector vec;
449 
450       CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
451       CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords));
452       CeedCallBackend(CeedVectorDestroy(&vec));
453     }
454 
455     // Restore context data
456     CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
457 
458     // Restore assembly array
459     CeedCallBackend(CeedVectorRestoreArray(assembled, &assembled_array));
460 
461     // Cleanup
462     CeedCallBackend(CeedQFunctionDestroy(&qf));
463     if (!is_run_good) data->use_assembly_fallback = true;
464   }
465   CeedCallBackend(CeedDestroy(&ceed));
466 
467   // Fallback, if needed
468   if (data->use_assembly_fallback) {
469     CeedOperator op_fallback;
470 
471     CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator");
472     CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
473     CeedCallBackend(CeedOperatorLinearAssembleAddDiagonal(op_fallback, assembled, request));
474     return CEED_ERROR_SUCCESS;
475   }
476   return CEED_ERROR_SUCCESS;
477 }
478 
479 //------------------------------------------------------------------------------
480 // Create operator
481 //------------------------------------------------------------------------------
482 int CeedOperatorCreate_Hip_gen(CeedOperator op) {
483   bool                  is_composite, is_at_points;
484   Ceed                  ceed;
485   CeedOperator_Hip_gen *impl;
486 
487   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
488   CeedCallBackend(CeedCalloc(1, &impl));
489   CeedCallBackend(CeedOperatorSetData(op, impl));
490   CeedCall(CeedOperatorIsComposite(op, &is_composite));
491   if (is_composite) {
492     CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAddComposite", CeedOperatorApplyAddComposite_Hip_gen));
493   } else {
494     CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip_gen));
495   }
496   CeedCall(CeedOperatorIsAtPoints(op, &is_at_points));
497   if (is_at_points) {
498     CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip_gen));
499   }
500   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip_gen));
501   CeedCallBackend(CeedDestroy(&ceed));
502   return CEED_ERROR_SUCCESS;
503 }
504 
505 //------------------------------------------------------------------------------
506