xref: /libCEED/backends/hip-gen/ceed-hip-gen-operator.c (revision a49e5d53e180225109bfad71df325c7cfa170c69)
19ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
37d8d0e25Snbeams //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
57d8d0e25Snbeams //
63d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
77d8d0e25Snbeams 
849aac155SJeremy L Thompson #include <ceed.h>
9ec3da8bcSJed Brown #include <ceed/backend.h>
1049aac155SJeremy L Thompson #include <ceed/jit-source/hip/hip-types.h>
113d576824SJeremy L Thompson #include <stddef.h>
123a2968d6SJeremy L Thompson #include <hip/hiprtc.h>
132b730f8bSJeremy L Thompson 
14b2165e7aSSebastian Grimberg #include "../hip/ceed-hip-common.h"
157d8d0e25Snbeams #include "../hip/ceed-hip-compile.h"
162b730f8bSJeremy L Thompson #include "ceed-hip-gen-operator-build.h"
172b730f8bSJeremy L Thompson #include "ceed-hip-gen.h"
187d8d0e25Snbeams 
197d8d0e25Snbeams //------------------------------------------------------------------------------
207d8d0e25Snbeams // Destroy operator
217d8d0e25Snbeams //------------------------------------------------------------------------------
CeedOperatorDestroy_Hip_gen(CeedOperator op)227d8d0e25Snbeams static int CeedOperatorDestroy_Hip_gen(CeedOperator op) {
233a2968d6SJeremy L Thompson   Ceed                  ceed;
247d8d0e25Snbeams   CeedOperator_Hip_gen *impl;
256eee1ffcSZach Atkins   bool                  is_composite;
26b7453713SJeremy L Thompson 
273a2968d6SJeremy L Thompson   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
282b730f8bSJeremy L Thompson   CeedCallBackend(CeedOperatorGetData(op, &impl));
296eee1ffcSZach Atkins   CeedCallBackend(CeedOperatorIsComposite(op, &is_composite));
306eee1ffcSZach Atkins   if (is_composite) {
316eee1ffcSZach Atkins     CeedInt num_suboperators;
326eee1ffcSZach Atkins 
33ed094490SJeremy L Thompson     CeedCall(CeedOperatorCompositeGetNumSub(op, &num_suboperators));
346eee1ffcSZach Atkins     for (CeedInt i = 0; i < num_suboperators; i++) {
356eee1ffcSZach Atkins       if (impl->streams[i]) CeedCallHip(ceed, hipStreamDestroy(impl->streams[i]));
366eee1ffcSZach Atkins       impl->streams[i] = NULL;
376eee1ffcSZach Atkins     }
386eee1ffcSZach Atkins   }
398b7d3340SJeremy L Thompson   if (impl->module) CeedCallHip(ceed, hipModuleUnload(impl->module));
400183ed61SJeremy L Thompson   if (impl->module_assemble_full) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_full));
410183ed61SJeremy L Thompson   if (impl->module_assemble_diagonal) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_diagonal));
425daefc96SJeremy L Thompson   if (impl->module_assemble_qfunction) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_qfunction));
433a2968d6SJeremy L Thompson   if (impl->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)impl->points.num_per_elem));
442b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&impl));
453a2968d6SJeremy L Thompson   CeedCallBackend(CeedDestroy(&ceed));
46e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
477d8d0e25Snbeams }
487d8d0e25Snbeams 
497d8d0e25Snbeams //------------------------------------------------------------------------------
507d8d0e25Snbeams // Apply and add to output
517d8d0e25Snbeams //------------------------------------------------------------------------------
CeedOperatorApplyAddCore_Hip_gen(CeedOperator op,hipStream_t stream,const CeedScalar * input_arr,CeedScalar * output_arr,bool * is_run_good,CeedRequest * request)52e9c76bddSJeremy L Thompson static int CeedOperatorApplyAddCore_Hip_gen(CeedOperator op, hipStream_t stream, const CeedScalar *input_arr, CeedScalar *output_arr,
53e9c76bddSJeremy L Thompson                                             bool *is_run_good, CeedRequest *request) {
54ea04d07fSJeremy L Thompson   bool                   is_at_points, is_tensor;
557d8d0e25Snbeams   Ceed                   ceed;
56b7453713SJeremy L Thompson   CeedInt                num_elem, num_input_fields, num_output_fields;
57b7453713SJeremy L Thompson   CeedEvalMode           eval_mode;
58b7453713SJeremy L Thompson   CeedQFunctionField    *qf_input_fields, *qf_output_fields;
597d8d0e25Snbeams   CeedQFunction_Hip_gen *qf_data;
60b7453713SJeremy L Thompson   CeedQFunction          qf;
61b7453713SJeremy L Thompson   CeedOperatorField     *op_input_fields, *op_output_fields;
62b7453713SJeremy L Thompson   CeedOperator_Hip_gen  *data;
63b7453713SJeremy L Thompson 
648d12f40eSJeremy L Thompson   // Creation of the operator
65ea04d07fSJeremy L Thompson   CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, is_run_good));
66ea04d07fSJeremy L Thompson   if (!(*is_run_good)) return CEED_ERROR_SUCCESS;
67f6eafd79SJeremy L Thompson 
68c11e12f4SJeremy L Thompson   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
69c11e12f4SJeremy L Thompson   CeedCallBackend(CeedOperatorGetData(op, &data));
70c11e12f4SJeremy L Thompson   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
71c11e12f4SJeremy L Thompson   CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
72c11e12f4SJeremy L Thompson   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
738d12f40eSJeremy L Thompson   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
74c11e12f4SJeremy L Thompson   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
75c11e12f4SJeremy L Thompson 
767d8d0e25Snbeams   // Input vectors
779e201c85SYohann   for (CeedInt i = 0; i < num_input_fields; i++) {
782b730f8bSJeremy L Thompson     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
799e201c85SYohann     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
809e201c85SYohann       data->fields.inputs[i] = NULL;
817d8d0e25Snbeams     } else {
823efc994bSJeremy L Thompson       bool       is_active;
83b7453713SJeremy L Thompson       CeedVector vec;
84b7453713SJeremy L Thompson 
857d8d0e25Snbeams       // Get input vector
862b730f8bSJeremy L Thompson       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
873efc994bSJeremy L Thompson       is_active = vec == CEED_VECTOR_ACTIVE;
88ea04d07fSJeremy L Thompson       if (is_active) data->fields.inputs[i] = input_arr;
89ea04d07fSJeremy L Thompson       else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
90ea04d07fSJeremy L Thompson       CeedCallBackend(CeedVectorDestroy(&vec));
917d8d0e25Snbeams     }
927d8d0e25Snbeams   }
937d8d0e25Snbeams 
947d8d0e25Snbeams   // Output vectors
959e201c85SYohann   for (CeedInt i = 0; i < num_output_fields; i++) {
962b730f8bSJeremy L Thompson     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
979e201c85SYohann     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
989e201c85SYohann       data->fields.outputs[i] = NULL;
997d8d0e25Snbeams     } else {
1003efc994bSJeremy L Thompson       bool       is_active;
101b7453713SJeremy L Thompson       CeedVector vec;
102b7453713SJeremy L Thompson 
1037d8d0e25Snbeams       // Get output vector
1042b730f8bSJeremy L Thompson       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
1053efc994bSJeremy L Thompson       is_active = vec == CEED_VECTOR_ACTIVE;
106ea04d07fSJeremy L Thompson       if (is_active) data->fields.outputs[i] = output_arr;
1070c8fbeedSJeremy L Thompson       else CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.outputs[i]));
108ea04d07fSJeremy L Thompson       CeedCallBackend(CeedVectorDestroy(&vec));
1097d8d0e25Snbeams     }
1107d8d0e25Snbeams   }
1117d8d0e25Snbeams 
1123a2968d6SJeremy L Thompson   // Point coordinates, if needed
1133a2968d6SJeremy L Thompson   CeedCallBackend(CeedOperatorIsAtPoints(op, &is_at_points));
1143a2968d6SJeremy L Thompson   if (is_at_points) {
1153a2968d6SJeremy L Thompson     // Coords
1163a2968d6SJeremy L Thompson     CeedVector vec;
1173a2968d6SJeremy L Thompson 
1183a2968d6SJeremy L Thompson     CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
1193a2968d6SJeremy L Thompson     CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords));
1203a2968d6SJeremy L Thompson     CeedCallBackend(CeedVectorDestroy(&vec));
1213a2968d6SJeremy L Thompson 
1223a2968d6SJeremy L Thompson     // Points per elem
1233a2968d6SJeremy L Thompson     if (num_elem != data->points.num_elem) {
1243a2968d6SJeremy L Thompson       CeedInt            *points_per_elem;
1253a2968d6SJeremy L Thompson       const CeedInt       num_bytes   = num_elem * sizeof(CeedInt);
1263a2968d6SJeremy L Thompson       CeedElemRestriction rstr_points = NULL;
1273a2968d6SJeremy L Thompson 
1283a2968d6SJeremy L Thompson       data->points.num_elem = num_elem;
1293a2968d6SJeremy L Thompson       CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL));
1303a2968d6SJeremy L Thompson       CeedCallBackend(CeedCalloc(num_elem, &points_per_elem));
1313a2968d6SJeremy L Thompson       for (CeedInt e = 0; e < num_elem; e++) {
1323a2968d6SJeremy L Thompson         CeedInt num_points_elem;
1333a2968d6SJeremy L Thompson 
1343a2968d6SJeremy L Thompson         CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem));
1353a2968d6SJeremy L Thompson         points_per_elem[e] = num_points_elem;
1363a2968d6SJeremy L Thompson       }
1373a2968d6SJeremy L Thompson       if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem));
1383a2968d6SJeremy L Thompson       CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes));
1393a2968d6SJeremy L Thompson       CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice));
1403a2968d6SJeremy L Thompson       CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
1413a2968d6SJeremy L Thompson       CeedCallBackend(CeedFree(&points_per_elem));
1423a2968d6SJeremy L Thompson     }
1433a2968d6SJeremy L Thompson   }
1443a2968d6SJeremy L Thompson 
1457d8d0e25Snbeams   // Get context data
1462b730f8bSJeremy L Thompson   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
1477d8d0e25Snbeams 
1487d8d0e25Snbeams   // Apply operator
1493a2968d6SJeremy L Thompson   void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points};
150b7453713SJeremy L Thompson 
1519123fb08SJeremy L Thompson   CeedCallBackend(CeedOperatorHasTensorBases(op, &is_tensor));
152a61b1c91SJeremy L Thompson   CeedInt block_sizes[3] = {data->thread_1d, ((!is_tensor || data->dim == 1) ? 1 : data->thread_1d), -1};
153f82027a4SJeremy L Thompson 
154f82027a4SJeremy L Thompson   if (is_tensor) {
15574398b5aSJeremy L Thompson     CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes));
156f82027a4SJeremy L Thompson   } else {
157a61b1c91SJeremy L Thompson     CeedInt elems_per_block = 64 * data->thread_1d > 256 ? 256 / data->thread_1d : 64;
158f82027a4SJeremy L Thompson 
159f82027a4SJeremy L Thompson     elems_per_block = elems_per_block > 0 ? elems_per_block : 1;
160f82027a4SJeremy L Thompson     block_sizes[2]  = elems_per_block;
161f82027a4SJeremy L Thompson   }
16274398b5aSJeremy L Thompson   if (data->dim == 1 || !is_tensor) {
1632b730f8bSJeremy L Thompson     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
164a61b1c91SJeremy L Thompson     CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar);
165b7453713SJeremy L Thompson 
1661a8516d0SJames Wright     CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem,
1671a8516d0SJames Wright                                                   is_run_good, opargs));
16874398b5aSJeremy L Thompson   } else if (data->dim == 2) {
1692b730f8bSJeremy L Thompson     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
170a61b1c91SJeremy L Thompson     CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
171b7453713SJeremy L Thompson 
1721a8516d0SJames Wright     CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem,
1731a8516d0SJames Wright                                                   is_run_good, opargs));
17474398b5aSJeremy L Thompson   } else if (data->dim == 3) {
1752b730f8bSJeremy L Thompson     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
176a61b1c91SJeremy L Thompson     CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
177b7453713SJeremy L Thompson 
1781a8516d0SJames Wright     CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem,
1791a8516d0SJames Wright                                                   is_run_good, opargs));
1807d8d0e25Snbeams   }
1817d8d0e25Snbeams 
1827d8d0e25Snbeams   // Restore input arrays
1839e201c85SYohann   for (CeedInt i = 0; i < num_input_fields; i++) {
1842b730f8bSJeremy L Thompson     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
1859e201c85SYohann     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
1867d8d0e25Snbeams     } else {
1873efc994bSJeremy L Thompson       bool       is_active;
188b7453713SJeremy L Thompson       CeedVector vec;
189b7453713SJeremy L Thompson 
1902b730f8bSJeremy L Thompson       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
1913efc994bSJeremy L Thompson       is_active = vec == CEED_VECTOR_ACTIVE;
192ea04d07fSJeremy L Thompson       if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
193ea04d07fSJeremy L Thompson       CeedCallBackend(CeedVectorDestroy(&vec));
1947d8d0e25Snbeams     }
1957d8d0e25Snbeams   }
1967d8d0e25Snbeams 
1977d8d0e25Snbeams   // Restore output arrays
1989e201c85SYohann   for (CeedInt i = 0; i < num_output_fields; i++) {
1992b730f8bSJeremy L Thompson     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
2009e201c85SYohann     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
2017d8d0e25Snbeams     } else {
2023efc994bSJeremy L Thompson       bool       is_active;
203b7453713SJeremy L Thompson       CeedVector vec;
204b7453713SJeremy L Thompson 
2052b730f8bSJeremy L Thompson       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
2063efc994bSJeremy L Thompson       is_active = vec == CEED_VECTOR_ACTIVE;
207ea04d07fSJeremy L Thompson       if (!is_active) CeedCallBackend(CeedVectorRestoreArray(vec, &data->fields.outputs[i]));
208ea04d07fSJeremy L Thompson       CeedCallBackend(CeedVectorDestroy(&vec));
2097d8d0e25Snbeams     }
2107d8d0e25Snbeams   }
2117d8d0e25Snbeams 
2123a2968d6SJeremy L Thompson   // Restore point coordinates, if needed
2133a2968d6SJeremy L Thompson   if (is_at_points) {
2143a2968d6SJeremy L Thompson     CeedVector vec;
2153a2968d6SJeremy L Thompson 
2163a2968d6SJeremy L Thompson     CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
2173a2968d6SJeremy L Thompson     CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords));
2183a2968d6SJeremy L Thompson     CeedCallBackend(CeedVectorDestroy(&vec));
2193a2968d6SJeremy L Thompson   }
2203a2968d6SJeremy L Thompson 
2217d8d0e25Snbeams   // Restore context data
2222b730f8bSJeremy L Thompson   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
2238d12f40eSJeremy L Thompson 
2248d12f40eSJeremy L Thompson   // Cleanup
2259bc66399SJeremy L Thompson   CeedCallBackend(CeedDestroy(&ceed));
226c11e12f4SJeremy L Thompson   CeedCallBackend(CeedQFunctionDestroy(&qf));
227ea04d07fSJeremy L Thompson   if (!(*is_run_good)) data->use_fallback = true;
228ea04d07fSJeremy L Thompson   return CEED_ERROR_SUCCESS;
229ea04d07fSJeremy L Thompson }
2308d12f40eSJeremy L Thompson 
CeedOperatorApplyAdd_Hip_gen(CeedOperator op,CeedVector input_vec,CeedVector output_vec,CeedRequest * request)231ea04d07fSJeremy L Thompson static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
232ea04d07fSJeremy L Thompson   bool              is_run_good = false;
233ea04d07fSJeremy L Thompson   const CeedScalar *input_arr   = NULL;
234ea04d07fSJeremy L Thompson   CeedScalar       *output_arr  = NULL;
235ea04d07fSJeremy L Thompson 
236ea04d07fSJeremy L Thompson   // Try to run kernel
237ea04d07fSJeremy L Thompson   if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr));
238ea04d07fSJeremy L Thompson   if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr));
239087855afSJeremy L Thompson   CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(op, NULL, input_arr, output_arr, &is_run_good, request));
240ea04d07fSJeremy L Thompson   if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr));
241ea04d07fSJeremy L Thompson   if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr));
242ea04d07fSJeremy L Thompson 
243ea04d07fSJeremy L Thompson   // Fallback on unsuccessful run
244ea04d07fSJeremy L Thompson   if (!is_run_good) {
2458d12f40eSJeremy L Thompson     CeedOperator op_fallback;
2468d12f40eSJeremy L Thompson 
247ca38d01dSJeremy L Thompson     CeedDebug(CeedOperatorReturnCeed(op), "\nFalling back to /gpu/hip/ref CeedOperator for ApplyAdd\n");
2488d12f40eSJeremy L Thompson     CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
2498d12f40eSJeremy L Thompson     CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
2508d12f40eSJeremy L Thompson   }
251e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2527d8d0e25Snbeams }
2537d8d0e25Snbeams 
CeedOperatorApplyAddComposite_Hip_gen(CeedOperator op,CeedVector input_vec,CeedVector output_vec,CeedRequest * request)254c99afcd8SJeremy L Thompson static int CeedOperatorApplyAddComposite_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
255*be395853SZach Atkins   bool                  is_run_good[CEED_COMPOSITE_MAX] = {false}, is_sequential;
256c99afcd8SJeremy L Thompson   CeedInt               num_suboperators;
257c99afcd8SJeremy L Thompson   const CeedScalar     *input_arr  = NULL;
258af34f196SJeremy L Thompson   CeedScalar           *output_arr = NULL;
259087855afSJeremy L Thompson   Ceed                  ceed;
2606eee1ffcSZach Atkins   CeedOperator_Hip_gen *impl;
261c99afcd8SJeremy L Thompson   CeedOperator         *sub_operators;
262c99afcd8SJeremy L Thompson 
263087855afSJeremy L Thompson   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
2646eee1ffcSZach Atkins   CeedCallBackend(CeedOperatorGetData(op, &impl));
265ed094490SJeremy L Thompson   CeedCallBackend(CeedOperatorCompositeGetNumSub(op, &num_suboperators));
266ed094490SJeremy L Thompson   CeedCallBackend(CeedOperatorCompositeGetSubList(op, &sub_operators));
267*be395853SZach Atkins   CeedCall(CeedOperatorCompositeIsSequential(op, &is_sequential));
268c99afcd8SJeremy L Thompson   if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr));
269c99afcd8SJeremy L Thompson   if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr));
270c99afcd8SJeremy L Thompson   for (CeedInt i = 0; i < num_suboperators; i++) {
271c99afcd8SJeremy L Thompson     CeedInt       num_elem     = 0;
272*be395853SZach Atkins     const CeedInt stream_index = is_sequential ? 0 : i;
273c99afcd8SJeremy L Thompson 
2746eee1ffcSZach Atkins     CeedCallBackend(CeedOperatorGetNumElements(sub_operators[i], &num_elem));
275c99afcd8SJeremy L Thompson     if (num_elem > 0) {
276*be395853SZach Atkins       if (!impl->streams[stream_index]) CeedCallHip(ceed, hipStreamCreate(&impl->streams[stream_index]));
277*be395853SZach Atkins       CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(sub_operators[i], impl->streams[stream_index], input_arr, output_arr, &is_run_good[i],
278*be395853SZach Atkins                                                        request));
2796eee1ffcSZach Atkins     } else {
2806eee1ffcSZach Atkins       is_run_good[i] = true;
2816eee1ffcSZach Atkins     }
2826eee1ffcSZach Atkins   }
283*be395853SZach Atkins   if (is_sequential) CeedCallHip(ceed, hipStreamSynchronize(impl->streams[0]));
284*be395853SZach Atkins   else {
2856eee1ffcSZach Atkins     for (CeedInt i = 0; i < num_suboperators; i++) {
2866eee1ffcSZach Atkins       if (impl->streams[i]) {
2876eee1ffcSZach Atkins         if (is_run_good[i]) CeedCallHip(ceed, hipStreamSynchronize(impl->streams[i]));
288c99afcd8SJeremy L Thompson       }
289c99afcd8SJeremy L Thompson     }
290*be395853SZach Atkins   }
291c99afcd8SJeremy L Thompson   if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr));
292c99afcd8SJeremy L Thompson   if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr));
293087855afSJeremy L Thompson   CeedCallHip(ceed, hipDeviceSynchronize());
294c99afcd8SJeremy L Thompson 
295c99afcd8SJeremy L Thompson   // Fallback on unsuccessful run
296c99afcd8SJeremy L Thompson   for (CeedInt i = 0; i < num_suboperators; i++) {
297c99afcd8SJeremy L Thompson     if (!is_run_good[i]) {
298c99afcd8SJeremy L Thompson       CeedOperator op_fallback;
299c99afcd8SJeremy L Thompson 
300ca38d01dSJeremy L Thompson       CeedDebug(ceed, "\nFalling back to /gpu/hip/ref CeedOperator for ApplyAdd\n");
301c99afcd8SJeremy L Thompson       CeedCallBackend(CeedOperatorGetFallback(sub_operators[i], &op_fallback));
302c99afcd8SJeremy L Thompson       CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
303c99afcd8SJeremy L Thompson     }
304c99afcd8SJeremy L Thompson   }
305087855afSJeremy L Thompson   CeedCallBackend(CeedDestroy(&ceed));
306c99afcd8SJeremy L Thompson   return CEED_ERROR_SUCCESS;
307c99afcd8SJeremy L Thompson }
308c99afcd8SJeremy L Thompson 
3097d8d0e25Snbeams //------------------------------------------------------------------------------
3105daefc96SJeremy L Thompson // QFunction assembly
3115daefc96SJeremy L Thompson //------------------------------------------------------------------------------
CeedOperatorLinearAssembleQFunctionCore_Hip_gen(CeedOperator op,bool build_objects,CeedVector * assembled,CeedElemRestriction * rstr,CeedRequest * request)3125daefc96SJeremy L Thompson static int CeedOperatorLinearAssembleQFunctionCore_Hip_gen(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr,
3135daefc96SJeremy L Thompson                                                            CeedRequest *request) {
3145daefc96SJeremy L Thompson   Ceed                  ceed;
3155daefc96SJeremy L Thompson   CeedOperator_Hip_gen *data;
3165daefc96SJeremy L Thompson 
3175daefc96SJeremy L Thompson   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
3185daefc96SJeremy L Thompson   CeedCallBackend(CeedOperatorGetData(op, &data));
3195daefc96SJeremy L Thompson 
3205daefc96SJeremy L Thompson   // Build the assembly kernel
3215daefc96SJeremy L Thompson   if (!data->assemble_qfunction && !data->use_assembly_fallback) {
3225daefc96SJeremy L Thompson     bool is_build_good = false;
3235daefc96SJeremy L Thompson 
3245daefc96SJeremy L Thompson     CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_build_good));
3255daefc96SJeremy L Thompson     if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelLinearAssembleQFunction_Hip_gen(op, &is_build_good));
3265daefc96SJeremy L Thompson     if (!is_build_good) data->use_assembly_fallback = true;
3275daefc96SJeremy L Thompson   }
3285daefc96SJeremy L Thompson 
3295daefc96SJeremy L Thompson   // Try assembly
3305daefc96SJeremy L Thompson   if (!data->use_assembly_fallback) {
3315daefc96SJeremy L Thompson     bool                   is_run_good = true;
3325daefc96SJeremy L Thompson     Ceed_Hip              *hip_data;
3335daefc96SJeremy L Thompson     CeedInt                num_elem, num_input_fields, num_output_fields;
3345daefc96SJeremy L Thompson     CeedEvalMode           eval_mode;
3355daefc96SJeremy L Thompson     CeedScalar            *assembled_array;
3365daefc96SJeremy L Thompson     CeedQFunctionField    *qf_input_fields, *qf_output_fields;
3375daefc96SJeremy L Thompson     CeedQFunction_Hip_gen *qf_data;
3385daefc96SJeremy L Thompson     CeedQFunction          qf;
3395daefc96SJeremy L Thompson     CeedOperatorField     *op_input_fields, *op_output_fields;
3405daefc96SJeremy L Thompson 
3415daefc96SJeremy L Thompson     CeedCallBackend(CeedGetData(ceed, &hip_data));
3425daefc96SJeremy L Thompson     CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
3435daefc96SJeremy L Thompson     CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
3445daefc96SJeremy L Thompson     CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
3455daefc96SJeremy L Thompson     CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
3465daefc96SJeremy L Thompson     CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
3475daefc96SJeremy L Thompson 
3485daefc96SJeremy L Thompson     // Input vectors
3495daefc96SJeremy L Thompson     for (CeedInt i = 0; i < num_input_fields; i++) {
3505daefc96SJeremy L Thompson       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
3515daefc96SJeremy L Thompson       if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
3525daefc96SJeremy L Thompson         data->fields.inputs[i] = NULL;
3535daefc96SJeremy L Thompson       } else {
3545daefc96SJeremy L Thompson         bool       is_active;
3555daefc96SJeremy L Thompson         CeedVector vec;
3565daefc96SJeremy L Thompson 
3575daefc96SJeremy L Thompson         // Get input vector
3585daefc96SJeremy L Thompson         CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
3595daefc96SJeremy L Thompson         is_active = vec == CEED_VECTOR_ACTIVE;
3605daefc96SJeremy L Thompson         if (is_active) data->fields.inputs[i] = NULL;
3615daefc96SJeremy L Thompson         else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
3625daefc96SJeremy L Thompson         CeedCallBackend(CeedVectorDestroy(&vec));
3635daefc96SJeremy L Thompson       }
3645daefc96SJeremy L Thompson     }
3655daefc96SJeremy L Thompson 
3665daefc96SJeremy L Thompson     // Get context data
3675daefc96SJeremy L Thompson     CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
3685daefc96SJeremy L Thompson 
3695daefc96SJeremy L Thompson     // Build objects if needed
3705daefc96SJeremy L Thompson     if (build_objects) {
3715daefc96SJeremy L Thompson       CeedInt qf_size_in = 0, qf_size_out = 0, Q;
3725daefc96SJeremy L Thompson 
3735daefc96SJeremy L Thompson       // Count number of active input fields
3745daefc96SJeremy L Thompson       {
3755daefc96SJeremy L Thompson         for (CeedInt i = 0; i < num_input_fields; i++) {
3765daefc96SJeremy L Thompson           CeedInt    field_size;
3775daefc96SJeremy L Thompson           CeedVector vec;
3785daefc96SJeremy L Thompson 
3795daefc96SJeremy L Thompson           // Get input vector
3805daefc96SJeremy L Thompson           CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
3815daefc96SJeremy L Thompson           // Check if active input
3825daefc96SJeremy L Thompson           if (vec == CEED_VECTOR_ACTIVE) {
3835daefc96SJeremy L Thompson             CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size));
3845daefc96SJeremy L Thompson             qf_size_in += field_size;
3855daefc96SJeremy L Thompson           }
3865daefc96SJeremy L Thompson           CeedCallBackend(CeedVectorDestroy(&vec));
3875daefc96SJeremy L Thompson         }
3885daefc96SJeremy L Thompson         CeedCheck(qf_size_in > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs");
3895daefc96SJeremy L Thompson       }
3905daefc96SJeremy L Thompson 
3915daefc96SJeremy L Thompson       // Count number of active output fields
3925daefc96SJeremy L Thompson       {
3935daefc96SJeremy L Thompson         for (CeedInt i = 0; i < num_output_fields; i++) {
3945daefc96SJeremy L Thompson           CeedInt    field_size;
3955daefc96SJeremy L Thompson           CeedVector vec;
3965daefc96SJeremy L Thompson 
3975daefc96SJeremy L Thompson           // Get output vector
3985daefc96SJeremy L Thompson           CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
3995daefc96SJeremy L Thompson           // Check if active output
4005daefc96SJeremy L Thompson           if (vec == CEED_VECTOR_ACTIVE) {
4015daefc96SJeremy L Thompson             CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &field_size));
4025daefc96SJeremy L Thompson             qf_size_out += field_size;
4035daefc96SJeremy L Thompson           }
4045daefc96SJeremy L Thompson           CeedCallBackend(CeedVectorDestroy(&vec));
4055daefc96SJeremy L Thompson         }
4065daefc96SJeremy L Thompson         CeedCheck(qf_size_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs");
4075daefc96SJeremy L Thompson       }
4085daefc96SJeremy L Thompson       CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
4095daefc96SJeremy L Thompson 
4105daefc96SJeremy L Thompson       // Actually build objects now
4115daefc96SJeremy L Thompson       const CeedSize l_size     = (CeedSize)num_elem * Q * qf_size_in * qf_size_out;
4125daefc96SJeremy L Thompson       CeedInt        strides[3] = {1, num_elem * Q, Q}; /* *NOPAD* */
4135daefc96SJeremy L Thompson 
4145daefc96SJeremy L Thompson       // Create output restriction
4155daefc96SJeremy L Thompson       CeedCallBackend(CeedElemRestrictionCreateStrided(ceed, num_elem, Q, qf_size_in * qf_size_out,
4165daefc96SJeremy L Thompson                                                        (CeedSize)qf_size_in * (CeedSize)qf_size_out * (CeedSize)num_elem * (CeedSize)Q, strides,
4175daefc96SJeremy L Thompson                                                        rstr));
4185daefc96SJeremy L Thompson       // Create assembled vector
4195daefc96SJeremy L Thompson       CeedCallBackend(CeedVectorCreate(ceed, l_size, assembled));
4205daefc96SJeremy L Thompson     }
4215daefc96SJeremy L Thompson 
4225daefc96SJeremy L Thompson     // Assembly array
4235daefc96SJeremy L Thompson     CeedCallBackend(CeedVectorGetArrayWrite(*assembled, CEED_MEM_DEVICE, &assembled_array));
4245daefc96SJeremy L Thompson 
4255daefc96SJeremy L Thompson     // Assemble QFunction
4265daefc96SJeremy L Thompson     bool  is_tensor = false;
4275daefc96SJeremy L Thompson     void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points, &assembled_array};
4285daefc96SJeremy L Thompson 
4295daefc96SJeremy L Thompson     CeedCallBackend(CeedOperatorHasTensorBases(op, &is_tensor));
4305daefc96SJeremy L Thompson     CeedInt block_sizes[3] = {data->thread_1d, ((!is_tensor || data->dim == 1) ? 1 : data->thread_1d), -1};
4315daefc96SJeremy L Thompson 
4325daefc96SJeremy L Thompson     if (is_tensor) {
4335daefc96SJeremy L Thompson       CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes));
4345daefc96SJeremy L Thompson     } else {
4355daefc96SJeremy L Thompson       CeedInt elems_per_block = 64 * data->thread_1d > 256 ? 256 / data->thread_1d : 64;
4365daefc96SJeremy L Thompson 
4375daefc96SJeremy L Thompson       elems_per_block = elems_per_block > 0 ? elems_per_block : 1;
4385daefc96SJeremy L Thompson       block_sizes[2]  = elems_per_block;
4395daefc96SJeremy L Thompson     }
4405daefc96SJeremy L Thompson     if (data->dim == 1 || !is_tensor) {
4415daefc96SJeremy L Thompson       CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
4425daefc96SJeremy L Thompson       CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar);
4435daefc96SJeremy L Thompson 
4445daefc96SJeremy L Thompson       CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_qfunction, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2],
4455daefc96SJeremy L Thompson                                                     sharedMem, &is_run_good, opargs));
4465daefc96SJeremy L Thompson     } else if (data->dim == 2) {
4475daefc96SJeremy L Thompson       CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
4485daefc96SJeremy L Thompson       CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
4495daefc96SJeremy L Thompson 
4505daefc96SJeremy L Thompson       CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_qfunction, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2],
4515daefc96SJeremy L Thompson                                                     sharedMem, &is_run_good, opargs));
4525daefc96SJeremy L Thompson     } else if (data->dim == 3) {
4535daefc96SJeremy L Thompson       CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
4545daefc96SJeremy L Thompson       CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
4555daefc96SJeremy L Thompson 
4565daefc96SJeremy L Thompson       CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_qfunction, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2],
4575daefc96SJeremy L Thompson                                                     sharedMem, &is_run_good, opargs));
4585daefc96SJeremy L Thompson     }
4595daefc96SJeremy L Thompson 
4605daefc96SJeremy L Thompson     // Restore input arrays
4615daefc96SJeremy L Thompson     for (CeedInt i = 0; i < num_input_fields; i++) {
4625daefc96SJeremy L Thompson       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
4635daefc96SJeremy L Thompson       if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
4645daefc96SJeremy L Thompson       } else {
4655daefc96SJeremy L Thompson         bool       is_active;
4665daefc96SJeremy L Thompson         CeedVector vec;
4675daefc96SJeremy L Thompson 
4685daefc96SJeremy L Thompson         CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
4695daefc96SJeremy L Thompson         is_active = vec == CEED_VECTOR_ACTIVE;
4705daefc96SJeremy L Thompson         if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
4715daefc96SJeremy L Thompson         CeedCallBackend(CeedVectorDestroy(&vec));
4725daefc96SJeremy L Thompson       }
4735daefc96SJeremy L Thompson     }
4745daefc96SJeremy L Thompson 
4755daefc96SJeremy L Thompson     // Restore context data
4765daefc96SJeremy L Thompson     CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
4775daefc96SJeremy L Thompson 
4785daefc96SJeremy L Thompson     // Restore assembly array
4795daefc96SJeremy L Thompson     CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array));
4805daefc96SJeremy L Thompson 
4815daefc96SJeremy L Thompson     // Cleanup
4825daefc96SJeremy L Thompson     CeedCallBackend(CeedQFunctionDestroy(&qf));
4835daefc96SJeremy L Thompson     if (!is_run_good) {
4845daefc96SJeremy L Thompson       data->use_assembly_fallback = true;
4855daefc96SJeremy L Thompson       if (build_objects) {
4865daefc96SJeremy L Thompson         CeedCallBackend(CeedVectorDestroy(assembled));
4875daefc96SJeremy L Thompson         CeedCallBackend(CeedElemRestrictionDestroy(rstr));
4885daefc96SJeremy L Thompson       }
4895daefc96SJeremy L Thompson     }
4905daefc96SJeremy L Thompson   }
4915daefc96SJeremy L Thompson   CeedCallBackend(CeedDestroy(&ceed));
4925daefc96SJeremy L Thompson 
4935daefc96SJeremy L Thompson   // Fallback, if needed
4945daefc96SJeremy L Thompson   if (data->use_assembly_fallback) {
4955daefc96SJeremy L Thompson     CeedOperator op_fallback;
4965daefc96SJeremy L Thompson 
497ca38d01dSJeremy L Thompson     CeedDebug(CeedOperatorReturnCeed(op), "\nFalling back to /gpu/hip/ref CeedOperator for LineearAssembleQFunction\n");
4985daefc96SJeremy L Thompson     CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
499ed094490SJeremy L Thompson     CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdateFallback(op_fallback, assembled, rstr, request));
5005daefc96SJeremy L Thompson     return CEED_ERROR_SUCCESS;
5015daefc96SJeremy L Thompson   }
5025daefc96SJeremy L Thompson   return CEED_ERROR_SUCCESS;
5035daefc96SJeremy L Thompson }
5045daefc96SJeremy L Thompson 
CeedOperatorLinearAssembleQFunction_Hip_gen(CeedOperator op,CeedVector * assembled,CeedElemRestriction * rstr,CeedRequest * request)5055daefc96SJeremy L Thompson static int CeedOperatorLinearAssembleQFunction_Hip_gen(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) {
5065daefc96SJeremy L Thompson   return CeedOperatorLinearAssembleQFunctionCore_Hip_gen(op, true, assembled, rstr, request);
5075daefc96SJeremy L Thompson }
5085daefc96SJeremy L Thompson 
CeedOperatorLinearAssembleQFunctionUpdate_Hip_gen(CeedOperator op,CeedVector assembled,CeedElemRestriction rstr,CeedRequest * request)5095daefc96SJeremy L Thompson static int CeedOperatorLinearAssembleQFunctionUpdate_Hip_gen(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) {
5105daefc96SJeremy L Thompson   return CeedOperatorLinearAssembleQFunctionCore_Hip_gen(op, false, &assembled, &rstr, request);
5115daefc96SJeremy L Thompson }
5125daefc96SJeremy L Thompson 
5135daefc96SJeremy L Thompson //------------------------------------------------------------------------------
5140183ed61SJeremy L Thompson // AtPoints diagonal assembly
5150183ed61SJeremy L Thompson //------------------------------------------------------------------------------
CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip_gen(CeedOperator op,CeedVector assembled,CeedRequest * request)5160183ed61SJeremy L Thompson static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip_gen(CeedOperator op, CeedVector assembled, CeedRequest *request) {
5170183ed61SJeremy L Thompson   Ceed                  ceed;
5180183ed61SJeremy L Thompson   CeedOperator_Hip_gen *data;
5190183ed61SJeremy L Thompson 
5200183ed61SJeremy L Thompson   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
5210183ed61SJeremy L Thompson   CeedCallBackend(CeedOperatorGetData(op, &data));
5220183ed61SJeremy L Thompson 
5230183ed61SJeremy L Thompson   // Build the assembly kernel
5240183ed61SJeremy L Thompson   if (!data->assemble_diagonal && !data->use_assembly_fallback) {
5250183ed61SJeremy L Thompson     bool                     is_build_good = false;
5260183ed61SJeremy L Thompson     CeedInt                  num_active_bases_in, num_active_bases_out;
5270183ed61SJeremy L Thompson     CeedOperatorAssemblyData assembly_data;
5280183ed61SJeremy L Thompson 
5290183ed61SJeremy L Thompson     CeedCallBackend(CeedOperatorGetOperatorAssemblyData(op, &assembly_data));
5301a8516d0SJames Wright     CeedCallBackend(CeedOperatorAssemblyDataGetEvalModes(assembly_data, &num_active_bases_in, NULL, NULL, NULL, &num_active_bases_out, NULL, NULL,
5311a8516d0SJames Wright                                                          NULL, NULL));
5320183ed61SJeremy L Thompson     if (num_active_bases_in == num_active_bases_out) {
5330183ed61SJeremy L Thompson       CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_build_good));
5340183ed61SJeremy L Thompson       if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelDiagonalAssemblyAtPoints_Hip_gen(op, &is_build_good));
5350183ed61SJeremy L Thompson     }
5360183ed61SJeremy L Thompson     if (!is_build_good) data->use_assembly_fallback = true;
5370183ed61SJeremy L Thompson   }
5380183ed61SJeremy L Thompson 
5390183ed61SJeremy L Thompson   // Try assembly
5400183ed61SJeremy L Thompson   if (!data->use_assembly_fallback) {
5410183ed61SJeremy L Thompson     bool                   is_run_good = true;
5420183ed61SJeremy L Thompson     Ceed_Hip              *hip_data;
5430183ed61SJeremy L Thompson     CeedInt                num_elem, num_input_fields, num_output_fields;
5440183ed61SJeremy L Thompson     CeedEvalMode           eval_mode;
5450183ed61SJeremy L Thompson     CeedScalar            *assembled_array;
5460183ed61SJeremy L Thompson     CeedQFunctionField    *qf_input_fields, *qf_output_fields;
5470183ed61SJeremy L Thompson     CeedQFunction_Hip_gen *qf_data;
5480183ed61SJeremy L Thompson     CeedQFunction          qf;
5490183ed61SJeremy L Thompson     CeedOperatorField     *op_input_fields, *op_output_fields;
5500183ed61SJeremy L Thompson 
5510183ed61SJeremy L Thompson     CeedCallBackend(CeedGetData(ceed, &hip_data));
5520183ed61SJeremy L Thompson     CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
5530183ed61SJeremy L Thompson     CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
5540183ed61SJeremy L Thompson     CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
5550183ed61SJeremy L Thompson     CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
5560183ed61SJeremy L Thompson     CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
5570183ed61SJeremy L Thompson 
5580183ed61SJeremy L Thompson     // Input vectors
5590183ed61SJeremy L Thompson     for (CeedInt i = 0; i < num_input_fields; i++) {
5600183ed61SJeremy L Thompson       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
5610183ed61SJeremy L Thompson       if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
5620183ed61SJeremy L Thompson         data->fields.inputs[i] = NULL;
5630183ed61SJeremy L Thompson       } else {
5640183ed61SJeremy L Thompson         bool       is_active;
5650183ed61SJeremy L Thompson         CeedVector vec;
5660183ed61SJeremy L Thompson 
5670183ed61SJeremy L Thompson         // Get input vector
5680183ed61SJeremy L Thompson         CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
5690183ed61SJeremy L Thompson         is_active = vec == CEED_VECTOR_ACTIVE;
5700183ed61SJeremy L Thompson         if (is_active) data->fields.inputs[i] = NULL;
5710183ed61SJeremy L Thompson         else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
5720183ed61SJeremy L Thompson         CeedCallBackend(CeedVectorDestroy(&vec));
5730183ed61SJeremy L Thompson       }
5740183ed61SJeremy L Thompson     }
5750183ed61SJeremy L Thompson 
5760183ed61SJeremy L Thompson     // Point coordinates
5770183ed61SJeremy L Thompson     {
5780183ed61SJeremy L Thompson       CeedVector vec;
5790183ed61SJeremy L Thompson 
5800183ed61SJeremy L Thompson       CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
5810183ed61SJeremy L Thompson       CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords));
5820183ed61SJeremy L Thompson       CeedCallBackend(CeedVectorDestroy(&vec));
5830183ed61SJeremy L Thompson 
5840183ed61SJeremy L Thompson       // Points per elem
5850183ed61SJeremy L Thompson       if (num_elem != data->points.num_elem) {
5860183ed61SJeremy L Thompson         CeedInt            *points_per_elem;
5870183ed61SJeremy L Thompson         const CeedInt       num_bytes   = num_elem * sizeof(CeedInt);
5880183ed61SJeremy L Thompson         CeedElemRestriction rstr_points = NULL;
5890183ed61SJeremy L Thompson 
5900183ed61SJeremy L Thompson         data->points.num_elem = num_elem;
5910183ed61SJeremy L Thompson         CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL));
5920183ed61SJeremy L Thompson         CeedCallBackend(CeedCalloc(num_elem, &points_per_elem));
5930183ed61SJeremy L Thompson         for (CeedInt e = 0; e < num_elem; e++) {
5940183ed61SJeremy L Thompson           CeedInt num_points_elem;
5950183ed61SJeremy L Thompson 
5960183ed61SJeremy L Thompson           CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem));
5970183ed61SJeremy L Thompson           points_per_elem[e] = num_points_elem;
5980183ed61SJeremy L Thompson         }
5990183ed61SJeremy L Thompson         if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem));
6000183ed61SJeremy L Thompson         CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes));
6010183ed61SJeremy L Thompson         CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice));
6020183ed61SJeremy L Thompson         CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
6030183ed61SJeremy L Thompson         CeedCallBackend(CeedFree(&points_per_elem));
6040183ed61SJeremy L Thompson       }
6050183ed61SJeremy L Thompson     }
6060183ed61SJeremy L Thompson 
6070183ed61SJeremy L Thompson     // Get context data
6080183ed61SJeremy L Thompson     CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
6090183ed61SJeremy L Thompson 
6100183ed61SJeremy L Thompson     // Assembly array
6110183ed61SJeremy L Thompson     CeedCallBackend(CeedVectorGetArray(assembled, CEED_MEM_DEVICE, &assembled_array));
6120183ed61SJeremy L Thompson 
6130183ed61SJeremy L Thompson     // Assemble diagonal
6140183ed61SJeremy L Thompson     void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points, &assembled_array};
6150183ed61SJeremy L Thompson 
6160183ed61SJeremy L Thompson     CeedInt block_sizes[3] = {data->thread_1d, (data->dim == 1 ? 1 : data->thread_1d), -1};
6170183ed61SJeremy L Thompson 
6180183ed61SJeremy L Thompson     CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes));
6190183ed61SJeremy L Thompson     block_sizes[2] = 1;
6200183ed61SJeremy L Thompson     if (data->dim == 1) {
6210183ed61SJeremy L Thompson       CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
6220183ed61SJeremy L Thompson       CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar);
6230183ed61SJeremy L Thompson 
6240183ed61SJeremy L Thompson       CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2],
6250183ed61SJeremy L Thompson                                                     sharedMem, &is_run_good, opargs));
6260183ed61SJeremy L Thompson     } else if (data->dim == 2) {
6270183ed61SJeremy L Thompson       CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
6280183ed61SJeremy L Thompson       CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
6290183ed61SJeremy L Thompson 
6300183ed61SJeremy L Thompson       CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2],
6310183ed61SJeremy L Thompson                                                     sharedMem, &is_run_good, opargs));
6320183ed61SJeremy L Thompson     } else if (data->dim == 3) {
6330183ed61SJeremy L Thompson       CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
6340183ed61SJeremy L Thompson       CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
6350183ed61SJeremy L Thompson 
6360183ed61SJeremy L Thompson       CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2],
6370183ed61SJeremy L Thompson                                                     sharedMem, &is_run_good, opargs));
6380183ed61SJeremy L Thompson     }
639692716b7SZach Atkins     CeedCallHip(ceed, hipDeviceSynchronize());
6400183ed61SJeremy L Thompson 
6410183ed61SJeremy L Thompson     // Restore input arrays
6420183ed61SJeremy L Thompson     for (CeedInt i = 0; i < num_input_fields; i++) {
6430183ed61SJeremy L Thompson       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
6440183ed61SJeremy L Thompson       if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
6450183ed61SJeremy L Thompson       } else {
6460183ed61SJeremy L Thompson         bool       is_active;
6470183ed61SJeremy L Thompson         CeedVector vec;
6480183ed61SJeremy L Thompson 
6490183ed61SJeremy L Thompson         CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
6500183ed61SJeremy L Thompson         is_active = vec == CEED_VECTOR_ACTIVE;
6510183ed61SJeremy L Thompson         if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
6520183ed61SJeremy L Thompson         CeedCallBackend(CeedVectorDestroy(&vec));
6530183ed61SJeremy L Thompson       }
6540183ed61SJeremy L Thompson     }
6550183ed61SJeremy L Thompson 
6560183ed61SJeremy L Thompson     // Restore point coordinates
6570183ed61SJeremy L Thompson     {
6580183ed61SJeremy L Thompson       CeedVector vec;
6590183ed61SJeremy L Thompson 
6600183ed61SJeremy L Thompson       CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
6610183ed61SJeremy L Thompson       CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords));
6620183ed61SJeremy L Thompson       CeedCallBackend(CeedVectorDestroy(&vec));
6630183ed61SJeremy L Thompson     }
6640183ed61SJeremy L Thompson 
6650183ed61SJeremy L Thompson     // Restore context data
6660183ed61SJeremy L Thompson     CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
6670183ed61SJeremy L Thompson 
6680183ed61SJeremy L Thompson     // Restore assembly array
6690183ed61SJeremy L Thompson     CeedCallBackend(CeedVectorRestoreArray(assembled, &assembled_array));
6700183ed61SJeremy L Thompson 
6710183ed61SJeremy L Thompson     // Cleanup
6720183ed61SJeremy L Thompson     CeedCallBackend(CeedQFunctionDestroy(&qf));
6730183ed61SJeremy L Thompson     if (!is_run_good) data->use_assembly_fallback = true;
6740183ed61SJeremy L Thompson   }
6750183ed61SJeremy L Thompson   CeedCallBackend(CeedDestroy(&ceed));
6760183ed61SJeremy L Thompson 
6770183ed61SJeremy L Thompson   // Fallback, if needed
6780183ed61SJeremy L Thompson   if (data->use_assembly_fallback) {
6790183ed61SJeremy L Thompson     CeedOperator op_fallback;
6800183ed61SJeremy L Thompson 
681ca38d01dSJeremy L Thompson     CeedDebug(CeedOperatorReturnCeed(op), "\nFalling back to /gpu/hip/ref CeedOperator for AtPoints LinearAssembleAddDiagonal\n");
6820183ed61SJeremy L Thompson     CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
6830183ed61SJeremy L Thompson     CeedCallBackend(CeedOperatorLinearAssembleAddDiagonal(op_fallback, assembled, request));
6840183ed61SJeremy L Thompson     return CEED_ERROR_SUCCESS;
6850183ed61SJeremy L Thompson   }
6860183ed61SJeremy L Thompson   return CEED_ERROR_SUCCESS;
6870183ed61SJeremy L Thompson }
6880183ed61SJeremy L Thompson 
6890183ed61SJeremy L Thompson //------------------------------------------------------------------------------
690692716b7SZach Atkins // AtPoints full assembly
691692716b7SZach Atkins //------------------------------------------------------------------------------
CeedOperatorAssembleSingleAtPoints_Hip_gen(CeedOperator op,CeedInt offset,CeedVector assembled)692ed094490SJeremy L Thompson static int CeedOperatorAssembleSingleAtPoints_Hip_gen(CeedOperator op, CeedInt offset, CeedVector assembled) {
693692716b7SZach Atkins   Ceed                  ceed;
694692716b7SZach Atkins   CeedOperator_Hip_gen *data;
695692716b7SZach Atkins 
696692716b7SZach Atkins   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
697692716b7SZach Atkins   CeedCallBackend(CeedOperatorGetData(op, &data));
698692716b7SZach Atkins 
699692716b7SZach Atkins   // Build the assembly kernel
700692716b7SZach Atkins   if (!data->assemble_full && !data->use_assembly_fallback) {
701692716b7SZach Atkins     bool                     is_build_good = false;
702692716b7SZach Atkins     CeedInt                  num_active_bases_in, num_active_bases_out;
703692716b7SZach Atkins     CeedOperatorAssemblyData assembly_data;
704692716b7SZach Atkins 
705692716b7SZach Atkins     CeedCallBackend(CeedOperatorGetOperatorAssemblyData(op, &assembly_data));
7061a8516d0SJames Wright     CeedCallBackend(CeedOperatorAssemblyDataGetEvalModes(assembly_data, &num_active_bases_in, NULL, NULL, NULL, &num_active_bases_out, NULL, NULL,
7071a8516d0SJames Wright                                                          NULL, NULL));
708692716b7SZach Atkins     if (num_active_bases_in == num_active_bases_out) {
709692716b7SZach Atkins       CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_build_good));
710692716b7SZach Atkins       if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelFullAssemblyAtPoints_Hip_gen(op, &is_build_good));
711692716b7SZach Atkins     }
712692716b7SZach Atkins     if (!is_build_good) {
713692716b7SZach Atkins       CeedDebug(ceed, "Single Operator Assemble at Points compile failed, using fallback\n");
714692716b7SZach Atkins       data->use_assembly_fallback = true;
715692716b7SZach Atkins     }
716692716b7SZach Atkins   }
717692716b7SZach Atkins 
718692716b7SZach Atkins   // Try assembly
719692716b7SZach Atkins   if (!data->use_assembly_fallback) {
720692716b7SZach Atkins     bool                   is_run_good = true;
721692716b7SZach Atkins     Ceed_Hip              *Hip_data;
722692716b7SZach Atkins     CeedInt                num_elem, num_input_fields, num_output_fields;
723692716b7SZach Atkins     CeedEvalMode           eval_mode;
724692716b7SZach Atkins     CeedScalar            *assembled_array;
725692716b7SZach Atkins     CeedQFunctionField    *qf_input_fields, *qf_output_fields;
726692716b7SZach Atkins     CeedQFunction_Hip_gen *qf_data;
727692716b7SZach Atkins     CeedQFunction          qf;
728692716b7SZach Atkins     CeedOperatorField     *op_input_fields, *op_output_fields;
729692716b7SZach Atkins 
730692716b7SZach Atkins     CeedCallBackend(CeedGetData(ceed, &Hip_data));
731692716b7SZach Atkins     CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
732692716b7SZach Atkins     CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
733692716b7SZach Atkins     CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
734692716b7SZach Atkins     CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
735692716b7SZach Atkins     CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
736692716b7SZach Atkins     CeedDebug(ceed, "Running single operator assemble for /gpu/hip/gen\n");
737692716b7SZach Atkins 
738692716b7SZach Atkins     // Input vectors
739692716b7SZach Atkins     for (CeedInt i = 0; i < num_input_fields; i++) {
740692716b7SZach Atkins       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
741692716b7SZach Atkins       if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
742692716b7SZach Atkins         data->fields.inputs[i] = NULL;
743692716b7SZach Atkins       } else {
744692716b7SZach Atkins         bool       is_active;
745692716b7SZach Atkins         CeedVector vec;
746692716b7SZach Atkins 
747692716b7SZach Atkins         // Get input vector
748692716b7SZach Atkins         CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
749692716b7SZach Atkins         is_active = vec == CEED_VECTOR_ACTIVE;
750692716b7SZach Atkins         if (is_active) data->fields.inputs[i] = NULL;
751692716b7SZach Atkins         else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
752692716b7SZach Atkins         CeedCallBackend(CeedVectorDestroy(&vec));
753692716b7SZach Atkins       }
754692716b7SZach Atkins     }
755692716b7SZach Atkins 
756692716b7SZach Atkins     // Point coordinates
757692716b7SZach Atkins     {
758692716b7SZach Atkins       CeedVector vec;
759692716b7SZach Atkins 
760692716b7SZach Atkins       CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
761692716b7SZach Atkins       CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords));
762692716b7SZach Atkins       CeedCallBackend(CeedVectorDestroy(&vec));
763692716b7SZach Atkins 
764692716b7SZach Atkins       // Points per elem
765692716b7SZach Atkins       if (num_elem != data->points.num_elem) {
766692716b7SZach Atkins         CeedInt            *points_per_elem;
767692716b7SZach Atkins         const CeedInt       num_bytes   = num_elem * sizeof(CeedInt);
768692716b7SZach Atkins         CeedElemRestriction rstr_points = NULL;
769692716b7SZach Atkins 
770692716b7SZach Atkins         data->points.num_elem = num_elem;
771692716b7SZach Atkins         CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL));
772692716b7SZach Atkins         CeedCallBackend(CeedCalloc(num_elem, &points_per_elem));
773692716b7SZach Atkins         for (CeedInt e = 0; e < num_elem; e++) {
774692716b7SZach Atkins           CeedInt num_points_elem;
775692716b7SZach Atkins 
776692716b7SZach Atkins           CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem));
777692716b7SZach Atkins           points_per_elem[e] = num_points_elem;
778692716b7SZach Atkins         }
779692716b7SZach Atkins         if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem));
780692716b7SZach Atkins         CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes));
781692716b7SZach Atkins         CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice));
782692716b7SZach Atkins         CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
783692716b7SZach Atkins         CeedCallBackend(CeedFree(&points_per_elem));
784692716b7SZach Atkins       }
785692716b7SZach Atkins     }
786692716b7SZach Atkins 
787692716b7SZach Atkins     // Get context data
788692716b7SZach Atkins     CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
789692716b7SZach Atkins 
790692716b7SZach Atkins     // Assembly array
791692716b7SZach Atkins     CeedCallBackend(CeedVectorGetArray(assembled, CEED_MEM_DEVICE, &assembled_array));
792692716b7SZach Atkins     CeedScalar *assembled_offset_array = &assembled_array[offset];
793692716b7SZach Atkins 
794692716b7SZach Atkins     // Assemble diagonal
795692716b7SZach Atkins     void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields,          &data->B,
796692716b7SZach Atkins                       &data->G,          &data->W,      &data->points,  &assembled_offset_array};
797692716b7SZach Atkins 
798692716b7SZach Atkins     CeedInt block_sizes[3] = {data->thread_1d, (data->dim == 1 ? 1 : data->thread_1d), -1};
799692716b7SZach Atkins 
800692716b7SZach Atkins     CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes));
801692716b7SZach Atkins     block_sizes[2] = 1;
802692716b7SZach Atkins     if (data->dim == 1) {
803692716b7SZach Atkins       CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
804692716b7SZach Atkins       CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar);
805692716b7SZach Atkins 
806692716b7SZach Atkins       CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_full, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem,
807692716b7SZach Atkins                                                     &is_run_good, opargs));
808692716b7SZach Atkins     } else if (data->dim == 2) {
809692716b7SZach Atkins       CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
810692716b7SZach Atkins       CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
811692716b7SZach Atkins 
812692716b7SZach Atkins       CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_full, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem,
813692716b7SZach Atkins                                                     &is_run_good, opargs));
814692716b7SZach Atkins     } else if (data->dim == 3) {
815692716b7SZach Atkins       CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
816692716b7SZach Atkins       CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
817692716b7SZach Atkins 
818692716b7SZach Atkins       CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_full, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem,
819692716b7SZach Atkins                                                     &is_run_good, opargs));
820692716b7SZach Atkins     }
821692716b7SZach Atkins     CeedCallHip(ceed, hipDeviceSynchronize());
822692716b7SZach Atkins 
823692716b7SZach Atkins     // Restore input arrays
824692716b7SZach Atkins     for (CeedInt i = 0; i < num_input_fields; i++) {
825692716b7SZach Atkins       CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
826692716b7SZach Atkins       if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
827692716b7SZach Atkins       } else {
828692716b7SZach Atkins         bool       is_active;
829692716b7SZach Atkins         CeedVector vec;
830692716b7SZach Atkins 
831692716b7SZach Atkins         CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
832692716b7SZach Atkins         is_active = vec == CEED_VECTOR_ACTIVE;
833692716b7SZach Atkins         if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
834692716b7SZach Atkins         CeedCallBackend(CeedVectorDestroy(&vec));
835692716b7SZach Atkins       }
836692716b7SZach Atkins     }
837692716b7SZach Atkins 
838692716b7SZach Atkins     // Restore point coordinates
839692716b7SZach Atkins     {
840692716b7SZach Atkins       CeedVector vec;
841692716b7SZach Atkins 
842692716b7SZach Atkins       CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
843692716b7SZach Atkins       CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords));
844692716b7SZach Atkins       CeedCallBackend(CeedVectorDestroy(&vec));
845692716b7SZach Atkins     }
846692716b7SZach Atkins 
847692716b7SZach Atkins     // Restore context data
848692716b7SZach Atkins     CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
849692716b7SZach Atkins 
850692716b7SZach Atkins     // Restore assembly array
851692716b7SZach Atkins     CeedCallBackend(CeedVectorRestoreArray(assembled, &assembled_array));
852692716b7SZach Atkins 
853692716b7SZach Atkins     // Cleanup
854692716b7SZach Atkins     CeedCallBackend(CeedQFunctionDestroy(&qf));
855692716b7SZach Atkins     if (!is_run_good) {
856692716b7SZach Atkins       CeedDebug(ceed, "Single Operator Assemble at Points run failed, using fallback\n");
857692716b7SZach Atkins       data->use_assembly_fallback = true;
858692716b7SZach Atkins     }
859692716b7SZach Atkins   }
860692716b7SZach Atkins   CeedCallBackend(CeedDestroy(&ceed));
861692716b7SZach Atkins 
862692716b7SZach Atkins   // Fallback, if needed
863692716b7SZach Atkins   if (data->use_assembly_fallback) {
864692716b7SZach Atkins     CeedOperator op_fallback;
865692716b7SZach Atkins 
866ca38d01dSJeremy L Thompson     CeedDebug(CeedOperatorReturnCeed(op), "\nFalling back to /gpu/hip/ref CeedOperator for AtPoints SingleOperatorAssemble\n");
867692716b7SZach Atkins     CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
868ed094490SJeremy L Thompson     CeedCallBackend(CeedOperatorAssembleSingle(op_fallback, offset, assembled));
869692716b7SZach Atkins     return CEED_ERROR_SUCCESS;
870692716b7SZach Atkins   }
871692716b7SZach Atkins   return CEED_ERROR_SUCCESS;
872692716b7SZach Atkins }
873692716b7SZach Atkins 
874692716b7SZach Atkins //------------------------------------------------------------------------------
8757d8d0e25Snbeams // Create operator
8767d8d0e25Snbeams //------------------------------------------------------------------------------
CeedOperatorCreate_Hip_gen(CeedOperator op)8777d8d0e25Snbeams int CeedOperatorCreate_Hip_gen(CeedOperator op) {
8780183ed61SJeremy L Thompson   bool                  is_composite, is_at_points;
8797d8d0e25Snbeams   Ceed                  ceed;
8807d8d0e25Snbeams   CeedOperator_Hip_gen *impl;
8817d8d0e25Snbeams 
882b7453713SJeremy L Thompson   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
8832b730f8bSJeremy L Thompson   CeedCallBackend(CeedCalloc(1, &impl));
8842b730f8bSJeremy L Thompson   CeedCallBackend(CeedOperatorSetData(op, impl));
885c99afcd8SJeremy L Thompson   CeedCall(CeedOperatorIsComposite(op, &is_composite));
886c99afcd8SJeremy L Thompson   if (is_composite) {
887c99afcd8SJeremy L Thompson     CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAddComposite", CeedOperatorApplyAddComposite_Hip_gen));
888c99afcd8SJeremy L Thompson   } else {
8892b730f8bSJeremy L Thompson     CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip_gen));
890c99afcd8SJeremy L Thompson   }
8910183ed61SJeremy L Thompson   CeedCall(CeedOperatorIsAtPoints(op, &is_at_points));
8920183ed61SJeremy L Thompson   if (is_at_points) {
8930183ed61SJeremy L Thompson     CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip_gen));
894ed094490SJeremy L Thompson     CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedOperatorAssembleSingleAtPoints_Hip_gen));
8950183ed61SJeremy L Thompson   }
8965daefc96SJeremy L Thompson   if (!is_at_points) {
8975daefc96SJeremy L Thompson     CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Hip_gen));
8985daefc96SJeremy L Thompson     CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Hip_gen));
8995daefc96SJeremy L Thompson   }
9002b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip_gen));
9019bc66399SJeremy L Thompson   CeedCallBackend(CeedDestroy(&ceed));
902e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
9037d8d0e25Snbeams }
9042a86cc9dSSebastian Grimberg 
9057d8d0e25Snbeams //------------------------------------------------------------------------------
906