19ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
37d8d0e25Snbeams //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
57d8d0e25Snbeams //
63d8e8822SJeremy L Thompson // This file is part of CEED: http://github.com/ceed
77d8d0e25Snbeams
849aac155SJeremy L Thompson #include <ceed.h>
9ec3da8bcSJed Brown #include <ceed/backend.h>
1049aac155SJeremy L Thompson #include <ceed/jit-source/hip/hip-types.h>
113d576824SJeremy L Thompson #include <stddef.h>
123a2968d6SJeremy L Thompson #include <hip/hiprtc.h>
132b730f8bSJeremy L Thompson
14b2165e7aSSebastian Grimberg #include "../hip/ceed-hip-common.h"
157d8d0e25Snbeams #include "../hip/ceed-hip-compile.h"
162b730f8bSJeremy L Thompson #include "ceed-hip-gen-operator-build.h"
172b730f8bSJeremy L Thompson #include "ceed-hip-gen.h"
187d8d0e25Snbeams
197d8d0e25Snbeams //------------------------------------------------------------------------------
207d8d0e25Snbeams // Destroy operator
217d8d0e25Snbeams //------------------------------------------------------------------------------
CeedOperatorDestroy_Hip_gen(CeedOperator op)227d8d0e25Snbeams static int CeedOperatorDestroy_Hip_gen(CeedOperator op) {
233a2968d6SJeremy L Thompson Ceed ceed;
247d8d0e25Snbeams CeedOperator_Hip_gen *impl;
256eee1ffcSZach Atkins bool is_composite;
26b7453713SJeremy L Thompson
273a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
282b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &impl));
296eee1ffcSZach Atkins CeedCallBackend(CeedOperatorIsComposite(op, &is_composite));
306eee1ffcSZach Atkins if (is_composite) {
316eee1ffcSZach Atkins CeedInt num_suboperators;
326eee1ffcSZach Atkins
33ed094490SJeremy L Thompson CeedCall(CeedOperatorCompositeGetNumSub(op, &num_suboperators));
346eee1ffcSZach Atkins for (CeedInt i = 0; i < num_suboperators; i++) {
356eee1ffcSZach Atkins if (impl->streams[i]) CeedCallHip(ceed, hipStreamDestroy(impl->streams[i]));
366eee1ffcSZach Atkins impl->streams[i] = NULL;
376eee1ffcSZach Atkins }
386eee1ffcSZach Atkins }
398b7d3340SJeremy L Thompson if (impl->module) CeedCallHip(ceed, hipModuleUnload(impl->module));
400183ed61SJeremy L Thompson if (impl->module_assemble_full) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_full));
410183ed61SJeremy L Thompson if (impl->module_assemble_diagonal) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_diagonal));
425daefc96SJeremy L Thompson if (impl->module_assemble_qfunction) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_qfunction));
433a2968d6SJeremy L Thompson if (impl->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)impl->points.num_per_elem));
442b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&impl));
453a2968d6SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed));
46e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS;
477d8d0e25Snbeams }
487d8d0e25Snbeams
497d8d0e25Snbeams //------------------------------------------------------------------------------
507d8d0e25Snbeams // Apply and add to output
517d8d0e25Snbeams //------------------------------------------------------------------------------
CeedOperatorApplyAddCore_Hip_gen(CeedOperator op,hipStream_t stream,const CeedScalar * input_arr,CeedScalar * output_arr,bool * is_run_good,CeedRequest * request)52e9c76bddSJeremy L Thompson static int CeedOperatorApplyAddCore_Hip_gen(CeedOperator op, hipStream_t stream, const CeedScalar *input_arr, CeedScalar *output_arr,
53e9c76bddSJeremy L Thompson bool *is_run_good, CeedRequest *request) {
54ea04d07fSJeremy L Thompson bool is_at_points, is_tensor;
557d8d0e25Snbeams Ceed ceed;
56b7453713SJeremy L Thompson CeedInt num_elem, num_input_fields, num_output_fields;
57b7453713SJeremy L Thompson CeedEvalMode eval_mode;
58b7453713SJeremy L Thompson CeedQFunctionField *qf_input_fields, *qf_output_fields;
597d8d0e25Snbeams CeedQFunction_Hip_gen *qf_data;
60b7453713SJeremy L Thompson CeedQFunction qf;
61b7453713SJeremy L Thompson CeedOperatorField *op_input_fields, *op_output_fields;
62b7453713SJeremy L Thompson CeedOperator_Hip_gen *data;
63b7453713SJeremy L Thompson
648d12f40eSJeremy L Thompson // Creation of the operator
65ea04d07fSJeremy L Thompson CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, is_run_good));
66ea04d07fSJeremy L Thompson if (!(*is_run_good)) return CEED_ERROR_SUCCESS;
67f6eafd79SJeremy L Thompson
68c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
69c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &data));
70c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
71c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
72c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
738d12f40eSJeremy L Thompson CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
74c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
75c11e12f4SJeremy L Thompson
767d8d0e25Snbeams // Input vectors
779e201c85SYohann for (CeedInt i = 0; i < num_input_fields; i++) {
782b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
799e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
809e201c85SYohann data->fields.inputs[i] = NULL;
817d8d0e25Snbeams } else {
823efc994bSJeremy L Thompson bool is_active;
83b7453713SJeremy L Thompson CeedVector vec;
84b7453713SJeremy L Thompson
857d8d0e25Snbeams // Get input vector
862b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
873efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE;
88ea04d07fSJeremy L Thompson if (is_active) data->fields.inputs[i] = input_arr;
89ea04d07fSJeremy L Thompson else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
90ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec));
917d8d0e25Snbeams }
927d8d0e25Snbeams }
937d8d0e25Snbeams
947d8d0e25Snbeams // Output vectors
959e201c85SYohann for (CeedInt i = 0; i < num_output_fields; i++) {
962b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
979e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
989e201c85SYohann data->fields.outputs[i] = NULL;
997d8d0e25Snbeams } else {
1003efc994bSJeremy L Thompson bool is_active;
101b7453713SJeremy L Thompson CeedVector vec;
102b7453713SJeremy L Thompson
1037d8d0e25Snbeams // Get output vector
1042b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
1053efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE;
106ea04d07fSJeremy L Thompson if (is_active) data->fields.outputs[i] = output_arr;
1070c8fbeedSJeremy L Thompson else CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.outputs[i]));
108ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec));
1097d8d0e25Snbeams }
1107d8d0e25Snbeams }
1117d8d0e25Snbeams
1123a2968d6SJeremy L Thompson // Point coordinates, if needed
1133a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorIsAtPoints(op, &is_at_points));
1143a2968d6SJeremy L Thompson if (is_at_points) {
1153a2968d6SJeremy L Thompson // Coords
1163a2968d6SJeremy L Thompson CeedVector vec;
1173a2968d6SJeremy L Thompson
1183a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
1193a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords));
1203a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec));
1213a2968d6SJeremy L Thompson
1223a2968d6SJeremy L Thompson // Points per elem
1233a2968d6SJeremy L Thompson if (num_elem != data->points.num_elem) {
1243a2968d6SJeremy L Thompson CeedInt *points_per_elem;
1253a2968d6SJeremy L Thompson const CeedInt num_bytes = num_elem * sizeof(CeedInt);
1263a2968d6SJeremy L Thompson CeedElemRestriction rstr_points = NULL;
1273a2968d6SJeremy L Thompson
1283a2968d6SJeremy L Thompson data->points.num_elem = num_elem;
1293a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL));
1303a2968d6SJeremy L Thompson CeedCallBackend(CeedCalloc(num_elem, &points_per_elem));
1313a2968d6SJeremy L Thompson for (CeedInt e = 0; e < num_elem; e++) {
1323a2968d6SJeremy L Thompson CeedInt num_points_elem;
1333a2968d6SJeremy L Thompson
1343a2968d6SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem));
1353a2968d6SJeremy L Thompson points_per_elem[e] = num_points_elem;
1363a2968d6SJeremy L Thompson }
1373a2968d6SJeremy L Thompson if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem));
1383a2968d6SJeremy L Thompson CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes));
1393a2968d6SJeremy L Thompson CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice));
1403a2968d6SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
1413a2968d6SJeremy L Thompson CeedCallBackend(CeedFree(&points_per_elem));
1423a2968d6SJeremy L Thompson }
1433a2968d6SJeremy L Thompson }
1443a2968d6SJeremy L Thompson
1457d8d0e25Snbeams // Get context data
1462b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
1477d8d0e25Snbeams
1487d8d0e25Snbeams // Apply operator
1493a2968d6SJeremy L Thompson void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points};
150b7453713SJeremy L Thompson
1519123fb08SJeremy L Thompson CeedCallBackend(CeedOperatorHasTensorBases(op, &is_tensor));
152a61b1c91SJeremy L Thompson CeedInt block_sizes[3] = {data->thread_1d, ((!is_tensor || data->dim == 1) ? 1 : data->thread_1d), -1};
153f82027a4SJeremy L Thompson
154f82027a4SJeremy L Thompson if (is_tensor) {
15574398b5aSJeremy L Thompson CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes));
156f82027a4SJeremy L Thompson } else {
157a61b1c91SJeremy L Thompson CeedInt elems_per_block = 64 * data->thread_1d > 256 ? 256 / data->thread_1d : 64;
158f82027a4SJeremy L Thompson
159f82027a4SJeremy L Thompson elems_per_block = elems_per_block > 0 ? elems_per_block : 1;
160f82027a4SJeremy L Thompson block_sizes[2] = elems_per_block;
161f82027a4SJeremy L Thompson }
16274398b5aSJeremy L Thompson if (data->dim == 1 || !is_tensor) {
1632b730f8bSJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
164a61b1c91SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar);
165b7453713SJeremy L Thompson
1661a8516d0SJames Wright CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem,
1671a8516d0SJames Wright is_run_good, opargs));
16874398b5aSJeremy L Thompson } else if (data->dim == 2) {
1692b730f8bSJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
170a61b1c91SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
171b7453713SJeremy L Thompson
1721a8516d0SJames Wright CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem,
1731a8516d0SJames Wright is_run_good, opargs));
17474398b5aSJeremy L Thompson } else if (data->dim == 3) {
1752b730f8bSJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
176a61b1c91SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
177b7453713SJeremy L Thompson
1781a8516d0SJames Wright CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem,
1791a8516d0SJames Wright is_run_good, opargs));
1807d8d0e25Snbeams }
1817d8d0e25Snbeams
1827d8d0e25Snbeams // Restore input arrays
1839e201c85SYohann for (CeedInt i = 0; i < num_input_fields; i++) {
1842b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
1859e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
1867d8d0e25Snbeams } else {
1873efc994bSJeremy L Thompson bool is_active;
188b7453713SJeremy L Thompson CeedVector vec;
189b7453713SJeremy L Thompson
1902b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
1913efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE;
192ea04d07fSJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
193ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec));
1947d8d0e25Snbeams }
1957d8d0e25Snbeams }
1967d8d0e25Snbeams
1977d8d0e25Snbeams // Restore output arrays
1989e201c85SYohann for (CeedInt i = 0; i < num_output_fields; i++) {
1992b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
2009e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
2017d8d0e25Snbeams } else {
2023efc994bSJeremy L Thompson bool is_active;
203b7453713SJeremy L Thompson CeedVector vec;
204b7453713SJeremy L Thompson
2052b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
2063efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE;
207ea04d07fSJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArray(vec, &data->fields.outputs[i]));
208ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec));
2097d8d0e25Snbeams }
2107d8d0e25Snbeams }
2117d8d0e25Snbeams
2123a2968d6SJeremy L Thompson // Restore point coordinates, if needed
2133a2968d6SJeremy L Thompson if (is_at_points) {
2143a2968d6SJeremy L Thompson CeedVector vec;
2153a2968d6SJeremy L Thompson
2163a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
2173a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords));
2183a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec));
2193a2968d6SJeremy L Thompson }
2203a2968d6SJeremy L Thompson
2217d8d0e25Snbeams // Restore context data
2222b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
2238d12f40eSJeremy L Thompson
2248d12f40eSJeremy L Thompson // Cleanup
2259bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed));
226c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionDestroy(&qf));
227ea04d07fSJeremy L Thompson if (!(*is_run_good)) data->use_fallback = true;
228ea04d07fSJeremy L Thompson return CEED_ERROR_SUCCESS;
229ea04d07fSJeremy L Thompson }
2308d12f40eSJeremy L Thompson
CeedOperatorApplyAdd_Hip_gen(CeedOperator op,CeedVector input_vec,CeedVector output_vec,CeedRequest * request)231ea04d07fSJeremy L Thompson static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
232ea04d07fSJeremy L Thompson bool is_run_good = false;
233ea04d07fSJeremy L Thompson const CeedScalar *input_arr = NULL;
234ea04d07fSJeremy L Thompson CeedScalar *output_arr = NULL;
235ea04d07fSJeremy L Thompson
236ea04d07fSJeremy L Thompson // Try to run kernel
237ea04d07fSJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr));
238ea04d07fSJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr));
239087855afSJeremy L Thompson CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(op, NULL, input_arr, output_arr, &is_run_good, request));
240ea04d07fSJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr));
241ea04d07fSJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr));
242ea04d07fSJeremy L Thompson
243ea04d07fSJeremy L Thompson // Fallback on unsuccessful run
244ea04d07fSJeremy L Thompson if (!is_run_good) {
2458d12f40eSJeremy L Thompson CeedOperator op_fallback;
2468d12f40eSJeremy L Thompson
247ca38d01dSJeremy L Thompson CeedDebug(CeedOperatorReturnCeed(op), "\nFalling back to /gpu/hip/ref CeedOperator for ApplyAdd\n");
2488d12f40eSJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
2498d12f40eSJeremy L Thompson CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
2508d12f40eSJeremy L Thompson }
251e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS;
2527d8d0e25Snbeams }
2537d8d0e25Snbeams
CeedOperatorApplyAddComposite_Hip_gen(CeedOperator op,CeedVector input_vec,CeedVector output_vec,CeedRequest * request)254c99afcd8SJeremy L Thompson static int CeedOperatorApplyAddComposite_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
255*be395853SZach Atkins bool is_run_good[CEED_COMPOSITE_MAX] = {false}, is_sequential;
256c99afcd8SJeremy L Thompson CeedInt num_suboperators;
257c99afcd8SJeremy L Thompson const CeedScalar *input_arr = NULL;
258af34f196SJeremy L Thompson CeedScalar *output_arr = NULL;
259087855afSJeremy L Thompson Ceed ceed;
2606eee1ffcSZach Atkins CeedOperator_Hip_gen *impl;
261c99afcd8SJeremy L Thompson CeedOperator *sub_operators;
262c99afcd8SJeremy L Thompson
263087855afSJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
2646eee1ffcSZach Atkins CeedCallBackend(CeedOperatorGetData(op, &impl));
265ed094490SJeremy L Thompson CeedCallBackend(CeedOperatorCompositeGetNumSub(op, &num_suboperators));
266ed094490SJeremy L Thompson CeedCallBackend(CeedOperatorCompositeGetSubList(op, &sub_operators));
267*be395853SZach Atkins CeedCall(CeedOperatorCompositeIsSequential(op, &is_sequential));
268c99afcd8SJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr));
269c99afcd8SJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr));
270c99afcd8SJeremy L Thompson for (CeedInt i = 0; i < num_suboperators; i++) {
271c99afcd8SJeremy L Thompson CeedInt num_elem = 0;
272*be395853SZach Atkins const CeedInt stream_index = is_sequential ? 0 : i;
273c99afcd8SJeremy L Thompson
2746eee1ffcSZach Atkins CeedCallBackend(CeedOperatorGetNumElements(sub_operators[i], &num_elem));
275c99afcd8SJeremy L Thompson if (num_elem > 0) {
276*be395853SZach Atkins if (!impl->streams[stream_index]) CeedCallHip(ceed, hipStreamCreate(&impl->streams[stream_index]));
277*be395853SZach Atkins CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(sub_operators[i], impl->streams[stream_index], input_arr, output_arr, &is_run_good[i],
278*be395853SZach Atkins request));
2796eee1ffcSZach Atkins } else {
2806eee1ffcSZach Atkins is_run_good[i] = true;
2816eee1ffcSZach Atkins }
2826eee1ffcSZach Atkins }
283*be395853SZach Atkins if (is_sequential) CeedCallHip(ceed, hipStreamSynchronize(impl->streams[0]));
284*be395853SZach Atkins else {
2856eee1ffcSZach Atkins for (CeedInt i = 0; i < num_suboperators; i++) {
2866eee1ffcSZach Atkins if (impl->streams[i]) {
2876eee1ffcSZach Atkins if (is_run_good[i]) CeedCallHip(ceed, hipStreamSynchronize(impl->streams[i]));
288c99afcd8SJeremy L Thompson }
289c99afcd8SJeremy L Thompson }
290*be395853SZach Atkins }
291c99afcd8SJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr));
292c99afcd8SJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr));
293087855afSJeremy L Thompson CeedCallHip(ceed, hipDeviceSynchronize());
294c99afcd8SJeremy L Thompson
295c99afcd8SJeremy L Thompson // Fallback on unsuccessful run
296c99afcd8SJeremy L Thompson for (CeedInt i = 0; i < num_suboperators; i++) {
297c99afcd8SJeremy L Thompson if (!is_run_good[i]) {
298c99afcd8SJeremy L Thompson CeedOperator op_fallback;
299c99afcd8SJeremy L Thompson
300ca38d01dSJeremy L Thompson CeedDebug(ceed, "\nFalling back to /gpu/hip/ref CeedOperator for ApplyAdd\n");
301c99afcd8SJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(sub_operators[i], &op_fallback));
302c99afcd8SJeremy L Thompson CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
303c99afcd8SJeremy L Thompson }
304c99afcd8SJeremy L Thompson }
305087855afSJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed));
306c99afcd8SJeremy L Thompson return CEED_ERROR_SUCCESS;
307c99afcd8SJeremy L Thompson }
308c99afcd8SJeremy L Thompson
3097d8d0e25Snbeams //------------------------------------------------------------------------------
3105daefc96SJeremy L Thompson // QFunction assembly
3115daefc96SJeremy L Thompson //------------------------------------------------------------------------------
CeedOperatorLinearAssembleQFunctionCore_Hip_gen(CeedOperator op,bool build_objects,CeedVector * assembled,CeedElemRestriction * rstr,CeedRequest * request)3125daefc96SJeremy L Thompson static int CeedOperatorLinearAssembleQFunctionCore_Hip_gen(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr,
3135daefc96SJeremy L Thompson CeedRequest *request) {
3145daefc96SJeremy L Thompson Ceed ceed;
3155daefc96SJeremy L Thompson CeedOperator_Hip_gen *data;
3165daefc96SJeremy L Thompson
3175daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
3185daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &data));
3195daefc96SJeremy L Thompson
3205daefc96SJeremy L Thompson // Build the assembly kernel
3215daefc96SJeremy L Thompson if (!data->assemble_qfunction && !data->use_assembly_fallback) {
3225daefc96SJeremy L Thompson bool is_build_good = false;
3235daefc96SJeremy L Thompson
3245daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_build_good));
3255daefc96SJeremy L Thompson if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelLinearAssembleQFunction_Hip_gen(op, &is_build_good));
3265daefc96SJeremy L Thompson if (!is_build_good) data->use_assembly_fallback = true;
3275daefc96SJeremy L Thompson }
3285daefc96SJeremy L Thompson
3295daefc96SJeremy L Thompson // Try assembly
3305daefc96SJeremy L Thompson if (!data->use_assembly_fallback) {
3315daefc96SJeremy L Thompson bool is_run_good = true;
3325daefc96SJeremy L Thompson Ceed_Hip *hip_data;
3335daefc96SJeremy L Thompson CeedInt num_elem, num_input_fields, num_output_fields;
3345daefc96SJeremy L Thompson CeedEvalMode eval_mode;
3355daefc96SJeremy L Thompson CeedScalar *assembled_array;
3365daefc96SJeremy L Thompson CeedQFunctionField *qf_input_fields, *qf_output_fields;
3375daefc96SJeremy L Thompson CeedQFunction_Hip_gen *qf_data;
3385daefc96SJeremy L Thompson CeedQFunction qf;
3395daefc96SJeremy L Thompson CeedOperatorField *op_input_fields, *op_output_fields;
3405daefc96SJeremy L Thompson
3415daefc96SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &hip_data));
3425daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
3435daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
3445daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
3455daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
3465daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
3475daefc96SJeremy L Thompson
3485daefc96SJeremy L Thompson // Input vectors
3495daefc96SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) {
3505daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
3515daefc96SJeremy L Thompson if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
3525daefc96SJeremy L Thompson data->fields.inputs[i] = NULL;
3535daefc96SJeremy L Thompson } else {
3545daefc96SJeremy L Thompson bool is_active;
3555daefc96SJeremy L Thompson CeedVector vec;
3565daefc96SJeremy L Thompson
3575daefc96SJeremy L Thompson // Get input vector
3585daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
3595daefc96SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE;
3605daefc96SJeremy L Thompson if (is_active) data->fields.inputs[i] = NULL;
3615daefc96SJeremy L Thompson else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
3625daefc96SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec));
3635daefc96SJeremy L Thompson }
3645daefc96SJeremy L Thompson }
3655daefc96SJeremy L Thompson
3665daefc96SJeremy L Thompson // Get context data
3675daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
3685daefc96SJeremy L Thompson
3695daefc96SJeremy L Thompson // Build objects if needed
3705daefc96SJeremy L Thompson if (build_objects) {
3715daefc96SJeremy L Thompson CeedInt qf_size_in = 0, qf_size_out = 0, Q;
3725daefc96SJeremy L Thompson
3735daefc96SJeremy L Thompson // Count number of active input fields
3745daefc96SJeremy L Thompson {
3755daefc96SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) {
3765daefc96SJeremy L Thompson CeedInt field_size;
3775daefc96SJeremy L Thompson CeedVector vec;
3785daefc96SJeremy L Thompson
3795daefc96SJeremy L Thompson // Get input vector
3805daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
3815daefc96SJeremy L Thompson // Check if active input
3825daefc96SJeremy L Thompson if (vec == CEED_VECTOR_ACTIVE) {
3835daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size));
3845daefc96SJeremy L Thompson qf_size_in += field_size;
3855daefc96SJeremy L Thompson }
3865daefc96SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec));
3875daefc96SJeremy L Thompson }
3885daefc96SJeremy L Thompson CeedCheck(qf_size_in > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs");
3895daefc96SJeremy L Thompson }
3905daefc96SJeremy L Thompson
3915daefc96SJeremy L Thompson // Count number of active output fields
3925daefc96SJeremy L Thompson {
3935daefc96SJeremy L Thompson for (CeedInt i = 0; i < num_output_fields; i++) {
3945daefc96SJeremy L Thompson CeedInt field_size;
3955daefc96SJeremy L Thompson CeedVector vec;
3965daefc96SJeremy L Thompson
3975daefc96SJeremy L Thompson // Get output vector
3985daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
3995daefc96SJeremy L Thompson // Check if active output
4005daefc96SJeremy L Thompson if (vec == CEED_VECTOR_ACTIVE) {
4015daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &field_size));
4025daefc96SJeremy L Thompson qf_size_out += field_size;
4035daefc96SJeremy L Thompson }
4045daefc96SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec));
4055daefc96SJeremy L Thompson }
4065daefc96SJeremy L Thompson CeedCheck(qf_size_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs");
4075daefc96SJeremy L Thompson }
4085daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q));
4095daefc96SJeremy L Thompson
4105daefc96SJeremy L Thompson // Actually build objects now
4115daefc96SJeremy L Thompson const CeedSize l_size = (CeedSize)num_elem * Q * qf_size_in * qf_size_out;
4125daefc96SJeremy L Thompson CeedInt strides[3] = {1, num_elem * Q, Q}; /* *NOPAD* */
4135daefc96SJeremy L Thompson
4145daefc96SJeremy L Thompson // Create output restriction
4155daefc96SJeremy L Thompson CeedCallBackend(CeedElemRestrictionCreateStrided(ceed, num_elem, Q, qf_size_in * qf_size_out,
4165daefc96SJeremy L Thompson (CeedSize)qf_size_in * (CeedSize)qf_size_out * (CeedSize)num_elem * (CeedSize)Q, strides,
4175daefc96SJeremy L Thompson rstr));
4185daefc96SJeremy L Thompson // Create assembled vector
4195daefc96SJeremy L Thompson CeedCallBackend(CeedVectorCreate(ceed, l_size, assembled));
4205daefc96SJeremy L Thompson }
4215daefc96SJeremy L Thompson
4225daefc96SJeremy L Thompson // Assembly array
4235daefc96SJeremy L Thompson CeedCallBackend(CeedVectorGetArrayWrite(*assembled, CEED_MEM_DEVICE, &assembled_array));
4245daefc96SJeremy L Thompson
4255daefc96SJeremy L Thompson // Assemble QFunction
4265daefc96SJeremy L Thompson bool is_tensor = false;
4275daefc96SJeremy L Thompson void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points, &assembled_array};
4285daefc96SJeremy L Thompson
4295daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorHasTensorBases(op, &is_tensor));
4305daefc96SJeremy L Thompson CeedInt block_sizes[3] = {data->thread_1d, ((!is_tensor || data->dim == 1) ? 1 : data->thread_1d), -1};
4315daefc96SJeremy L Thompson
4325daefc96SJeremy L Thompson if (is_tensor) {
4335daefc96SJeremy L Thompson CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes));
4345daefc96SJeremy L Thompson } else {
4355daefc96SJeremy L Thompson CeedInt elems_per_block = 64 * data->thread_1d > 256 ? 256 / data->thread_1d : 64;
4365daefc96SJeremy L Thompson
4375daefc96SJeremy L Thompson elems_per_block = elems_per_block > 0 ? elems_per_block : 1;
4385daefc96SJeremy L Thompson block_sizes[2] = elems_per_block;
4395daefc96SJeremy L Thompson }
4405daefc96SJeremy L Thompson if (data->dim == 1 || !is_tensor) {
4415daefc96SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
4425daefc96SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar);
4435daefc96SJeremy L Thompson
4445daefc96SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_qfunction, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2],
4455daefc96SJeremy L Thompson sharedMem, &is_run_good, opargs));
4465daefc96SJeremy L Thompson } else if (data->dim == 2) {
4475daefc96SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
4485daefc96SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
4495daefc96SJeremy L Thompson
4505daefc96SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_qfunction, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2],
4515daefc96SJeremy L Thompson sharedMem, &is_run_good, opargs));
4525daefc96SJeremy L Thompson } else if (data->dim == 3) {
4535daefc96SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
4545daefc96SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
4555daefc96SJeremy L Thompson
4565daefc96SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_qfunction, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2],
4575daefc96SJeremy L Thompson sharedMem, &is_run_good, opargs));
4585daefc96SJeremy L Thompson }
4595daefc96SJeremy L Thompson
4605daefc96SJeremy L Thompson // Restore input arrays
4615daefc96SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) {
4625daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
4635daefc96SJeremy L Thompson if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
4645daefc96SJeremy L Thompson } else {
4655daefc96SJeremy L Thompson bool is_active;
4665daefc96SJeremy L Thompson CeedVector vec;
4675daefc96SJeremy L Thompson
4685daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
4695daefc96SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE;
4705daefc96SJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
4715daefc96SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec));
4725daefc96SJeremy L Thompson }
4735daefc96SJeremy L Thompson }
4745daefc96SJeremy L Thompson
4755daefc96SJeremy L Thompson // Restore context data
4765daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
4775daefc96SJeremy L Thompson
4785daefc96SJeremy L Thompson // Restore assembly array
4795daefc96SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array));
4805daefc96SJeremy L Thompson
4815daefc96SJeremy L Thompson // Cleanup
4825daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionDestroy(&qf));
4835daefc96SJeremy L Thompson if (!is_run_good) {
4845daefc96SJeremy L Thompson data->use_assembly_fallback = true;
4855daefc96SJeremy L Thompson if (build_objects) {
4865daefc96SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(assembled));
4875daefc96SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(rstr));
4885daefc96SJeremy L Thompson }
4895daefc96SJeremy L Thompson }
4905daefc96SJeremy L Thompson }
4915daefc96SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed));
4925daefc96SJeremy L Thompson
4935daefc96SJeremy L Thompson // Fallback, if needed
4945daefc96SJeremy L Thompson if (data->use_assembly_fallback) {
4955daefc96SJeremy L Thompson CeedOperator op_fallback;
4965daefc96SJeremy L Thompson
497ca38d01dSJeremy L Thompson CeedDebug(CeedOperatorReturnCeed(op), "\nFalling back to /gpu/hip/ref CeedOperator for LineearAssembleQFunction\n");
4985daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
499ed094490SJeremy L Thompson CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdateFallback(op_fallback, assembled, rstr, request));
5005daefc96SJeremy L Thompson return CEED_ERROR_SUCCESS;
5015daefc96SJeremy L Thompson }
5025daefc96SJeremy L Thompson return CEED_ERROR_SUCCESS;
5035daefc96SJeremy L Thompson }
5045daefc96SJeremy L Thompson
CeedOperatorLinearAssembleQFunction_Hip_gen(CeedOperator op,CeedVector * assembled,CeedElemRestriction * rstr,CeedRequest * request)5055daefc96SJeremy L Thompson static int CeedOperatorLinearAssembleQFunction_Hip_gen(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) {
5065daefc96SJeremy L Thompson return CeedOperatorLinearAssembleQFunctionCore_Hip_gen(op, true, assembled, rstr, request);
5075daefc96SJeremy L Thompson }
5085daefc96SJeremy L Thompson
CeedOperatorLinearAssembleQFunctionUpdate_Hip_gen(CeedOperator op,CeedVector assembled,CeedElemRestriction rstr,CeedRequest * request)5095daefc96SJeremy L Thompson static int CeedOperatorLinearAssembleQFunctionUpdate_Hip_gen(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) {
5105daefc96SJeremy L Thompson return CeedOperatorLinearAssembleQFunctionCore_Hip_gen(op, false, &assembled, &rstr, request);
5115daefc96SJeremy L Thompson }
5125daefc96SJeremy L Thompson
5135daefc96SJeremy L Thompson //------------------------------------------------------------------------------
5140183ed61SJeremy L Thompson // AtPoints diagonal assembly
5150183ed61SJeremy L Thompson //------------------------------------------------------------------------------
CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip_gen(CeedOperator op,CeedVector assembled,CeedRequest * request)5160183ed61SJeremy L Thompson static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip_gen(CeedOperator op, CeedVector assembled, CeedRequest *request) {
5170183ed61SJeremy L Thompson Ceed ceed;
5180183ed61SJeremy L Thompson CeedOperator_Hip_gen *data;
5190183ed61SJeremy L Thompson
5200183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
5210183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &data));
5220183ed61SJeremy L Thompson
5230183ed61SJeremy L Thompson // Build the assembly kernel
5240183ed61SJeremy L Thompson if (!data->assemble_diagonal && !data->use_assembly_fallback) {
5250183ed61SJeremy L Thompson bool is_build_good = false;
5260183ed61SJeremy L Thompson CeedInt num_active_bases_in, num_active_bases_out;
5270183ed61SJeremy L Thompson CeedOperatorAssemblyData assembly_data;
5280183ed61SJeremy L Thompson
5290183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetOperatorAssemblyData(op, &assembly_data));
5301a8516d0SJames Wright CeedCallBackend(CeedOperatorAssemblyDataGetEvalModes(assembly_data, &num_active_bases_in, NULL, NULL, NULL, &num_active_bases_out, NULL, NULL,
5311a8516d0SJames Wright NULL, NULL));
5320183ed61SJeremy L Thompson if (num_active_bases_in == num_active_bases_out) {
5330183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_build_good));
5340183ed61SJeremy L Thompson if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelDiagonalAssemblyAtPoints_Hip_gen(op, &is_build_good));
5350183ed61SJeremy L Thompson }
5360183ed61SJeremy L Thompson if (!is_build_good) data->use_assembly_fallback = true;
5370183ed61SJeremy L Thompson }
5380183ed61SJeremy L Thompson
5390183ed61SJeremy L Thompson // Try assembly
5400183ed61SJeremy L Thompson if (!data->use_assembly_fallback) {
5410183ed61SJeremy L Thompson bool is_run_good = true;
5420183ed61SJeremy L Thompson Ceed_Hip *hip_data;
5430183ed61SJeremy L Thompson CeedInt num_elem, num_input_fields, num_output_fields;
5440183ed61SJeremy L Thompson CeedEvalMode eval_mode;
5450183ed61SJeremy L Thompson CeedScalar *assembled_array;
5460183ed61SJeremy L Thompson CeedQFunctionField *qf_input_fields, *qf_output_fields;
5470183ed61SJeremy L Thompson CeedQFunction_Hip_gen *qf_data;
5480183ed61SJeremy L Thompson CeedQFunction qf;
5490183ed61SJeremy L Thompson CeedOperatorField *op_input_fields, *op_output_fields;
5500183ed61SJeremy L Thompson
5510183ed61SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &hip_data));
5520183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
5530183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
5540183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
5550183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
5560183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
5570183ed61SJeremy L Thompson
5580183ed61SJeremy L Thompson // Input vectors
5590183ed61SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) {
5600183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
5610183ed61SJeremy L Thompson if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
5620183ed61SJeremy L Thompson data->fields.inputs[i] = NULL;
5630183ed61SJeremy L Thompson } else {
5640183ed61SJeremy L Thompson bool is_active;
5650183ed61SJeremy L Thompson CeedVector vec;
5660183ed61SJeremy L Thompson
5670183ed61SJeremy L Thompson // Get input vector
5680183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
5690183ed61SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE;
5700183ed61SJeremy L Thompson if (is_active) data->fields.inputs[i] = NULL;
5710183ed61SJeremy L Thompson else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
5720183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec));
5730183ed61SJeremy L Thompson }
5740183ed61SJeremy L Thompson }
5750183ed61SJeremy L Thompson
5760183ed61SJeremy L Thompson // Point coordinates
5770183ed61SJeremy L Thompson {
5780183ed61SJeremy L Thompson CeedVector vec;
5790183ed61SJeremy L Thompson
5800183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
5810183ed61SJeremy L Thompson CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords));
5820183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec));
5830183ed61SJeremy L Thompson
5840183ed61SJeremy L Thompson // Points per elem
5850183ed61SJeremy L Thompson if (num_elem != data->points.num_elem) {
5860183ed61SJeremy L Thompson CeedInt *points_per_elem;
5870183ed61SJeremy L Thompson const CeedInt num_bytes = num_elem * sizeof(CeedInt);
5880183ed61SJeremy L Thompson CeedElemRestriction rstr_points = NULL;
5890183ed61SJeremy L Thompson
5900183ed61SJeremy L Thompson data->points.num_elem = num_elem;
5910183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL));
5920183ed61SJeremy L Thompson CeedCallBackend(CeedCalloc(num_elem, &points_per_elem));
5930183ed61SJeremy L Thompson for (CeedInt e = 0; e < num_elem; e++) {
5940183ed61SJeremy L Thompson CeedInt num_points_elem;
5950183ed61SJeremy L Thompson
5960183ed61SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem));
5970183ed61SJeremy L Thompson points_per_elem[e] = num_points_elem;
5980183ed61SJeremy L Thompson }
5990183ed61SJeremy L Thompson if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem));
6000183ed61SJeremy L Thompson CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes));
6010183ed61SJeremy L Thompson CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice));
6020183ed61SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
6030183ed61SJeremy L Thompson CeedCallBackend(CeedFree(&points_per_elem));
6040183ed61SJeremy L Thompson }
6050183ed61SJeremy L Thompson }
6060183ed61SJeremy L Thompson
6070183ed61SJeremy L Thompson // Get context data
6080183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
6090183ed61SJeremy L Thompson
6100183ed61SJeremy L Thompson // Assembly array
6110183ed61SJeremy L Thompson CeedCallBackend(CeedVectorGetArray(assembled, CEED_MEM_DEVICE, &assembled_array));
6120183ed61SJeremy L Thompson
6130183ed61SJeremy L Thompson // Assemble diagonal
6140183ed61SJeremy L Thompson void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points, &assembled_array};
6150183ed61SJeremy L Thompson
6160183ed61SJeremy L Thompson CeedInt block_sizes[3] = {data->thread_1d, (data->dim == 1 ? 1 : data->thread_1d), -1};
6170183ed61SJeremy L Thompson
6180183ed61SJeremy L Thompson CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes));
6190183ed61SJeremy L Thompson block_sizes[2] = 1;
6200183ed61SJeremy L Thompson if (data->dim == 1) {
6210183ed61SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
6220183ed61SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar);
6230183ed61SJeremy L Thompson
6240183ed61SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2],
6250183ed61SJeremy L Thompson sharedMem, &is_run_good, opargs));
6260183ed61SJeremy L Thompson } else if (data->dim == 2) {
6270183ed61SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
6280183ed61SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
6290183ed61SJeremy L Thompson
6300183ed61SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2],
6310183ed61SJeremy L Thompson sharedMem, &is_run_good, opargs));
6320183ed61SJeremy L Thompson } else if (data->dim == 3) {
6330183ed61SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
6340183ed61SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
6350183ed61SJeremy L Thompson
6360183ed61SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2],
6370183ed61SJeremy L Thompson sharedMem, &is_run_good, opargs));
6380183ed61SJeremy L Thompson }
639692716b7SZach Atkins CeedCallHip(ceed, hipDeviceSynchronize());
6400183ed61SJeremy L Thompson
6410183ed61SJeremy L Thompson // Restore input arrays
6420183ed61SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) {
6430183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
6440183ed61SJeremy L Thompson if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
6450183ed61SJeremy L Thompson } else {
6460183ed61SJeremy L Thompson bool is_active;
6470183ed61SJeremy L Thompson CeedVector vec;
6480183ed61SJeremy L Thompson
6490183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
6500183ed61SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE;
6510183ed61SJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
6520183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec));
6530183ed61SJeremy L Thompson }
6540183ed61SJeremy L Thompson }
6550183ed61SJeremy L Thompson
6560183ed61SJeremy L Thompson // Restore point coordinates
6570183ed61SJeremy L Thompson {
6580183ed61SJeremy L Thompson CeedVector vec;
6590183ed61SJeremy L Thompson
6600183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
6610183ed61SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords));
6620183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec));
6630183ed61SJeremy L Thompson }
6640183ed61SJeremy L Thompson
6650183ed61SJeremy L Thompson // Restore context data
6660183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
6670183ed61SJeremy L Thompson
6680183ed61SJeremy L Thompson // Restore assembly array
6690183ed61SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArray(assembled, &assembled_array));
6700183ed61SJeremy L Thompson
6710183ed61SJeremy L Thompson // Cleanup
6720183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionDestroy(&qf));
6730183ed61SJeremy L Thompson if (!is_run_good) data->use_assembly_fallback = true;
6740183ed61SJeremy L Thompson }
6750183ed61SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed));
6760183ed61SJeremy L Thompson
6770183ed61SJeremy L Thompson // Fallback, if needed
6780183ed61SJeremy L Thompson if (data->use_assembly_fallback) {
6790183ed61SJeremy L Thompson CeedOperator op_fallback;
6800183ed61SJeremy L Thompson
681ca38d01dSJeremy L Thompson CeedDebug(CeedOperatorReturnCeed(op), "\nFalling back to /gpu/hip/ref CeedOperator for AtPoints LinearAssembleAddDiagonal\n");
6820183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
6830183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorLinearAssembleAddDiagonal(op_fallback, assembled, request));
6840183ed61SJeremy L Thompson return CEED_ERROR_SUCCESS;
6850183ed61SJeremy L Thompson }
6860183ed61SJeremy L Thompson return CEED_ERROR_SUCCESS;
6870183ed61SJeremy L Thompson }
6880183ed61SJeremy L Thompson
6890183ed61SJeremy L Thompson //------------------------------------------------------------------------------
690692716b7SZach Atkins // AtPoints full assembly
691692716b7SZach Atkins //------------------------------------------------------------------------------
CeedOperatorAssembleSingleAtPoints_Hip_gen(CeedOperator op,CeedInt offset,CeedVector assembled)692ed094490SJeremy L Thompson static int CeedOperatorAssembleSingleAtPoints_Hip_gen(CeedOperator op, CeedInt offset, CeedVector assembled) {
693692716b7SZach Atkins Ceed ceed;
694692716b7SZach Atkins CeedOperator_Hip_gen *data;
695692716b7SZach Atkins
696692716b7SZach Atkins CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
697692716b7SZach Atkins CeedCallBackend(CeedOperatorGetData(op, &data));
698692716b7SZach Atkins
699692716b7SZach Atkins // Build the assembly kernel
700692716b7SZach Atkins if (!data->assemble_full && !data->use_assembly_fallback) {
701692716b7SZach Atkins bool is_build_good = false;
702692716b7SZach Atkins CeedInt num_active_bases_in, num_active_bases_out;
703692716b7SZach Atkins CeedOperatorAssemblyData assembly_data;
704692716b7SZach Atkins
705692716b7SZach Atkins CeedCallBackend(CeedOperatorGetOperatorAssemblyData(op, &assembly_data));
7061a8516d0SJames Wright CeedCallBackend(CeedOperatorAssemblyDataGetEvalModes(assembly_data, &num_active_bases_in, NULL, NULL, NULL, &num_active_bases_out, NULL, NULL,
7071a8516d0SJames Wright NULL, NULL));
708692716b7SZach Atkins if (num_active_bases_in == num_active_bases_out) {
709692716b7SZach Atkins CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_build_good));
710692716b7SZach Atkins if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelFullAssemblyAtPoints_Hip_gen(op, &is_build_good));
711692716b7SZach Atkins }
712692716b7SZach Atkins if (!is_build_good) {
713692716b7SZach Atkins CeedDebug(ceed, "Single Operator Assemble at Points compile failed, using fallback\n");
714692716b7SZach Atkins data->use_assembly_fallback = true;
715692716b7SZach Atkins }
716692716b7SZach Atkins }
717692716b7SZach Atkins
718692716b7SZach Atkins // Try assembly
719692716b7SZach Atkins if (!data->use_assembly_fallback) {
720692716b7SZach Atkins bool is_run_good = true;
721692716b7SZach Atkins Ceed_Hip *Hip_data;
722692716b7SZach Atkins CeedInt num_elem, num_input_fields, num_output_fields;
723692716b7SZach Atkins CeedEvalMode eval_mode;
724692716b7SZach Atkins CeedScalar *assembled_array;
725692716b7SZach Atkins CeedQFunctionField *qf_input_fields, *qf_output_fields;
726692716b7SZach Atkins CeedQFunction_Hip_gen *qf_data;
727692716b7SZach Atkins CeedQFunction qf;
728692716b7SZach Atkins CeedOperatorField *op_input_fields, *op_output_fields;
729692716b7SZach Atkins
730692716b7SZach Atkins CeedCallBackend(CeedGetData(ceed, &Hip_data));
731692716b7SZach Atkins CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
732692716b7SZach Atkins CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
733692716b7SZach Atkins CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
734692716b7SZach Atkins CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
735692716b7SZach Atkins CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
736692716b7SZach Atkins CeedDebug(ceed, "Running single operator assemble for /gpu/hip/gen\n");
737692716b7SZach Atkins
738692716b7SZach Atkins // Input vectors
739692716b7SZach Atkins for (CeedInt i = 0; i < num_input_fields; i++) {
740692716b7SZach Atkins CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
741692716b7SZach Atkins if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
742692716b7SZach Atkins data->fields.inputs[i] = NULL;
743692716b7SZach Atkins } else {
744692716b7SZach Atkins bool is_active;
745692716b7SZach Atkins CeedVector vec;
746692716b7SZach Atkins
747692716b7SZach Atkins // Get input vector
748692716b7SZach Atkins CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
749692716b7SZach Atkins is_active = vec == CEED_VECTOR_ACTIVE;
750692716b7SZach Atkins if (is_active) data->fields.inputs[i] = NULL;
751692716b7SZach Atkins else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
752692716b7SZach Atkins CeedCallBackend(CeedVectorDestroy(&vec));
753692716b7SZach Atkins }
754692716b7SZach Atkins }
755692716b7SZach Atkins
756692716b7SZach Atkins // Point coordinates
757692716b7SZach Atkins {
758692716b7SZach Atkins CeedVector vec;
759692716b7SZach Atkins
760692716b7SZach Atkins CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
761692716b7SZach Atkins CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords));
762692716b7SZach Atkins CeedCallBackend(CeedVectorDestroy(&vec));
763692716b7SZach Atkins
764692716b7SZach Atkins // Points per elem
765692716b7SZach Atkins if (num_elem != data->points.num_elem) {
766692716b7SZach Atkins CeedInt *points_per_elem;
767692716b7SZach Atkins const CeedInt num_bytes = num_elem * sizeof(CeedInt);
768692716b7SZach Atkins CeedElemRestriction rstr_points = NULL;
769692716b7SZach Atkins
770692716b7SZach Atkins data->points.num_elem = num_elem;
771692716b7SZach Atkins CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL));
772692716b7SZach Atkins CeedCallBackend(CeedCalloc(num_elem, &points_per_elem));
773692716b7SZach Atkins for (CeedInt e = 0; e < num_elem; e++) {
774692716b7SZach Atkins CeedInt num_points_elem;
775692716b7SZach Atkins
776692716b7SZach Atkins CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem));
777692716b7SZach Atkins points_per_elem[e] = num_points_elem;
778692716b7SZach Atkins }
779692716b7SZach Atkins if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem));
780692716b7SZach Atkins CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes));
781692716b7SZach Atkins CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice));
782692716b7SZach Atkins CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
783692716b7SZach Atkins CeedCallBackend(CeedFree(&points_per_elem));
784692716b7SZach Atkins }
785692716b7SZach Atkins }
786692716b7SZach Atkins
787692716b7SZach Atkins // Get context data
788692716b7SZach Atkins CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
789692716b7SZach Atkins
790692716b7SZach Atkins // Assembly array
791692716b7SZach Atkins CeedCallBackend(CeedVectorGetArray(assembled, CEED_MEM_DEVICE, &assembled_array));
792692716b7SZach Atkins CeedScalar *assembled_offset_array = &assembled_array[offset];
793692716b7SZach Atkins
794692716b7SZach Atkins // Assemble diagonal
795692716b7SZach Atkins void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B,
796692716b7SZach Atkins &data->G, &data->W, &data->points, &assembled_offset_array};
797692716b7SZach Atkins
798692716b7SZach Atkins CeedInt block_sizes[3] = {data->thread_1d, (data->dim == 1 ? 1 : data->thread_1d), -1};
799692716b7SZach Atkins
800692716b7SZach Atkins CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes));
801692716b7SZach Atkins block_sizes[2] = 1;
802692716b7SZach Atkins if (data->dim == 1) {
803692716b7SZach Atkins CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
804692716b7SZach Atkins CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar);
805692716b7SZach Atkins
806692716b7SZach Atkins CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_full, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem,
807692716b7SZach Atkins &is_run_good, opargs));
808692716b7SZach Atkins } else if (data->dim == 2) {
809692716b7SZach Atkins CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
810692716b7SZach Atkins CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
811692716b7SZach Atkins
812692716b7SZach Atkins CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_full, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem,
813692716b7SZach Atkins &is_run_good, opargs));
814692716b7SZach Atkins } else if (data->dim == 3) {
815692716b7SZach Atkins CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
816692716b7SZach Atkins CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
817692716b7SZach Atkins
818692716b7SZach Atkins CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_full, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem,
819692716b7SZach Atkins &is_run_good, opargs));
820692716b7SZach Atkins }
821692716b7SZach Atkins CeedCallHip(ceed, hipDeviceSynchronize());
822692716b7SZach Atkins
823692716b7SZach Atkins // Restore input arrays
824692716b7SZach Atkins for (CeedInt i = 0; i < num_input_fields; i++) {
825692716b7SZach Atkins CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
826692716b7SZach Atkins if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
827692716b7SZach Atkins } else {
828692716b7SZach Atkins bool is_active;
829692716b7SZach Atkins CeedVector vec;
830692716b7SZach Atkins
831692716b7SZach Atkins CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
832692716b7SZach Atkins is_active = vec == CEED_VECTOR_ACTIVE;
833692716b7SZach Atkins if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
834692716b7SZach Atkins CeedCallBackend(CeedVectorDestroy(&vec));
835692716b7SZach Atkins }
836692716b7SZach Atkins }
837692716b7SZach Atkins
838692716b7SZach Atkins // Restore point coordinates
839692716b7SZach Atkins {
840692716b7SZach Atkins CeedVector vec;
841692716b7SZach Atkins
842692716b7SZach Atkins CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
843692716b7SZach Atkins CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords));
844692716b7SZach Atkins CeedCallBackend(CeedVectorDestroy(&vec));
845692716b7SZach Atkins }
846692716b7SZach Atkins
847692716b7SZach Atkins // Restore context data
848692716b7SZach Atkins CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
849692716b7SZach Atkins
850692716b7SZach Atkins // Restore assembly array
851692716b7SZach Atkins CeedCallBackend(CeedVectorRestoreArray(assembled, &assembled_array));
852692716b7SZach Atkins
853692716b7SZach Atkins // Cleanup
854692716b7SZach Atkins CeedCallBackend(CeedQFunctionDestroy(&qf));
855692716b7SZach Atkins if (!is_run_good) {
856692716b7SZach Atkins CeedDebug(ceed, "Single Operator Assemble at Points run failed, using fallback\n");
857692716b7SZach Atkins data->use_assembly_fallback = true;
858692716b7SZach Atkins }
859692716b7SZach Atkins }
860692716b7SZach Atkins CeedCallBackend(CeedDestroy(&ceed));
861692716b7SZach Atkins
862692716b7SZach Atkins // Fallback, if needed
863692716b7SZach Atkins if (data->use_assembly_fallback) {
864692716b7SZach Atkins CeedOperator op_fallback;
865692716b7SZach Atkins
866ca38d01dSJeremy L Thompson CeedDebug(CeedOperatorReturnCeed(op), "\nFalling back to /gpu/hip/ref CeedOperator for AtPoints SingleOperatorAssemble\n");
867692716b7SZach Atkins CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
868ed094490SJeremy L Thompson CeedCallBackend(CeedOperatorAssembleSingle(op_fallback, offset, assembled));
869692716b7SZach Atkins return CEED_ERROR_SUCCESS;
870692716b7SZach Atkins }
871692716b7SZach Atkins return CEED_ERROR_SUCCESS;
872692716b7SZach Atkins }
873692716b7SZach Atkins
874692716b7SZach Atkins //------------------------------------------------------------------------------
8757d8d0e25Snbeams // Create operator
8767d8d0e25Snbeams //------------------------------------------------------------------------------
CeedOperatorCreate_Hip_gen(CeedOperator op)8777d8d0e25Snbeams int CeedOperatorCreate_Hip_gen(CeedOperator op) {
8780183ed61SJeremy L Thompson bool is_composite, is_at_points;
8797d8d0e25Snbeams Ceed ceed;
8807d8d0e25Snbeams CeedOperator_Hip_gen *impl;
8817d8d0e25Snbeams
882b7453713SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
8832b730f8bSJeremy L Thompson CeedCallBackend(CeedCalloc(1, &impl));
8842b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorSetData(op, impl));
885c99afcd8SJeremy L Thompson CeedCall(CeedOperatorIsComposite(op, &is_composite));
886c99afcd8SJeremy L Thompson if (is_composite) {
887c99afcd8SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAddComposite", CeedOperatorApplyAddComposite_Hip_gen));
888c99afcd8SJeremy L Thompson } else {
8892b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip_gen));
890c99afcd8SJeremy L Thompson }
8910183ed61SJeremy L Thompson CeedCall(CeedOperatorIsAtPoints(op, &is_at_points));
8920183ed61SJeremy L Thompson if (is_at_points) {
8930183ed61SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip_gen));
894ed094490SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedOperatorAssembleSingleAtPoints_Hip_gen));
8950183ed61SJeremy L Thompson }
8965daefc96SJeremy L Thompson if (!is_at_points) {
8975daefc96SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Hip_gen));
8985daefc96SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Hip_gen));
8995daefc96SJeremy L Thompson }
9002b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip_gen));
9019bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed));
902e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS;
9037d8d0e25Snbeams }
9042a86cc9dSSebastian Grimberg
9057d8d0e25Snbeams //------------------------------------------------------------------------------
906