1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors. 23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 37d8d0e25Snbeams // 43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause 57d8d0e25Snbeams // 63d8e8822SJeremy L Thompson // This file is part of CEED: http://github.com/ceed 77d8d0e25Snbeams 849aac155SJeremy L Thompson #include <ceed.h> 9ec3da8bcSJed Brown #include <ceed/backend.h> 1049aac155SJeremy L Thompson #include <ceed/jit-source/hip/hip-types.h> 113d576824SJeremy L Thompson #include <stddef.h> 123a2968d6SJeremy L Thompson #include <hip/hiprtc.h> 132b730f8bSJeremy L Thompson 14b2165e7aSSebastian Grimberg #include "../hip/ceed-hip-common.h" 157d8d0e25Snbeams #include "../hip/ceed-hip-compile.h" 162b730f8bSJeremy L Thompson #include "ceed-hip-gen-operator-build.h" 172b730f8bSJeremy L Thompson #include "ceed-hip-gen.h" 187d8d0e25Snbeams 197d8d0e25Snbeams //------------------------------------------------------------------------------ 207d8d0e25Snbeams // Destroy operator 217d8d0e25Snbeams //------------------------------------------------------------------------------ 227d8d0e25Snbeams static int CeedOperatorDestroy_Hip_gen(CeedOperator op) { 233a2968d6SJeremy L Thompson Ceed ceed; 247d8d0e25Snbeams CeedOperator_Hip_gen *impl; 256eee1ffcSZach Atkins bool is_composite; 26b7453713SJeremy L Thompson 273a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 282b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &impl)); 296eee1ffcSZach Atkins CeedCallBackend(CeedOperatorIsComposite(op, &is_composite)); 306eee1ffcSZach Atkins if (is_composite) { 316eee1ffcSZach Atkins CeedInt num_suboperators; 326eee1ffcSZach Atkins 33ed094490SJeremy L Thompson CeedCall(CeedOperatorCompositeGetNumSub(op, &num_suboperators)); 346eee1ffcSZach Atkins for (CeedInt i = 0; i < num_suboperators; i++) { 356eee1ffcSZach Atkins if (impl->streams[i]) CeedCallHip(ceed, hipStreamDestroy(impl->streams[i])); 366eee1ffcSZach Atkins impl->streams[i] = NULL; 376eee1ffcSZach Atkins } 386eee1ffcSZach Atkins } 398b7d3340SJeremy L Thompson if (impl->module) CeedCallHip(ceed, hipModuleUnload(impl->module)); 400183ed61SJeremy L Thompson if (impl->module_assemble_full) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_full)); 410183ed61SJeremy L Thompson if (impl->module_assemble_diagonal) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_diagonal)); 425daefc96SJeremy L Thompson if (impl->module_assemble_qfunction) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_qfunction)); 433a2968d6SJeremy L Thompson if (impl->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)impl->points.num_per_elem)); 442b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&impl)); 453a2968d6SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 46e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 477d8d0e25Snbeams } 487d8d0e25Snbeams 497d8d0e25Snbeams //------------------------------------------------------------------------------ 507d8d0e25Snbeams // Apply and add to output 517d8d0e25Snbeams //------------------------------------------------------------------------------ 52e9c76bddSJeremy L Thompson static int CeedOperatorApplyAddCore_Hip_gen(CeedOperator op, hipStream_t stream, const CeedScalar *input_arr, CeedScalar *output_arr, 53e9c76bddSJeremy L Thompson bool *is_run_good, CeedRequest *request) { 54ea04d07fSJeremy L Thompson bool is_at_points, is_tensor; 557d8d0e25Snbeams Ceed ceed; 56b7453713SJeremy L Thompson CeedInt num_elem, num_input_fields, num_output_fields; 57b7453713SJeremy L Thompson CeedEvalMode eval_mode; 58b7453713SJeremy L Thompson CeedQFunctionField *qf_input_fields, *qf_output_fields; 597d8d0e25Snbeams CeedQFunction_Hip_gen *qf_data; 60b7453713SJeremy L Thompson CeedQFunction qf; 61b7453713SJeremy L Thompson CeedOperatorField *op_input_fields, *op_output_fields; 62b7453713SJeremy L Thompson CeedOperator_Hip_gen *data; 63b7453713SJeremy L Thompson 648d12f40eSJeremy L Thompson // Creation of the operator 65ea04d07fSJeremy L Thompson CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, is_run_good)); 66ea04d07fSJeremy L Thompson if (!(*is_run_good)) return CEED_ERROR_SUCCESS; 67f6eafd79SJeremy L Thompson 68c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 69c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &data)); 70c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 71c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &qf_data)); 72c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 738d12f40eSJeremy L Thompson CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 74c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 75c11e12f4SJeremy L Thompson 767d8d0e25Snbeams // Input vectors 779e201c85SYohann for (CeedInt i = 0; i < num_input_fields; i++) { 782b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 799e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 809e201c85SYohann data->fields.inputs[i] = NULL; 817d8d0e25Snbeams } else { 823efc994bSJeremy L Thompson bool is_active; 83b7453713SJeremy L Thompson CeedVector vec; 84b7453713SJeremy L Thompson 857d8d0e25Snbeams // Get input vector 862b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 873efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 88ea04d07fSJeremy L Thompson if (is_active) data->fields.inputs[i] = input_arr; 89ea04d07fSJeremy L Thompson else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i])); 90ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 917d8d0e25Snbeams } 927d8d0e25Snbeams } 937d8d0e25Snbeams 947d8d0e25Snbeams // Output vectors 959e201c85SYohann for (CeedInt i = 0; i < num_output_fields; i++) { 962b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 979e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 989e201c85SYohann data->fields.outputs[i] = NULL; 997d8d0e25Snbeams } else { 1003efc994bSJeremy L Thompson bool is_active; 101b7453713SJeremy L Thompson CeedVector vec; 102b7453713SJeremy L Thompson 1037d8d0e25Snbeams // Get output vector 1042b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 1053efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 106ea04d07fSJeremy L Thompson if (is_active) data->fields.outputs[i] = output_arr; 1070c8fbeedSJeremy L Thompson else CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.outputs[i])); 108ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 1097d8d0e25Snbeams } 1107d8d0e25Snbeams } 1117d8d0e25Snbeams 1123a2968d6SJeremy L Thompson // Point coordinates, if needed 1133a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorIsAtPoints(op, &is_at_points)); 1143a2968d6SJeremy L Thompson if (is_at_points) { 1153a2968d6SJeremy L Thompson // Coords 1163a2968d6SJeremy L Thompson CeedVector vec; 1173a2968d6SJeremy L Thompson 1183a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 1193a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords)); 1203a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 1213a2968d6SJeremy L Thompson 1223a2968d6SJeremy L Thompson // Points per elem 1233a2968d6SJeremy L Thompson if (num_elem != data->points.num_elem) { 1243a2968d6SJeremy L Thompson CeedInt *points_per_elem; 1253a2968d6SJeremy L Thompson const CeedInt num_bytes = num_elem * sizeof(CeedInt); 1263a2968d6SJeremy L Thompson CeedElemRestriction rstr_points = NULL; 1273a2968d6SJeremy L Thompson 1283a2968d6SJeremy L Thompson data->points.num_elem = num_elem; 1293a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL)); 1303a2968d6SJeremy L Thompson CeedCallBackend(CeedCalloc(num_elem, &points_per_elem)); 1313a2968d6SJeremy L Thompson for (CeedInt e = 0; e < num_elem; e++) { 1323a2968d6SJeremy L Thompson CeedInt num_points_elem; 1333a2968d6SJeremy L Thompson 1343a2968d6SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem)); 1353a2968d6SJeremy L Thompson points_per_elem[e] = num_points_elem; 1363a2968d6SJeremy L Thompson } 1373a2968d6SJeremy L Thompson if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem)); 1383a2968d6SJeremy L Thompson CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes)); 1393a2968d6SJeremy L Thompson CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice)); 1403a2968d6SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 1413a2968d6SJeremy L Thompson CeedCallBackend(CeedFree(&points_per_elem)); 1423a2968d6SJeremy L Thompson } 1433a2968d6SJeremy L Thompson } 1443a2968d6SJeremy L Thompson 1457d8d0e25Snbeams // Get context data 1462b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c)); 1477d8d0e25Snbeams 1487d8d0e25Snbeams // Apply operator 1493a2968d6SJeremy L Thompson void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points}; 150b7453713SJeremy L Thompson 1519123fb08SJeremy L Thompson CeedCallBackend(CeedOperatorHasTensorBases(op, &is_tensor)); 152a61b1c91SJeremy L Thompson CeedInt block_sizes[3] = {data->thread_1d, ((!is_tensor || data->dim == 1) ? 1 : data->thread_1d), -1}; 153f82027a4SJeremy L Thompson 154f82027a4SJeremy L Thompson if (is_tensor) { 15574398b5aSJeremy L Thompson CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes)); 156f82027a4SJeremy L Thompson } else { 157a61b1c91SJeremy L Thompson CeedInt elems_per_block = 64 * data->thread_1d > 256 ? 256 / data->thread_1d : 64; 158f82027a4SJeremy L Thompson 159f82027a4SJeremy L Thompson elems_per_block = elems_per_block > 0 ? elems_per_block : 1; 160f82027a4SJeremy L Thompson block_sizes[2] = elems_per_block; 161f82027a4SJeremy L Thompson } 16274398b5aSJeremy L Thompson if (data->dim == 1 || !is_tensor) { 1632b730f8bSJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 164a61b1c91SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar); 165b7453713SJeremy L Thompson 1661a8516d0SJames Wright CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, 1671a8516d0SJames Wright is_run_good, opargs)); 16874398b5aSJeremy L Thompson } else if (data->dim == 2) { 1692b730f8bSJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 170a61b1c91SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 171b7453713SJeremy L Thompson 1721a8516d0SJames Wright CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, 1731a8516d0SJames Wright is_run_good, opargs)); 17474398b5aSJeremy L Thompson } else if (data->dim == 3) { 1752b730f8bSJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 176a61b1c91SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 177b7453713SJeremy L Thompson 1781a8516d0SJames Wright CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, 1791a8516d0SJames Wright is_run_good, opargs)); 1807d8d0e25Snbeams } 1817d8d0e25Snbeams 1827d8d0e25Snbeams // Restore input arrays 1839e201c85SYohann for (CeedInt i = 0; i < num_input_fields; i++) { 1842b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 1859e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 1867d8d0e25Snbeams } else { 1873efc994bSJeremy L Thompson bool is_active; 188b7453713SJeremy L Thompson CeedVector vec; 189b7453713SJeremy L Thompson 1902b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1913efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 192ea04d07fSJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i])); 193ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 1947d8d0e25Snbeams } 1957d8d0e25Snbeams } 1967d8d0e25Snbeams 1977d8d0e25Snbeams // Restore output arrays 1989e201c85SYohann for (CeedInt i = 0; i < num_output_fields; i++) { 1992b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 2009e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 2017d8d0e25Snbeams } else { 2023efc994bSJeremy L Thompson bool is_active; 203b7453713SJeremy L Thompson CeedVector vec; 204b7453713SJeremy L Thompson 2052b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 2063efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 207ea04d07fSJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArray(vec, &data->fields.outputs[i])); 208ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 2097d8d0e25Snbeams } 2107d8d0e25Snbeams } 2117d8d0e25Snbeams 2123a2968d6SJeremy L Thompson // Restore point coordinates, if needed 2133a2968d6SJeremy L Thompson if (is_at_points) { 2143a2968d6SJeremy L Thompson CeedVector vec; 2153a2968d6SJeremy L Thompson 2163a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 2173a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords)); 2183a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 2193a2968d6SJeremy L Thompson } 2203a2968d6SJeremy L Thompson 2217d8d0e25Snbeams // Restore context data 2222b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c)); 2238d12f40eSJeremy L Thompson 2248d12f40eSJeremy L Thompson // Cleanup 2259bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 226c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionDestroy(&qf)); 227ea04d07fSJeremy L Thompson if (!(*is_run_good)) data->use_fallback = true; 228ea04d07fSJeremy L Thompson return CEED_ERROR_SUCCESS; 229ea04d07fSJeremy L Thompson } 2308d12f40eSJeremy L Thompson 231ea04d07fSJeremy L Thompson static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) { 232ea04d07fSJeremy L Thompson bool is_run_good = false; 233ea04d07fSJeremy L Thompson const CeedScalar *input_arr = NULL; 234ea04d07fSJeremy L Thompson CeedScalar *output_arr = NULL; 235ea04d07fSJeremy L Thompson 236ea04d07fSJeremy L Thompson // Try to run kernel 237ea04d07fSJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr)); 238ea04d07fSJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr)); 239087855afSJeremy L Thompson CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(op, NULL, input_arr, output_arr, &is_run_good, request)); 240ea04d07fSJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr)); 241ea04d07fSJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr)); 242ea04d07fSJeremy L Thompson 243ea04d07fSJeremy L Thompson // Fallback on unsuccessful run 244ea04d07fSJeremy L Thompson if (!is_run_good) { 2458d12f40eSJeremy L Thompson CeedOperator op_fallback; 2468d12f40eSJeremy L Thompson 247ca38d01dSJeremy L Thompson CeedDebug(CeedOperatorReturnCeed(op), "\nFalling back to /gpu/hip/ref CeedOperator for ApplyAdd\n"); 2488d12f40eSJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback)); 2498d12f40eSJeremy L Thompson CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request)); 2508d12f40eSJeremy L Thompson } 251e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 2527d8d0e25Snbeams } 2537d8d0e25Snbeams 254c99afcd8SJeremy L Thompson static int CeedOperatorApplyAddComposite_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) { 255af34f196SJeremy L Thompson bool is_run_good[CEED_COMPOSITE_MAX] = {false}; 256c99afcd8SJeremy L Thompson CeedInt num_suboperators; 257c99afcd8SJeremy L Thompson const CeedScalar *input_arr = NULL; 258af34f196SJeremy L Thompson CeedScalar *output_arr = NULL; 259087855afSJeremy L Thompson Ceed ceed; 2606eee1ffcSZach Atkins CeedOperator_Hip_gen *impl; 261c99afcd8SJeremy L Thompson CeedOperator *sub_operators; 262c99afcd8SJeremy L Thompson 263087855afSJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 2646eee1ffcSZach Atkins CeedCallBackend(CeedOperatorGetData(op, &impl)); 265ed094490SJeremy L Thompson CeedCallBackend(CeedOperatorCompositeGetNumSub(op, &num_suboperators)); 266ed094490SJeremy L Thompson CeedCallBackend(CeedOperatorCompositeGetSubList(op, &sub_operators)); 267c99afcd8SJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr)); 268c99afcd8SJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr)); 269c99afcd8SJeremy L Thompson for (CeedInt i = 0; i < num_suboperators; i++) { 270c99afcd8SJeremy L Thompson CeedInt num_elem = 0; 271c99afcd8SJeremy L Thompson 2726eee1ffcSZach Atkins CeedCallBackend(CeedOperatorGetNumElements(sub_operators[i], &num_elem)); 273c99afcd8SJeremy L Thompson if (num_elem > 0) { 2746eee1ffcSZach Atkins if (!impl->streams[i]) CeedCallHip(ceed, hipStreamCreate(&impl->streams[i])); 2756eee1ffcSZach Atkins CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(sub_operators[i], impl->streams[i], input_arr, output_arr, &is_run_good[i], request)); 2766eee1ffcSZach Atkins } else { 2776eee1ffcSZach Atkins is_run_good[i] = true; 2786eee1ffcSZach Atkins } 2796eee1ffcSZach Atkins } 280087855afSJeremy L Thompson 2816eee1ffcSZach Atkins for (CeedInt i = 0; i < num_suboperators; i++) { 2826eee1ffcSZach Atkins if (impl->streams[i]) { 2836eee1ffcSZach Atkins if (is_run_good[i]) CeedCallHip(ceed, hipStreamSynchronize(impl->streams[i])); 284c99afcd8SJeremy L Thompson } 285c99afcd8SJeremy L Thompson } 286c99afcd8SJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr)); 287c99afcd8SJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr)); 288087855afSJeremy L Thompson CeedCallHip(ceed, hipDeviceSynchronize()); 289c99afcd8SJeremy L Thompson 290c99afcd8SJeremy L Thompson // Fallback on unsuccessful run 291c99afcd8SJeremy L Thompson for (CeedInt i = 0; i < num_suboperators; i++) { 292c99afcd8SJeremy L Thompson if (!is_run_good[i]) { 293c99afcd8SJeremy L Thompson CeedOperator op_fallback; 294c99afcd8SJeremy L Thompson 295ca38d01dSJeremy L Thompson CeedDebug(ceed, "\nFalling back to /gpu/hip/ref CeedOperator for ApplyAdd\n"); 296c99afcd8SJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(sub_operators[i], &op_fallback)); 297c99afcd8SJeremy L Thompson CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request)); 298c99afcd8SJeremy L Thompson } 299c99afcd8SJeremy L Thompson } 300087855afSJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 301c99afcd8SJeremy L Thompson return CEED_ERROR_SUCCESS; 302c99afcd8SJeremy L Thompson } 303c99afcd8SJeremy L Thompson 3047d8d0e25Snbeams //------------------------------------------------------------------------------ 3055daefc96SJeremy L Thompson // QFunction assembly 3065daefc96SJeremy L Thompson //------------------------------------------------------------------------------ 3075daefc96SJeremy L Thompson static int CeedOperatorLinearAssembleQFunctionCore_Hip_gen(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 3085daefc96SJeremy L Thompson CeedRequest *request) { 3095daefc96SJeremy L Thompson Ceed ceed; 3105daefc96SJeremy L Thompson CeedOperator_Hip_gen *data; 3115daefc96SJeremy L Thompson 3125daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 3135daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &data)); 3145daefc96SJeremy L Thompson 3155daefc96SJeremy L Thompson // Build the assembly kernel 3165daefc96SJeremy L Thompson if (!data->assemble_qfunction && !data->use_assembly_fallback) { 3175daefc96SJeremy L Thompson bool is_build_good = false; 3185daefc96SJeremy L Thompson 3195daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_build_good)); 3205daefc96SJeremy L Thompson if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelLinearAssembleQFunction_Hip_gen(op, &is_build_good)); 3215daefc96SJeremy L Thompson if (!is_build_good) data->use_assembly_fallback = true; 3225daefc96SJeremy L Thompson } 3235daefc96SJeremy L Thompson 3245daefc96SJeremy L Thompson // Try assembly 3255daefc96SJeremy L Thompson if (!data->use_assembly_fallback) { 3265daefc96SJeremy L Thompson bool is_run_good = true; 3275daefc96SJeremy L Thompson Ceed_Hip *hip_data; 3285daefc96SJeremy L Thompson CeedInt num_elem, num_input_fields, num_output_fields; 3295daefc96SJeremy L Thompson CeedEvalMode eval_mode; 3305daefc96SJeremy L Thompson CeedScalar *assembled_array; 3315daefc96SJeremy L Thompson CeedQFunctionField *qf_input_fields, *qf_output_fields; 3325daefc96SJeremy L Thompson CeedQFunction_Hip_gen *qf_data; 3335daefc96SJeremy L Thompson CeedQFunction qf; 3345daefc96SJeremy L Thompson CeedOperatorField *op_input_fields, *op_output_fields; 3355daefc96SJeremy L Thompson 3365daefc96SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &hip_data)); 3375daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 3385daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &qf_data)); 3395daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 3405daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 3415daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 3425daefc96SJeremy L Thompson 3435daefc96SJeremy L Thompson // Input vectors 3445daefc96SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) { 3455daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 3465daefc96SJeremy L Thompson if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 3475daefc96SJeremy L Thompson data->fields.inputs[i] = NULL; 3485daefc96SJeremy L Thompson } else { 3495daefc96SJeremy L Thompson bool is_active; 3505daefc96SJeremy L Thompson CeedVector vec; 3515daefc96SJeremy L Thompson 3525daefc96SJeremy L Thompson // Get input vector 3535daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 3545daefc96SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 3555daefc96SJeremy L Thompson if (is_active) data->fields.inputs[i] = NULL; 3565daefc96SJeremy L Thompson else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i])); 3575daefc96SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 3585daefc96SJeremy L Thompson } 3595daefc96SJeremy L Thompson } 3605daefc96SJeremy L Thompson 3615daefc96SJeremy L Thompson // Get context data 3625daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c)); 3635daefc96SJeremy L Thompson 3645daefc96SJeremy L Thompson // Build objects if needed 3655daefc96SJeremy L Thompson if (build_objects) { 3665daefc96SJeremy L Thompson CeedInt qf_size_in = 0, qf_size_out = 0, Q; 3675daefc96SJeremy L Thompson 3685daefc96SJeremy L Thompson // Count number of active input fields 3695daefc96SJeremy L Thompson { 3705daefc96SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) { 3715daefc96SJeremy L Thompson CeedInt field_size; 3725daefc96SJeremy L Thompson CeedVector vec; 3735daefc96SJeremy L Thompson 3745daefc96SJeremy L Thompson // Get input vector 3755daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 3765daefc96SJeremy L Thompson // Check if active input 3775daefc96SJeremy L Thompson if (vec == CEED_VECTOR_ACTIVE) { 3785daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size)); 3795daefc96SJeremy L Thompson qf_size_in += field_size; 3805daefc96SJeremy L Thompson } 3815daefc96SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 3825daefc96SJeremy L Thompson } 3835daefc96SJeremy L Thompson CeedCheck(qf_size_in > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 3845daefc96SJeremy L Thompson } 3855daefc96SJeremy L Thompson 3865daefc96SJeremy L Thompson // Count number of active output fields 3875daefc96SJeremy L Thompson { 3885daefc96SJeremy L Thompson for (CeedInt i = 0; i < num_output_fields; i++) { 3895daefc96SJeremy L Thompson CeedInt field_size; 3905daefc96SJeremy L Thompson CeedVector vec; 3915daefc96SJeremy L Thompson 3925daefc96SJeremy L Thompson // Get output vector 3935daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 3945daefc96SJeremy L Thompson // Check if active output 3955daefc96SJeremy L Thompson if (vec == CEED_VECTOR_ACTIVE) { 3965daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &field_size)); 3975daefc96SJeremy L Thompson qf_size_out += field_size; 3985daefc96SJeremy L Thompson } 3995daefc96SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 4005daefc96SJeremy L Thompson } 4015daefc96SJeremy L Thompson CeedCheck(qf_size_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 4025daefc96SJeremy L Thompson } 4035daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 4045daefc96SJeremy L Thompson 4055daefc96SJeremy L Thompson // Actually build objects now 4065daefc96SJeremy L Thompson const CeedSize l_size = (CeedSize)num_elem * Q * qf_size_in * qf_size_out; 4075daefc96SJeremy L Thompson CeedInt strides[3] = {1, num_elem * Q, Q}; /* *NOPAD* */ 4085daefc96SJeremy L Thompson 4095daefc96SJeremy L Thompson // Create output restriction 4105daefc96SJeremy L Thompson CeedCallBackend(CeedElemRestrictionCreateStrided(ceed, num_elem, Q, qf_size_in * qf_size_out, 4115daefc96SJeremy L Thompson (CeedSize)qf_size_in * (CeedSize)qf_size_out * (CeedSize)num_elem * (CeedSize)Q, strides, 4125daefc96SJeremy L Thompson rstr)); 4135daefc96SJeremy L Thompson // Create assembled vector 4145daefc96SJeremy L Thompson CeedCallBackend(CeedVectorCreate(ceed, l_size, assembled)); 4155daefc96SJeremy L Thompson } 4165daefc96SJeremy L Thompson 4175daefc96SJeremy L Thompson // Assembly array 4185daefc96SJeremy L Thompson CeedCallBackend(CeedVectorGetArrayWrite(*assembled, CEED_MEM_DEVICE, &assembled_array)); 4195daefc96SJeremy L Thompson 4205daefc96SJeremy L Thompson // Assemble QFunction 4215daefc96SJeremy L Thompson bool is_tensor = false; 4225daefc96SJeremy L Thompson void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points, &assembled_array}; 4235daefc96SJeremy L Thompson 4245daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorHasTensorBases(op, &is_tensor)); 4255daefc96SJeremy L Thompson CeedInt block_sizes[3] = {data->thread_1d, ((!is_tensor || data->dim == 1) ? 1 : data->thread_1d), -1}; 4265daefc96SJeremy L Thompson 4275daefc96SJeremy L Thompson if (is_tensor) { 4285daefc96SJeremy L Thompson CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes)); 4295daefc96SJeremy L Thompson } else { 4305daefc96SJeremy L Thompson CeedInt elems_per_block = 64 * data->thread_1d > 256 ? 256 / data->thread_1d : 64; 4315daefc96SJeremy L Thompson 4325daefc96SJeremy L Thompson elems_per_block = elems_per_block > 0 ? elems_per_block : 1; 4335daefc96SJeremy L Thompson block_sizes[2] = elems_per_block; 4345daefc96SJeremy L Thompson } 4355daefc96SJeremy L Thompson if (data->dim == 1 || !is_tensor) { 4365daefc96SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 4375daefc96SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar); 4385daefc96SJeremy L Thompson 4395daefc96SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_qfunction, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 4405daefc96SJeremy L Thompson sharedMem, &is_run_good, opargs)); 4415daefc96SJeremy L Thompson } else if (data->dim == 2) { 4425daefc96SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 4435daefc96SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 4445daefc96SJeremy L Thompson 4455daefc96SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_qfunction, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 4465daefc96SJeremy L Thompson sharedMem, &is_run_good, opargs)); 4475daefc96SJeremy L Thompson } else if (data->dim == 3) { 4485daefc96SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 4495daefc96SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 4505daefc96SJeremy L Thompson 4515daefc96SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_qfunction, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 4525daefc96SJeremy L Thompson sharedMem, &is_run_good, opargs)); 4535daefc96SJeremy L Thompson } 4545daefc96SJeremy L Thompson 4555daefc96SJeremy L Thompson // Restore input arrays 4565daefc96SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) { 4575daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 4585daefc96SJeremy L Thompson if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 4595daefc96SJeremy L Thompson } else { 4605daefc96SJeremy L Thompson bool is_active; 4615daefc96SJeremy L Thompson CeedVector vec; 4625daefc96SJeremy L Thompson 4635daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 4645daefc96SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 4655daefc96SJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i])); 4665daefc96SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 4675daefc96SJeremy L Thompson } 4685daefc96SJeremy L Thompson } 4695daefc96SJeremy L Thompson 4705daefc96SJeremy L Thompson // Restore context data 4715daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c)); 4725daefc96SJeremy L Thompson 4735daefc96SJeremy L Thompson // Restore assembly array 4745daefc96SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array)); 4755daefc96SJeremy L Thompson 4765daefc96SJeremy L Thompson // Cleanup 4775daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionDestroy(&qf)); 4785daefc96SJeremy L Thompson if (!is_run_good) { 4795daefc96SJeremy L Thompson data->use_assembly_fallback = true; 4805daefc96SJeremy L Thompson if (build_objects) { 4815daefc96SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(assembled)); 4825daefc96SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(rstr)); 4835daefc96SJeremy L Thompson } 4845daefc96SJeremy L Thompson } 4855daefc96SJeremy L Thompson } 4865daefc96SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 4875daefc96SJeremy L Thompson 4885daefc96SJeremy L Thompson // Fallback, if needed 4895daefc96SJeremy L Thompson if (data->use_assembly_fallback) { 4905daefc96SJeremy L Thompson CeedOperator op_fallback; 4915daefc96SJeremy L Thompson 492ca38d01dSJeremy L Thompson CeedDebug(CeedOperatorReturnCeed(op), "\nFalling back to /gpu/hip/ref CeedOperator for LineearAssembleQFunction\n"); 4935daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback)); 494ed094490SJeremy L Thompson CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdateFallback(op_fallback, assembled, rstr, request)); 4955daefc96SJeremy L Thompson return CEED_ERROR_SUCCESS; 4965daefc96SJeremy L Thompson } 4975daefc96SJeremy L Thompson return CEED_ERROR_SUCCESS; 4985daefc96SJeremy L Thompson } 4995daefc96SJeremy L Thompson 5005daefc96SJeremy L Thompson static int CeedOperatorLinearAssembleQFunction_Hip_gen(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 5015daefc96SJeremy L Thompson return CeedOperatorLinearAssembleQFunctionCore_Hip_gen(op, true, assembled, rstr, request); 5025daefc96SJeremy L Thompson } 5035daefc96SJeremy L Thompson 5045daefc96SJeremy L Thompson static int CeedOperatorLinearAssembleQFunctionUpdate_Hip_gen(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 5055daefc96SJeremy L Thompson return CeedOperatorLinearAssembleQFunctionCore_Hip_gen(op, false, &assembled, &rstr, request); 5065daefc96SJeremy L Thompson } 5075daefc96SJeremy L Thompson 5085daefc96SJeremy L Thompson //------------------------------------------------------------------------------ 5090183ed61SJeremy L Thompson // AtPoints diagonal assembly 5100183ed61SJeremy L Thompson //------------------------------------------------------------------------------ 5110183ed61SJeremy L Thompson static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip_gen(CeedOperator op, CeedVector assembled, CeedRequest *request) { 5120183ed61SJeremy L Thompson Ceed ceed; 5130183ed61SJeremy L Thompson CeedOperator_Hip_gen *data; 5140183ed61SJeremy L Thompson 5150183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 5160183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &data)); 5170183ed61SJeremy L Thompson 5180183ed61SJeremy L Thompson // Build the assembly kernel 5190183ed61SJeremy L Thompson if (!data->assemble_diagonal && !data->use_assembly_fallback) { 5200183ed61SJeremy L Thompson bool is_build_good = false; 5210183ed61SJeremy L Thompson CeedInt num_active_bases_in, num_active_bases_out; 5220183ed61SJeremy L Thompson CeedOperatorAssemblyData assembly_data; 5230183ed61SJeremy L Thompson 5240183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetOperatorAssemblyData(op, &assembly_data)); 5251a8516d0SJames Wright CeedCallBackend(CeedOperatorAssemblyDataGetEvalModes(assembly_data, &num_active_bases_in, NULL, NULL, NULL, &num_active_bases_out, NULL, NULL, 5261a8516d0SJames Wright NULL, NULL)); 5270183ed61SJeremy L Thompson if (num_active_bases_in == num_active_bases_out) { 5280183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_build_good)); 5290183ed61SJeremy L Thompson if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelDiagonalAssemblyAtPoints_Hip_gen(op, &is_build_good)); 5300183ed61SJeremy L Thompson } 5310183ed61SJeremy L Thompson if (!is_build_good) data->use_assembly_fallback = true; 5320183ed61SJeremy L Thompson } 5330183ed61SJeremy L Thompson 5340183ed61SJeremy L Thompson // Try assembly 5350183ed61SJeremy L Thompson if (!data->use_assembly_fallback) { 5360183ed61SJeremy L Thompson bool is_run_good = true; 5370183ed61SJeremy L Thompson Ceed_Hip *hip_data; 5380183ed61SJeremy L Thompson CeedInt num_elem, num_input_fields, num_output_fields; 5390183ed61SJeremy L Thompson CeedEvalMode eval_mode; 5400183ed61SJeremy L Thompson CeedScalar *assembled_array; 5410183ed61SJeremy L Thompson CeedQFunctionField *qf_input_fields, *qf_output_fields; 5420183ed61SJeremy L Thompson CeedQFunction_Hip_gen *qf_data; 5430183ed61SJeremy L Thompson CeedQFunction qf; 5440183ed61SJeremy L Thompson CeedOperatorField *op_input_fields, *op_output_fields; 5450183ed61SJeremy L Thompson 5460183ed61SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &hip_data)); 5470183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 5480183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &qf_data)); 5490183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 5500183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 5510183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 5520183ed61SJeremy L Thompson 5530183ed61SJeremy L Thompson // Input vectors 5540183ed61SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) { 5550183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 5560183ed61SJeremy L Thompson if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 5570183ed61SJeremy L Thompson data->fields.inputs[i] = NULL; 5580183ed61SJeremy L Thompson } else { 5590183ed61SJeremy L Thompson bool is_active; 5600183ed61SJeremy L Thompson CeedVector vec; 5610183ed61SJeremy L Thompson 5620183ed61SJeremy L Thompson // Get input vector 5630183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 5640183ed61SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 5650183ed61SJeremy L Thompson if (is_active) data->fields.inputs[i] = NULL; 5660183ed61SJeremy L Thompson else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i])); 5670183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 5680183ed61SJeremy L Thompson } 5690183ed61SJeremy L Thompson } 5700183ed61SJeremy L Thompson 5710183ed61SJeremy L Thompson // Point coordinates 5720183ed61SJeremy L Thompson { 5730183ed61SJeremy L Thompson CeedVector vec; 5740183ed61SJeremy L Thompson 5750183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 5760183ed61SJeremy L Thompson CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords)); 5770183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 5780183ed61SJeremy L Thompson 5790183ed61SJeremy L Thompson // Points per elem 5800183ed61SJeremy L Thompson if (num_elem != data->points.num_elem) { 5810183ed61SJeremy L Thompson CeedInt *points_per_elem; 5820183ed61SJeremy L Thompson const CeedInt num_bytes = num_elem * sizeof(CeedInt); 5830183ed61SJeremy L Thompson CeedElemRestriction rstr_points = NULL; 5840183ed61SJeremy L Thompson 5850183ed61SJeremy L Thompson data->points.num_elem = num_elem; 5860183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL)); 5870183ed61SJeremy L Thompson CeedCallBackend(CeedCalloc(num_elem, &points_per_elem)); 5880183ed61SJeremy L Thompson for (CeedInt e = 0; e < num_elem; e++) { 5890183ed61SJeremy L Thompson CeedInt num_points_elem; 5900183ed61SJeremy L Thompson 5910183ed61SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem)); 5920183ed61SJeremy L Thompson points_per_elem[e] = num_points_elem; 5930183ed61SJeremy L Thompson } 5940183ed61SJeremy L Thompson if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem)); 5950183ed61SJeremy L Thompson CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes)); 5960183ed61SJeremy L Thompson CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice)); 5970183ed61SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 5980183ed61SJeremy L Thompson CeedCallBackend(CeedFree(&points_per_elem)); 5990183ed61SJeremy L Thompson } 6000183ed61SJeremy L Thompson } 6010183ed61SJeremy L Thompson 6020183ed61SJeremy L Thompson // Get context data 6030183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c)); 6040183ed61SJeremy L Thompson 6050183ed61SJeremy L Thompson // Assembly array 6060183ed61SJeremy L Thompson CeedCallBackend(CeedVectorGetArray(assembled, CEED_MEM_DEVICE, &assembled_array)); 6070183ed61SJeremy L Thompson 6080183ed61SJeremy L Thompson // Assemble diagonal 6090183ed61SJeremy L Thompson void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points, &assembled_array}; 6100183ed61SJeremy L Thompson 6110183ed61SJeremy L Thompson CeedInt block_sizes[3] = {data->thread_1d, (data->dim == 1 ? 1 : data->thread_1d), -1}; 6120183ed61SJeremy L Thompson 6130183ed61SJeremy L Thompson CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes)); 6140183ed61SJeremy L Thompson block_sizes[2] = 1; 6150183ed61SJeremy L Thompson if (data->dim == 1) { 6160183ed61SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 6170183ed61SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar); 6180183ed61SJeremy L Thompson 6190183ed61SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 6200183ed61SJeremy L Thompson sharedMem, &is_run_good, opargs)); 6210183ed61SJeremy L Thompson } else if (data->dim == 2) { 6220183ed61SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 6230183ed61SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 6240183ed61SJeremy L Thompson 6250183ed61SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 6260183ed61SJeremy L Thompson sharedMem, &is_run_good, opargs)); 6270183ed61SJeremy L Thompson } else if (data->dim == 3) { 6280183ed61SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 6290183ed61SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 6300183ed61SJeremy L Thompson 6310183ed61SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 6320183ed61SJeremy L Thompson sharedMem, &is_run_good, opargs)); 6330183ed61SJeremy L Thompson } 634692716b7SZach Atkins CeedCallHip(ceed, hipDeviceSynchronize()); 6350183ed61SJeremy L Thompson 6360183ed61SJeremy L Thompson // Restore input arrays 6370183ed61SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) { 6380183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 6390183ed61SJeremy L Thompson if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 6400183ed61SJeremy L Thompson } else { 6410183ed61SJeremy L Thompson bool is_active; 6420183ed61SJeremy L Thompson CeedVector vec; 6430183ed61SJeremy L Thompson 6440183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 6450183ed61SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 6460183ed61SJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i])); 6470183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 6480183ed61SJeremy L Thompson } 6490183ed61SJeremy L Thompson } 6500183ed61SJeremy L Thompson 6510183ed61SJeremy L Thompson // Restore point coordinates 6520183ed61SJeremy L Thompson { 6530183ed61SJeremy L Thompson CeedVector vec; 6540183ed61SJeremy L Thompson 6550183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 6560183ed61SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords)); 6570183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 6580183ed61SJeremy L Thompson } 6590183ed61SJeremy L Thompson 6600183ed61SJeremy L Thompson // Restore context data 6610183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c)); 6620183ed61SJeremy L Thompson 6630183ed61SJeremy L Thompson // Restore assembly array 6640183ed61SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArray(assembled, &assembled_array)); 6650183ed61SJeremy L Thompson 6660183ed61SJeremy L Thompson // Cleanup 6670183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionDestroy(&qf)); 6680183ed61SJeremy L Thompson if (!is_run_good) data->use_assembly_fallback = true; 6690183ed61SJeremy L Thompson } 6700183ed61SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 6710183ed61SJeremy L Thompson 6720183ed61SJeremy L Thompson // Fallback, if needed 6730183ed61SJeremy L Thompson if (data->use_assembly_fallback) { 6740183ed61SJeremy L Thompson CeedOperator op_fallback; 6750183ed61SJeremy L Thompson 676ca38d01dSJeremy L Thompson CeedDebug(CeedOperatorReturnCeed(op), "\nFalling back to /gpu/hip/ref CeedOperator for AtPoints LinearAssembleAddDiagonal\n"); 6770183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback)); 6780183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorLinearAssembleAddDiagonal(op_fallback, assembled, request)); 6790183ed61SJeremy L Thompson return CEED_ERROR_SUCCESS; 6800183ed61SJeremy L Thompson } 6810183ed61SJeremy L Thompson return CEED_ERROR_SUCCESS; 6820183ed61SJeremy L Thompson } 6830183ed61SJeremy L Thompson 6840183ed61SJeremy L Thompson //------------------------------------------------------------------------------ 685692716b7SZach Atkins // AtPoints full assembly 686692716b7SZach Atkins //------------------------------------------------------------------------------ 687ed094490SJeremy L Thompson static int CeedOperatorAssembleSingleAtPoints_Hip_gen(CeedOperator op, CeedInt offset, CeedVector assembled) { 688692716b7SZach Atkins Ceed ceed; 689692716b7SZach Atkins CeedOperator_Hip_gen *data; 690692716b7SZach Atkins 691692716b7SZach Atkins CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 692692716b7SZach Atkins CeedCallBackend(CeedOperatorGetData(op, &data)); 693692716b7SZach Atkins 694692716b7SZach Atkins // Build the assembly kernel 695692716b7SZach Atkins if (!data->assemble_full && !data->use_assembly_fallback) { 696692716b7SZach Atkins bool is_build_good = false; 697692716b7SZach Atkins CeedInt num_active_bases_in, num_active_bases_out; 698692716b7SZach Atkins CeedOperatorAssemblyData assembly_data; 699692716b7SZach Atkins 700692716b7SZach Atkins CeedCallBackend(CeedOperatorGetOperatorAssemblyData(op, &assembly_data)); 7011a8516d0SJames Wright CeedCallBackend(CeedOperatorAssemblyDataGetEvalModes(assembly_data, &num_active_bases_in, NULL, NULL, NULL, &num_active_bases_out, NULL, NULL, 7021a8516d0SJames Wright NULL, NULL)); 703692716b7SZach Atkins if (num_active_bases_in == num_active_bases_out) { 704692716b7SZach Atkins CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_build_good)); 705692716b7SZach Atkins if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelFullAssemblyAtPoints_Hip_gen(op, &is_build_good)); 706692716b7SZach Atkins } 707692716b7SZach Atkins if (!is_build_good) { 708692716b7SZach Atkins CeedDebug(ceed, "Single Operator Assemble at Points compile failed, using fallback\n"); 709692716b7SZach Atkins data->use_assembly_fallback = true; 710692716b7SZach Atkins } 711692716b7SZach Atkins } 712692716b7SZach Atkins 713692716b7SZach Atkins // Try assembly 714692716b7SZach Atkins if (!data->use_assembly_fallback) { 715692716b7SZach Atkins bool is_run_good = true; 716692716b7SZach Atkins Ceed_Hip *Hip_data; 717692716b7SZach Atkins CeedInt num_elem, num_input_fields, num_output_fields; 718692716b7SZach Atkins CeedEvalMode eval_mode; 719692716b7SZach Atkins CeedScalar *assembled_array; 720692716b7SZach Atkins CeedQFunctionField *qf_input_fields, *qf_output_fields; 721692716b7SZach Atkins CeedQFunction_Hip_gen *qf_data; 722692716b7SZach Atkins CeedQFunction qf; 723692716b7SZach Atkins CeedOperatorField *op_input_fields, *op_output_fields; 724692716b7SZach Atkins 725692716b7SZach Atkins CeedCallBackend(CeedGetData(ceed, &Hip_data)); 726692716b7SZach Atkins CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 727692716b7SZach Atkins CeedCallBackend(CeedQFunctionGetData(qf, &qf_data)); 728692716b7SZach Atkins CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 729692716b7SZach Atkins CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 730692716b7SZach Atkins CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 731692716b7SZach Atkins CeedDebug(ceed, "Running single operator assemble for /gpu/hip/gen\n"); 732692716b7SZach Atkins 733692716b7SZach Atkins // Input vectors 734692716b7SZach Atkins for (CeedInt i = 0; i < num_input_fields; i++) { 735692716b7SZach Atkins CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 736692716b7SZach Atkins if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 737692716b7SZach Atkins data->fields.inputs[i] = NULL; 738692716b7SZach Atkins } else { 739692716b7SZach Atkins bool is_active; 740692716b7SZach Atkins CeedVector vec; 741692716b7SZach Atkins 742692716b7SZach Atkins // Get input vector 743692716b7SZach Atkins CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 744692716b7SZach Atkins is_active = vec == CEED_VECTOR_ACTIVE; 745692716b7SZach Atkins if (is_active) data->fields.inputs[i] = NULL; 746692716b7SZach Atkins else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i])); 747692716b7SZach Atkins CeedCallBackend(CeedVectorDestroy(&vec)); 748692716b7SZach Atkins } 749692716b7SZach Atkins } 750692716b7SZach Atkins 751692716b7SZach Atkins // Point coordinates 752692716b7SZach Atkins { 753692716b7SZach Atkins CeedVector vec; 754692716b7SZach Atkins 755692716b7SZach Atkins CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 756692716b7SZach Atkins CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords)); 757692716b7SZach Atkins CeedCallBackend(CeedVectorDestroy(&vec)); 758692716b7SZach Atkins 759692716b7SZach Atkins // Points per elem 760692716b7SZach Atkins if (num_elem != data->points.num_elem) { 761692716b7SZach Atkins CeedInt *points_per_elem; 762692716b7SZach Atkins const CeedInt num_bytes = num_elem * sizeof(CeedInt); 763692716b7SZach Atkins CeedElemRestriction rstr_points = NULL; 764692716b7SZach Atkins 765692716b7SZach Atkins data->points.num_elem = num_elem; 766692716b7SZach Atkins CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL)); 767692716b7SZach Atkins CeedCallBackend(CeedCalloc(num_elem, &points_per_elem)); 768692716b7SZach Atkins for (CeedInt e = 0; e < num_elem; e++) { 769692716b7SZach Atkins CeedInt num_points_elem; 770692716b7SZach Atkins 771692716b7SZach Atkins CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem)); 772692716b7SZach Atkins points_per_elem[e] = num_points_elem; 773692716b7SZach Atkins } 774692716b7SZach Atkins if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem)); 775692716b7SZach Atkins CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes)); 776692716b7SZach Atkins CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice)); 777692716b7SZach Atkins CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 778692716b7SZach Atkins CeedCallBackend(CeedFree(&points_per_elem)); 779692716b7SZach Atkins } 780692716b7SZach Atkins } 781692716b7SZach Atkins 782692716b7SZach Atkins // Get context data 783692716b7SZach Atkins CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c)); 784692716b7SZach Atkins 785692716b7SZach Atkins // Assembly array 786692716b7SZach Atkins CeedCallBackend(CeedVectorGetArray(assembled, CEED_MEM_DEVICE, &assembled_array)); 787692716b7SZach Atkins CeedScalar *assembled_offset_array = &assembled_array[offset]; 788692716b7SZach Atkins 789692716b7SZach Atkins // Assemble diagonal 790692716b7SZach Atkins void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, 791692716b7SZach Atkins &data->G, &data->W, &data->points, &assembled_offset_array}; 792692716b7SZach Atkins 793692716b7SZach Atkins CeedInt block_sizes[3] = {data->thread_1d, (data->dim == 1 ? 1 : data->thread_1d), -1}; 794692716b7SZach Atkins 795692716b7SZach Atkins CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes)); 796692716b7SZach Atkins block_sizes[2] = 1; 797692716b7SZach Atkins if (data->dim == 1) { 798692716b7SZach Atkins CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 799692716b7SZach Atkins CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar); 800692716b7SZach Atkins 801692716b7SZach Atkins CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_full, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, 802692716b7SZach Atkins &is_run_good, opargs)); 803692716b7SZach Atkins } else if (data->dim == 2) { 804692716b7SZach Atkins CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 805692716b7SZach Atkins CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 806692716b7SZach Atkins 807692716b7SZach Atkins CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_full, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, 808692716b7SZach Atkins &is_run_good, opargs)); 809692716b7SZach Atkins } else if (data->dim == 3) { 810692716b7SZach Atkins CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 811692716b7SZach Atkins CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 812692716b7SZach Atkins 813692716b7SZach Atkins CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_full, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, 814692716b7SZach Atkins &is_run_good, opargs)); 815692716b7SZach Atkins } 816692716b7SZach Atkins CeedCallHip(ceed, hipDeviceSynchronize()); 817692716b7SZach Atkins 818692716b7SZach Atkins // Restore input arrays 819692716b7SZach Atkins for (CeedInt i = 0; i < num_input_fields; i++) { 820692716b7SZach Atkins CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 821692716b7SZach Atkins if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 822692716b7SZach Atkins } else { 823692716b7SZach Atkins bool is_active; 824692716b7SZach Atkins CeedVector vec; 825692716b7SZach Atkins 826692716b7SZach Atkins CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 827692716b7SZach Atkins is_active = vec == CEED_VECTOR_ACTIVE; 828692716b7SZach Atkins if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i])); 829692716b7SZach Atkins CeedCallBackend(CeedVectorDestroy(&vec)); 830692716b7SZach Atkins } 831692716b7SZach Atkins } 832692716b7SZach Atkins 833692716b7SZach Atkins // Restore point coordinates 834692716b7SZach Atkins { 835692716b7SZach Atkins CeedVector vec; 836692716b7SZach Atkins 837692716b7SZach Atkins CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 838692716b7SZach Atkins CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords)); 839692716b7SZach Atkins CeedCallBackend(CeedVectorDestroy(&vec)); 840692716b7SZach Atkins } 841692716b7SZach Atkins 842692716b7SZach Atkins // Restore context data 843692716b7SZach Atkins CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c)); 844692716b7SZach Atkins 845692716b7SZach Atkins // Restore assembly array 846692716b7SZach Atkins CeedCallBackend(CeedVectorRestoreArray(assembled, &assembled_array)); 847692716b7SZach Atkins 848692716b7SZach Atkins // Cleanup 849692716b7SZach Atkins CeedCallBackend(CeedQFunctionDestroy(&qf)); 850692716b7SZach Atkins if (!is_run_good) { 851692716b7SZach Atkins CeedDebug(ceed, "Single Operator Assemble at Points run failed, using fallback\n"); 852692716b7SZach Atkins data->use_assembly_fallback = true; 853692716b7SZach Atkins } 854692716b7SZach Atkins } 855692716b7SZach Atkins CeedCallBackend(CeedDestroy(&ceed)); 856692716b7SZach Atkins 857692716b7SZach Atkins // Fallback, if needed 858692716b7SZach Atkins if (data->use_assembly_fallback) { 859692716b7SZach Atkins CeedOperator op_fallback; 860692716b7SZach Atkins 861ca38d01dSJeremy L Thompson CeedDebug(CeedOperatorReturnCeed(op), "\nFalling back to /gpu/hip/ref CeedOperator for AtPoints SingleOperatorAssemble\n"); 862692716b7SZach Atkins CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback)); 863ed094490SJeremy L Thompson CeedCallBackend(CeedOperatorAssembleSingle(op_fallback, offset, assembled)); 864692716b7SZach Atkins return CEED_ERROR_SUCCESS; 865692716b7SZach Atkins } 866692716b7SZach Atkins return CEED_ERROR_SUCCESS; 867692716b7SZach Atkins } 868692716b7SZach Atkins 869692716b7SZach Atkins //------------------------------------------------------------------------------ 8707d8d0e25Snbeams // Create operator 8717d8d0e25Snbeams //------------------------------------------------------------------------------ 8727d8d0e25Snbeams int CeedOperatorCreate_Hip_gen(CeedOperator op) { 8730183ed61SJeremy L Thompson bool is_composite, is_at_points; 8747d8d0e25Snbeams Ceed ceed; 8757d8d0e25Snbeams CeedOperator_Hip_gen *impl; 8767d8d0e25Snbeams 877b7453713SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 8782b730f8bSJeremy L Thompson CeedCallBackend(CeedCalloc(1, &impl)); 8792b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorSetData(op, impl)); 880c99afcd8SJeremy L Thompson CeedCall(CeedOperatorIsComposite(op, &is_composite)); 881c99afcd8SJeremy L Thompson if (is_composite) { 882c99afcd8SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAddComposite", CeedOperatorApplyAddComposite_Hip_gen)); 883c99afcd8SJeremy L Thompson } else { 8842b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip_gen)); 885c99afcd8SJeremy L Thompson } 8860183ed61SJeremy L Thompson CeedCall(CeedOperatorIsAtPoints(op, &is_at_points)); 8870183ed61SJeremy L Thompson if (is_at_points) { 8880183ed61SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip_gen)); 889ed094490SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedOperatorAssembleSingleAtPoints_Hip_gen)); 8900183ed61SJeremy L Thompson } 8915daefc96SJeremy L Thompson if (!is_at_points) { 8925daefc96SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Hip_gen)); 8935daefc96SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Hip_gen)); 8945daefc96SJeremy L Thompson } 8952b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip_gen)); 8969bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 897e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 8987d8d0e25Snbeams } 8992a86cc9dSSebastian Grimberg 9007d8d0e25Snbeams //------------------------------------------------------------------------------ 901