1d275d636SJeremy L Thompson // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors. 23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 37d8d0e25Snbeams // 43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause 57d8d0e25Snbeams // 63d8e8822SJeremy L Thompson // This file is part of CEED: http://github.com/ceed 77d8d0e25Snbeams 849aac155SJeremy L Thompson #include <ceed.h> 9ec3da8bcSJed Brown #include <ceed/backend.h> 1049aac155SJeremy L Thompson #include <ceed/jit-source/hip/hip-types.h> 113d576824SJeremy L Thompson #include <stddef.h> 123a2968d6SJeremy L Thompson #include <hip/hiprtc.h> 132b730f8bSJeremy L Thompson 14b2165e7aSSebastian Grimberg #include "../hip/ceed-hip-common.h" 157d8d0e25Snbeams #include "../hip/ceed-hip-compile.h" 162b730f8bSJeremy L Thompson #include "ceed-hip-gen-operator-build.h" 172b730f8bSJeremy L Thompson #include "ceed-hip-gen.h" 187d8d0e25Snbeams 197d8d0e25Snbeams //------------------------------------------------------------------------------ 207d8d0e25Snbeams // Destroy operator 217d8d0e25Snbeams //------------------------------------------------------------------------------ 227d8d0e25Snbeams static int CeedOperatorDestroy_Hip_gen(CeedOperator op) { 233a2968d6SJeremy L Thompson Ceed ceed; 247d8d0e25Snbeams CeedOperator_Hip_gen *impl; 256eee1ffcSZach Atkins bool is_composite; 26b7453713SJeremy L Thompson 273a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 282b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &impl)); 296eee1ffcSZach Atkins CeedCallBackend(CeedOperatorIsComposite(op, &is_composite)); 306eee1ffcSZach Atkins if (is_composite) { 316eee1ffcSZach Atkins CeedInt num_suboperators; 326eee1ffcSZach Atkins 336eee1ffcSZach Atkins CeedCall(CeedCompositeOperatorGetNumSub(op, &num_suboperators)); 346eee1ffcSZach Atkins for (CeedInt i = 0; i < num_suboperators; i++) { 356eee1ffcSZach Atkins if (impl->streams[i]) CeedCallHip(ceed, hipStreamDestroy(impl->streams[i])); 366eee1ffcSZach Atkins impl->streams[i] = NULL; 376eee1ffcSZach Atkins } 386eee1ffcSZach Atkins } 398b7d3340SJeremy L Thompson if (impl->module) CeedCallHip(ceed, hipModuleUnload(impl->module)); 400183ed61SJeremy L Thompson if (impl->module_assemble_full) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_full)); 410183ed61SJeremy L Thompson if (impl->module_assemble_diagonal) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_diagonal)); 42*5daefc96SJeremy L Thompson if (impl->module_assemble_qfunction) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_qfunction)); 433a2968d6SJeremy L Thompson if (impl->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)impl->points.num_per_elem)); 442b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&impl)); 453a2968d6SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 46e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 477d8d0e25Snbeams } 487d8d0e25Snbeams 497d8d0e25Snbeams //------------------------------------------------------------------------------ 507d8d0e25Snbeams // Apply and add to output 517d8d0e25Snbeams //------------------------------------------------------------------------------ 52e9c76bddSJeremy L Thompson static int CeedOperatorApplyAddCore_Hip_gen(CeedOperator op, hipStream_t stream, const CeedScalar *input_arr, CeedScalar *output_arr, 53e9c76bddSJeremy L Thompson bool *is_run_good, CeedRequest *request) { 54ea04d07fSJeremy L Thompson bool is_at_points, is_tensor; 557d8d0e25Snbeams Ceed ceed; 56b7453713SJeremy L Thompson CeedInt num_elem, num_input_fields, num_output_fields; 57b7453713SJeremy L Thompson CeedEvalMode eval_mode; 58b7453713SJeremy L Thompson CeedQFunctionField *qf_input_fields, *qf_output_fields; 597d8d0e25Snbeams CeedQFunction_Hip_gen *qf_data; 60b7453713SJeremy L Thompson CeedQFunction qf; 61b7453713SJeremy L Thompson CeedOperatorField *op_input_fields, *op_output_fields; 62b7453713SJeremy L Thompson CeedOperator_Hip_gen *data; 63b7453713SJeremy L Thompson 648d12f40eSJeremy L Thompson // Creation of the operator 65ea04d07fSJeremy L Thompson CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, is_run_good)); 66ea04d07fSJeremy L Thompson if (!(*is_run_good)) return CEED_ERROR_SUCCESS; 67f6eafd79SJeremy L Thompson 68c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 69c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &data)); 70c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 71c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &qf_data)); 72c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 738d12f40eSJeremy L Thompson CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 74c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 75c11e12f4SJeremy L Thompson 767d8d0e25Snbeams // Input vectors 779e201c85SYohann for (CeedInt i = 0; i < num_input_fields; i++) { 782b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 799e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 809e201c85SYohann data->fields.inputs[i] = NULL; 817d8d0e25Snbeams } else { 823efc994bSJeremy L Thompson bool is_active; 83b7453713SJeremy L Thompson CeedVector vec; 84b7453713SJeremy L Thompson 857d8d0e25Snbeams // Get input vector 862b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 873efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 88ea04d07fSJeremy L Thompson if (is_active) data->fields.inputs[i] = input_arr; 89ea04d07fSJeremy L Thompson else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i])); 90ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 917d8d0e25Snbeams } 927d8d0e25Snbeams } 937d8d0e25Snbeams 947d8d0e25Snbeams // Output vectors 959e201c85SYohann for (CeedInt i = 0; i < num_output_fields; i++) { 962b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 979e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 989e201c85SYohann data->fields.outputs[i] = NULL; 997d8d0e25Snbeams } else { 1003efc994bSJeremy L Thompson bool is_active; 101b7453713SJeremy L Thompson CeedVector vec; 102b7453713SJeremy L Thompson 1037d8d0e25Snbeams // Get output vector 1042b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 1053efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 106ea04d07fSJeremy L Thompson if (is_active) data->fields.outputs[i] = output_arr; 1070c8fbeedSJeremy L Thompson else CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.outputs[i])); 108ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 1097d8d0e25Snbeams } 1107d8d0e25Snbeams } 1117d8d0e25Snbeams 1123a2968d6SJeremy L Thompson // Point coordinates, if needed 1133a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorIsAtPoints(op, &is_at_points)); 1143a2968d6SJeremy L Thompson if (is_at_points) { 1153a2968d6SJeremy L Thompson // Coords 1163a2968d6SJeremy L Thompson CeedVector vec; 1173a2968d6SJeremy L Thompson 1183a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 1193a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords)); 1203a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 1213a2968d6SJeremy L Thompson 1223a2968d6SJeremy L Thompson // Points per elem 1233a2968d6SJeremy L Thompson if (num_elem != data->points.num_elem) { 1243a2968d6SJeremy L Thompson CeedInt *points_per_elem; 1253a2968d6SJeremy L Thompson const CeedInt num_bytes = num_elem * sizeof(CeedInt); 1263a2968d6SJeremy L Thompson CeedElemRestriction rstr_points = NULL; 1273a2968d6SJeremy L Thompson 1283a2968d6SJeremy L Thompson data->points.num_elem = num_elem; 1293a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL)); 1303a2968d6SJeremy L Thompson CeedCallBackend(CeedCalloc(num_elem, &points_per_elem)); 1313a2968d6SJeremy L Thompson for (CeedInt e = 0; e < num_elem; e++) { 1323a2968d6SJeremy L Thompson CeedInt num_points_elem; 1333a2968d6SJeremy L Thompson 1343a2968d6SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem)); 1353a2968d6SJeremy L Thompson points_per_elem[e] = num_points_elem; 1363a2968d6SJeremy L Thompson } 1373a2968d6SJeremy L Thompson if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem)); 1383a2968d6SJeremy L Thompson CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes)); 1393a2968d6SJeremy L Thompson CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice)); 1403a2968d6SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 1413a2968d6SJeremy L Thompson CeedCallBackend(CeedFree(&points_per_elem)); 1423a2968d6SJeremy L Thompson } 1433a2968d6SJeremy L Thompson } 1443a2968d6SJeremy L Thompson 1457d8d0e25Snbeams // Get context data 1462b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c)); 1477d8d0e25Snbeams 1487d8d0e25Snbeams // Apply operator 1493a2968d6SJeremy L Thompson void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points}; 150b7453713SJeremy L Thompson 1519123fb08SJeremy L Thompson CeedCallBackend(CeedOperatorHasTensorBases(op, &is_tensor)); 152a61b1c91SJeremy L Thompson CeedInt block_sizes[3] = {data->thread_1d, ((!is_tensor || data->dim == 1) ? 1 : data->thread_1d), -1}; 153f82027a4SJeremy L Thompson 154f82027a4SJeremy L Thompson if (is_tensor) { 15574398b5aSJeremy L Thompson CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes)); 15690c30374SJeremy L Thompson if (is_at_points) block_sizes[2] = 1; 157f82027a4SJeremy L Thompson } else { 158a61b1c91SJeremy L Thompson CeedInt elems_per_block = 64 * data->thread_1d > 256 ? 256 / data->thread_1d : 64; 159f82027a4SJeremy L Thompson 160f82027a4SJeremy L Thompson elems_per_block = elems_per_block > 0 ? elems_per_block : 1; 161f82027a4SJeremy L Thompson block_sizes[2] = elems_per_block; 162f82027a4SJeremy L Thompson } 16374398b5aSJeremy L Thompson if (data->dim == 1 || !is_tensor) { 1642b730f8bSJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 165a61b1c91SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar); 166b7453713SJeremy L Thompson 1678d12f40eSJeremy L Thompson CeedCallBackend( 168e9c76bddSJeremy L Thompson CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, is_run_good, opargs)); 16974398b5aSJeremy L Thompson } else if (data->dim == 2) { 1702b730f8bSJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 171a61b1c91SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 172b7453713SJeremy L Thompson 1738d12f40eSJeremy L Thompson CeedCallBackend( 174e9c76bddSJeremy L Thompson CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, is_run_good, opargs)); 17574398b5aSJeremy L Thompson } else if (data->dim == 3) { 1762b730f8bSJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 177a61b1c91SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 178b7453713SJeremy L Thompson 1798d12f40eSJeremy L Thompson CeedCallBackend( 180e9c76bddSJeremy L Thompson CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, is_run_good, opargs)); 1817d8d0e25Snbeams } 1827d8d0e25Snbeams 1837d8d0e25Snbeams // Restore input arrays 1849e201c85SYohann for (CeedInt i = 0; i < num_input_fields; i++) { 1852b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 1869e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 1877d8d0e25Snbeams } else { 1883efc994bSJeremy L Thompson bool is_active; 189b7453713SJeremy L Thompson CeedVector vec; 190b7453713SJeremy L Thompson 1912b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1923efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 193ea04d07fSJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i])); 194ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 1957d8d0e25Snbeams } 1967d8d0e25Snbeams } 1977d8d0e25Snbeams 1987d8d0e25Snbeams // Restore output arrays 1999e201c85SYohann for (CeedInt i = 0; i < num_output_fields; i++) { 2002b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 2019e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 2027d8d0e25Snbeams } else { 2033efc994bSJeremy L Thompson bool is_active; 204b7453713SJeremy L Thompson CeedVector vec; 205b7453713SJeremy L Thompson 2062b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 2073efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 208ea04d07fSJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArray(vec, &data->fields.outputs[i])); 209ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 2107d8d0e25Snbeams } 2117d8d0e25Snbeams } 2127d8d0e25Snbeams 2133a2968d6SJeremy L Thompson // Restore point coordinates, if needed 2143a2968d6SJeremy L Thompson if (is_at_points) { 2153a2968d6SJeremy L Thompson CeedVector vec; 2163a2968d6SJeremy L Thompson 2173a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 2183a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords)); 2193a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 2203a2968d6SJeremy L Thompson } 2213a2968d6SJeremy L Thompson 2227d8d0e25Snbeams // Restore context data 2232b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c)); 2248d12f40eSJeremy L Thompson 2258d12f40eSJeremy L Thompson // Cleanup 2269bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 227c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionDestroy(&qf)); 228ea04d07fSJeremy L Thompson if (!(*is_run_good)) data->use_fallback = true; 229ea04d07fSJeremy L Thompson return CEED_ERROR_SUCCESS; 230ea04d07fSJeremy L Thompson } 2318d12f40eSJeremy L Thompson 232ea04d07fSJeremy L Thompson static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) { 233ea04d07fSJeremy L Thompson bool is_run_good = false; 234ea04d07fSJeremy L Thompson const CeedScalar *input_arr = NULL; 235ea04d07fSJeremy L Thompson CeedScalar *output_arr = NULL; 236ea04d07fSJeremy L Thompson 237ea04d07fSJeremy L Thompson // Try to run kernel 238ea04d07fSJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr)); 239ea04d07fSJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr)); 240087855afSJeremy L Thompson CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(op, NULL, input_arr, output_arr, &is_run_good, request)); 241ea04d07fSJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr)); 242ea04d07fSJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr)); 243ea04d07fSJeremy L Thompson 244ea04d07fSJeremy L Thompson // Fallback on unsuccessful run 245ea04d07fSJeremy L Thompson if (!is_run_good) { 2468d12f40eSJeremy L Thompson CeedOperator op_fallback; 2478d12f40eSJeremy L Thompson 248ea04d07fSJeremy L Thompson CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator"); 2498d12f40eSJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback)); 2508d12f40eSJeremy L Thompson CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request)); 2518d12f40eSJeremy L Thompson } 252e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 2537d8d0e25Snbeams } 2547d8d0e25Snbeams 255c99afcd8SJeremy L Thompson static int CeedOperatorApplyAddComposite_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) { 2566eee1ffcSZach Atkins bool is_run_good[CEED_COMPOSITE_MAX] = {true}; 257c99afcd8SJeremy L Thompson CeedInt num_suboperators; 258c99afcd8SJeremy L Thompson const CeedScalar *input_arr = NULL; 2596eee1ffcSZach Atkins CeedScalar *output_arr; 260087855afSJeremy L Thompson Ceed ceed; 2616eee1ffcSZach Atkins CeedOperator_Hip_gen *impl; 262c99afcd8SJeremy L Thompson CeedOperator *sub_operators; 263c99afcd8SJeremy L Thompson 264087855afSJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 2656eee1ffcSZach Atkins CeedCallBackend(CeedOperatorGetData(op, &impl)); 2666eee1ffcSZach Atkins CeedCallBackend(CeedCompositeOperatorGetNumSub(op, &num_suboperators)); 2676eee1ffcSZach Atkins CeedCallBackend(CeedCompositeOperatorGetSubList(op, &sub_operators)); 268c99afcd8SJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr)); 269c99afcd8SJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr)); 270c99afcd8SJeremy L Thompson for (CeedInt i = 0; i < num_suboperators; i++) { 271c99afcd8SJeremy L Thompson CeedInt num_elem = 0; 272c99afcd8SJeremy L Thompson 2736eee1ffcSZach Atkins CeedCallBackend(CeedOperatorGetNumElements(sub_operators[i], &num_elem)); 274c99afcd8SJeremy L Thompson if (num_elem > 0) { 2756eee1ffcSZach Atkins if (!impl->streams[i]) CeedCallHip(ceed, hipStreamCreate(&impl->streams[i])); 2766eee1ffcSZach Atkins CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(sub_operators[i], impl->streams[i], input_arr, output_arr, &is_run_good[i], request)); 2776eee1ffcSZach Atkins } else { 2786eee1ffcSZach Atkins is_run_good[i] = true; 2796eee1ffcSZach Atkins } 2806eee1ffcSZach Atkins } 281087855afSJeremy L Thompson 2826eee1ffcSZach Atkins for (CeedInt i = 0; i < num_suboperators; i++) { 2836eee1ffcSZach Atkins if (impl->streams[i]) { 2846eee1ffcSZach Atkins if (is_run_good[i]) CeedCallHip(ceed, hipStreamSynchronize(impl->streams[i])); 285c99afcd8SJeremy L Thompson } 286c99afcd8SJeremy L Thompson } 287c99afcd8SJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr)); 288c99afcd8SJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr)); 289087855afSJeremy L Thompson CeedCallHip(ceed, hipDeviceSynchronize()); 290c99afcd8SJeremy L Thompson 291c99afcd8SJeremy L Thompson // Fallback on unsuccessful run 292c99afcd8SJeremy L Thompson for (CeedInt i = 0; i < num_suboperators; i++) { 293c99afcd8SJeremy L Thompson if (!is_run_good[i]) { 294c99afcd8SJeremy L Thompson CeedOperator op_fallback; 295c99afcd8SJeremy L Thompson 296087855afSJeremy L Thompson CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator"); 297c99afcd8SJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(sub_operators[i], &op_fallback)); 298c99afcd8SJeremy L Thompson CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request)); 299c99afcd8SJeremy L Thompson } 300c99afcd8SJeremy L Thompson } 301087855afSJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 302c99afcd8SJeremy L Thompson return CEED_ERROR_SUCCESS; 303c99afcd8SJeremy L Thompson } 304c99afcd8SJeremy L Thompson 3057d8d0e25Snbeams //------------------------------------------------------------------------------ 306*5daefc96SJeremy L Thompson // QFunction assembly 307*5daefc96SJeremy L Thompson //------------------------------------------------------------------------------ 308*5daefc96SJeremy L Thompson static int CeedOperatorLinearAssembleQFunctionCore_Hip_gen(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr, 309*5daefc96SJeremy L Thompson CeedRequest *request) { 310*5daefc96SJeremy L Thompson Ceed ceed; 311*5daefc96SJeremy L Thompson CeedOperator_Hip_gen *data; 312*5daefc96SJeremy L Thompson 313*5daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 314*5daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &data)); 315*5daefc96SJeremy L Thompson 316*5daefc96SJeremy L Thompson // Build the assembly kernel 317*5daefc96SJeremy L Thompson if (!data->assemble_qfunction && !data->use_assembly_fallback) { 318*5daefc96SJeremy L Thompson bool is_build_good = false; 319*5daefc96SJeremy L Thompson 320*5daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_build_good)); 321*5daefc96SJeremy L Thompson if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelLinearAssembleQFunction_Hip_gen(op, &is_build_good)); 322*5daefc96SJeremy L Thompson if (!is_build_good) data->use_assembly_fallback = true; 323*5daefc96SJeremy L Thompson } 324*5daefc96SJeremy L Thompson 325*5daefc96SJeremy L Thompson // Try assembly 326*5daefc96SJeremy L Thompson if (!data->use_assembly_fallback) { 327*5daefc96SJeremy L Thompson bool is_run_good = true; 328*5daefc96SJeremy L Thompson Ceed_Hip *hip_data; 329*5daefc96SJeremy L Thompson CeedInt num_elem, num_input_fields, num_output_fields; 330*5daefc96SJeremy L Thompson CeedEvalMode eval_mode; 331*5daefc96SJeremy L Thompson CeedScalar *assembled_array; 332*5daefc96SJeremy L Thompson CeedQFunctionField *qf_input_fields, *qf_output_fields; 333*5daefc96SJeremy L Thompson CeedQFunction_Hip_gen *qf_data; 334*5daefc96SJeremy L Thompson CeedQFunction qf; 335*5daefc96SJeremy L Thompson CeedOperatorField *op_input_fields, *op_output_fields; 336*5daefc96SJeremy L Thompson 337*5daefc96SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &hip_data)); 338*5daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 339*5daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &qf_data)); 340*5daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 341*5daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 342*5daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 343*5daefc96SJeremy L Thompson 344*5daefc96SJeremy L Thompson // Input vectors 345*5daefc96SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) { 346*5daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 347*5daefc96SJeremy L Thompson if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 348*5daefc96SJeremy L Thompson data->fields.inputs[i] = NULL; 349*5daefc96SJeremy L Thompson } else { 350*5daefc96SJeremy L Thompson bool is_active; 351*5daefc96SJeremy L Thompson CeedVector vec; 352*5daefc96SJeremy L Thompson 353*5daefc96SJeremy L Thompson // Get input vector 354*5daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 355*5daefc96SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 356*5daefc96SJeremy L Thompson if (is_active) data->fields.inputs[i] = NULL; 357*5daefc96SJeremy L Thompson else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i])); 358*5daefc96SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 359*5daefc96SJeremy L Thompson } 360*5daefc96SJeremy L Thompson } 361*5daefc96SJeremy L Thompson 362*5daefc96SJeremy L Thompson // Get context data 363*5daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c)); 364*5daefc96SJeremy L Thompson 365*5daefc96SJeremy L Thompson // Build objects if needed 366*5daefc96SJeremy L Thompson if (build_objects) { 367*5daefc96SJeremy L Thompson CeedInt qf_size_in = 0, qf_size_out = 0, Q; 368*5daefc96SJeremy L Thompson 369*5daefc96SJeremy L Thompson // Count number of active input fields 370*5daefc96SJeremy L Thompson { 371*5daefc96SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) { 372*5daefc96SJeremy L Thompson CeedInt field_size; 373*5daefc96SJeremy L Thompson CeedVector vec; 374*5daefc96SJeremy L Thompson 375*5daefc96SJeremy L Thompson // Get input vector 376*5daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 377*5daefc96SJeremy L Thompson // Check if active input 378*5daefc96SJeremy L Thompson if (vec == CEED_VECTOR_ACTIVE) { 379*5daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size)); 380*5daefc96SJeremy L Thompson qf_size_in += field_size; 381*5daefc96SJeremy L Thompson } 382*5daefc96SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 383*5daefc96SJeremy L Thompson } 384*5daefc96SJeremy L Thompson CeedCheck(qf_size_in > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 385*5daefc96SJeremy L Thompson } 386*5daefc96SJeremy L Thompson 387*5daefc96SJeremy L Thompson // Count number of active output fields 388*5daefc96SJeremy L Thompson { 389*5daefc96SJeremy L Thompson for (CeedInt i = 0; i < num_output_fields; i++) { 390*5daefc96SJeremy L Thompson CeedInt field_size; 391*5daefc96SJeremy L Thompson CeedVector vec; 392*5daefc96SJeremy L Thompson 393*5daefc96SJeremy L Thompson // Get output vector 394*5daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 395*5daefc96SJeremy L Thompson // Check if active output 396*5daefc96SJeremy L Thompson if (vec == CEED_VECTOR_ACTIVE) { 397*5daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &field_size)); 398*5daefc96SJeremy L Thompson qf_size_out += field_size; 399*5daefc96SJeremy L Thompson } 400*5daefc96SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 401*5daefc96SJeremy L Thompson } 402*5daefc96SJeremy L Thompson CeedCheck(qf_size_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs"); 403*5daefc96SJeremy L Thompson } 404*5daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); 405*5daefc96SJeremy L Thompson 406*5daefc96SJeremy L Thompson // Actually build objects now 407*5daefc96SJeremy L Thompson const CeedSize l_size = (CeedSize)num_elem * Q * qf_size_in * qf_size_out; 408*5daefc96SJeremy L Thompson CeedInt strides[3] = {1, num_elem * Q, Q}; /* *NOPAD* */ 409*5daefc96SJeremy L Thompson 410*5daefc96SJeremy L Thompson // Create output restriction 411*5daefc96SJeremy L Thompson CeedCallBackend(CeedElemRestrictionCreateStrided(ceed, num_elem, Q, qf_size_in * qf_size_out, 412*5daefc96SJeremy L Thompson (CeedSize)qf_size_in * (CeedSize)qf_size_out * (CeedSize)num_elem * (CeedSize)Q, strides, 413*5daefc96SJeremy L Thompson rstr)); 414*5daefc96SJeremy L Thompson // Create assembled vector 415*5daefc96SJeremy L Thompson CeedCallBackend(CeedVectorCreate(ceed, l_size, assembled)); 416*5daefc96SJeremy L Thompson } 417*5daefc96SJeremy L Thompson 418*5daefc96SJeremy L Thompson // Assembly array 419*5daefc96SJeremy L Thompson CeedCallBackend(CeedVectorGetArrayWrite(*assembled, CEED_MEM_DEVICE, &assembled_array)); 420*5daefc96SJeremy L Thompson 421*5daefc96SJeremy L Thompson // Assemble QFunction 422*5daefc96SJeremy L Thompson bool is_tensor = false; 423*5daefc96SJeremy L Thompson void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points, &assembled_array}; 424*5daefc96SJeremy L Thompson 425*5daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorHasTensorBases(op, &is_tensor)); 426*5daefc96SJeremy L Thompson CeedInt block_sizes[3] = {data->thread_1d, ((!is_tensor || data->dim == 1) ? 1 : data->thread_1d), -1}; 427*5daefc96SJeremy L Thompson 428*5daefc96SJeremy L Thompson if (is_tensor) { 429*5daefc96SJeremy L Thompson CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes)); 430*5daefc96SJeremy L Thompson } else { 431*5daefc96SJeremy L Thompson CeedInt elems_per_block = 64 * data->thread_1d > 256 ? 256 / data->thread_1d : 64; 432*5daefc96SJeremy L Thompson 433*5daefc96SJeremy L Thompson elems_per_block = elems_per_block > 0 ? elems_per_block : 1; 434*5daefc96SJeremy L Thompson block_sizes[2] = elems_per_block; 435*5daefc96SJeremy L Thompson } 436*5daefc96SJeremy L Thompson if (data->dim == 1 || !is_tensor) { 437*5daefc96SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 438*5daefc96SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar); 439*5daefc96SJeremy L Thompson 440*5daefc96SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_qfunction, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 441*5daefc96SJeremy L Thompson sharedMem, &is_run_good, opargs)); 442*5daefc96SJeremy L Thompson } else if (data->dim == 2) { 443*5daefc96SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 444*5daefc96SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 445*5daefc96SJeremy L Thompson 446*5daefc96SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_qfunction, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 447*5daefc96SJeremy L Thompson sharedMem, &is_run_good, opargs)); 448*5daefc96SJeremy L Thompson } else if (data->dim == 3) { 449*5daefc96SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 450*5daefc96SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 451*5daefc96SJeremy L Thompson 452*5daefc96SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_qfunction, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 453*5daefc96SJeremy L Thompson sharedMem, &is_run_good, opargs)); 454*5daefc96SJeremy L Thompson } 455*5daefc96SJeremy L Thompson 456*5daefc96SJeremy L Thompson // Restore input arrays 457*5daefc96SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) { 458*5daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 459*5daefc96SJeremy L Thompson if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 460*5daefc96SJeremy L Thompson } else { 461*5daefc96SJeremy L Thompson bool is_active; 462*5daefc96SJeremy L Thompson CeedVector vec; 463*5daefc96SJeremy L Thompson 464*5daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 465*5daefc96SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 466*5daefc96SJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i])); 467*5daefc96SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 468*5daefc96SJeremy L Thompson } 469*5daefc96SJeremy L Thompson } 470*5daefc96SJeremy L Thompson 471*5daefc96SJeremy L Thompson // Restore context data 472*5daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c)); 473*5daefc96SJeremy L Thompson 474*5daefc96SJeremy L Thompson // Restore assembly array 475*5daefc96SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArray(*assembled, &assembled_array)); 476*5daefc96SJeremy L Thompson 477*5daefc96SJeremy L Thompson // Cleanup 478*5daefc96SJeremy L Thompson CeedCallBackend(CeedQFunctionDestroy(&qf)); 479*5daefc96SJeremy L Thompson if (!is_run_good) { 480*5daefc96SJeremy L Thompson data->use_assembly_fallback = true; 481*5daefc96SJeremy L Thompson if (build_objects) { 482*5daefc96SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(assembled)); 483*5daefc96SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(rstr)); 484*5daefc96SJeremy L Thompson } 485*5daefc96SJeremy L Thompson } 486*5daefc96SJeremy L Thompson } 487*5daefc96SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 488*5daefc96SJeremy L Thompson 489*5daefc96SJeremy L Thompson // Fallback, if needed 490*5daefc96SJeremy L Thompson if (data->use_assembly_fallback) { 491*5daefc96SJeremy L Thompson CeedOperator op_fallback; 492*5daefc96SJeremy L Thompson 493*5daefc96SJeremy L Thompson CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator"); 494*5daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback)); 495*5daefc96SJeremy L Thompson CeedCallBackend(CeedOperatorFallbackLinearAssembleQFunctionBuildOrUpdate(op_fallback, assembled, rstr, request)); 496*5daefc96SJeremy L Thompson return CEED_ERROR_SUCCESS; 497*5daefc96SJeremy L Thompson } 498*5daefc96SJeremy L Thompson return CEED_ERROR_SUCCESS; 499*5daefc96SJeremy L Thompson } 500*5daefc96SJeremy L Thompson 501*5daefc96SJeremy L Thompson static int CeedOperatorLinearAssembleQFunction_Hip_gen(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) { 502*5daefc96SJeremy L Thompson return CeedOperatorLinearAssembleQFunctionCore_Hip_gen(op, true, assembled, rstr, request); 503*5daefc96SJeremy L Thompson } 504*5daefc96SJeremy L Thompson 505*5daefc96SJeremy L Thompson static int CeedOperatorLinearAssembleQFunctionUpdate_Hip_gen(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) { 506*5daefc96SJeremy L Thompson return CeedOperatorLinearAssembleQFunctionCore_Hip_gen(op, false, &assembled, &rstr, request); 507*5daefc96SJeremy L Thompson } 508*5daefc96SJeremy L Thompson 509*5daefc96SJeremy L Thompson //------------------------------------------------------------------------------ 5100183ed61SJeremy L Thompson // AtPoints diagonal assembly 5110183ed61SJeremy L Thompson //------------------------------------------------------------------------------ 5120183ed61SJeremy L Thompson static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip_gen(CeedOperator op, CeedVector assembled, CeedRequest *request) { 5130183ed61SJeremy L Thompson Ceed ceed; 5140183ed61SJeremy L Thompson CeedOperator_Hip_gen *data; 5150183ed61SJeremy L Thompson 5160183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 5170183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &data)); 5180183ed61SJeremy L Thompson 5190183ed61SJeremy L Thompson // Build the assembly kernel 5200183ed61SJeremy L Thompson if (!data->assemble_diagonal && !data->use_assembly_fallback) { 5210183ed61SJeremy L Thompson bool is_build_good = false; 5220183ed61SJeremy L Thompson CeedInt num_active_bases_in, num_active_bases_out; 5230183ed61SJeremy L Thompson CeedOperatorAssemblyData assembly_data; 5240183ed61SJeremy L Thompson 5250183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetOperatorAssemblyData(op, &assembly_data)); 5260183ed61SJeremy L Thompson CeedCallBackend( 5270183ed61SJeremy L Thompson CeedOperatorAssemblyDataGetEvalModes(assembly_data, &num_active_bases_in, NULL, NULL, NULL, &num_active_bases_out, NULL, NULL, NULL, NULL)); 5280183ed61SJeremy L Thompson if (num_active_bases_in == num_active_bases_out) { 5290183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_build_good)); 5300183ed61SJeremy L Thompson if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelDiagonalAssemblyAtPoints_Hip_gen(op, &is_build_good)); 5310183ed61SJeremy L Thompson } 5320183ed61SJeremy L Thompson if (!is_build_good) data->use_assembly_fallback = true; 5330183ed61SJeremy L Thompson } 5340183ed61SJeremy L Thompson 5350183ed61SJeremy L Thompson // Try assembly 5360183ed61SJeremy L Thompson if (!data->use_assembly_fallback) { 5370183ed61SJeremy L Thompson bool is_run_good = true; 5380183ed61SJeremy L Thompson Ceed_Hip *hip_data; 5390183ed61SJeremy L Thompson CeedInt num_elem, num_input_fields, num_output_fields; 5400183ed61SJeremy L Thompson CeedEvalMode eval_mode; 5410183ed61SJeremy L Thompson CeedScalar *assembled_array; 5420183ed61SJeremy L Thompson CeedQFunctionField *qf_input_fields, *qf_output_fields; 5430183ed61SJeremy L Thompson CeedQFunction_Hip_gen *qf_data; 5440183ed61SJeremy L Thompson CeedQFunction qf; 5450183ed61SJeremy L Thompson CeedOperatorField *op_input_fields, *op_output_fields; 5460183ed61SJeremy L Thompson 5470183ed61SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &hip_data)); 5480183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 5490183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &qf_data)); 5500183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 5510183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 5520183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 5530183ed61SJeremy L Thompson 5540183ed61SJeremy L Thompson // Input vectors 5550183ed61SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) { 5560183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 5570183ed61SJeremy L Thompson if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 5580183ed61SJeremy L Thompson data->fields.inputs[i] = NULL; 5590183ed61SJeremy L Thompson } else { 5600183ed61SJeremy L Thompson bool is_active; 5610183ed61SJeremy L Thompson CeedVector vec; 5620183ed61SJeremy L Thompson 5630183ed61SJeremy L Thompson // Get input vector 5640183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 5650183ed61SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 5660183ed61SJeremy L Thompson if (is_active) data->fields.inputs[i] = NULL; 5670183ed61SJeremy L Thompson else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i])); 5680183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 5690183ed61SJeremy L Thompson } 5700183ed61SJeremy L Thompson } 5710183ed61SJeremy L Thompson 5720183ed61SJeremy L Thompson // Point coordinates 5730183ed61SJeremy L Thompson { 5740183ed61SJeremy L Thompson CeedVector vec; 5750183ed61SJeremy L Thompson 5760183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 5770183ed61SJeremy L Thompson CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords)); 5780183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 5790183ed61SJeremy L Thompson 5800183ed61SJeremy L Thompson // Points per elem 5810183ed61SJeremy L Thompson if (num_elem != data->points.num_elem) { 5820183ed61SJeremy L Thompson CeedInt *points_per_elem; 5830183ed61SJeremy L Thompson const CeedInt num_bytes = num_elem * sizeof(CeedInt); 5840183ed61SJeremy L Thompson CeedElemRestriction rstr_points = NULL; 5850183ed61SJeremy L Thompson 5860183ed61SJeremy L Thompson data->points.num_elem = num_elem; 5870183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL)); 5880183ed61SJeremy L Thompson CeedCallBackend(CeedCalloc(num_elem, &points_per_elem)); 5890183ed61SJeremy L Thompson for (CeedInt e = 0; e < num_elem; e++) { 5900183ed61SJeremy L Thompson CeedInt num_points_elem; 5910183ed61SJeremy L Thompson 5920183ed61SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem)); 5930183ed61SJeremy L Thompson points_per_elem[e] = num_points_elem; 5940183ed61SJeremy L Thompson } 5950183ed61SJeremy L Thompson if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem)); 5960183ed61SJeremy L Thompson CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes)); 5970183ed61SJeremy L Thompson CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice)); 5980183ed61SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 5990183ed61SJeremy L Thompson CeedCallBackend(CeedFree(&points_per_elem)); 6000183ed61SJeremy L Thompson } 6010183ed61SJeremy L Thompson } 6020183ed61SJeremy L Thompson 6030183ed61SJeremy L Thompson // Get context data 6040183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c)); 6050183ed61SJeremy L Thompson 6060183ed61SJeremy L Thompson // Assembly array 6070183ed61SJeremy L Thompson CeedCallBackend(CeedVectorGetArray(assembled, CEED_MEM_DEVICE, &assembled_array)); 6080183ed61SJeremy L Thompson 6090183ed61SJeremy L Thompson // Assemble diagonal 6100183ed61SJeremy L Thompson void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points, &assembled_array}; 6110183ed61SJeremy L Thompson 6120183ed61SJeremy L Thompson CeedInt block_sizes[3] = {data->thread_1d, (data->dim == 1 ? 1 : data->thread_1d), -1}; 6130183ed61SJeremy L Thompson 6140183ed61SJeremy L Thompson CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes)); 6150183ed61SJeremy L Thompson block_sizes[2] = 1; 6160183ed61SJeremy L Thompson if (data->dim == 1) { 6170183ed61SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 6180183ed61SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar); 6190183ed61SJeremy L Thompson 6200183ed61SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 6210183ed61SJeremy L Thompson sharedMem, &is_run_good, opargs)); 6220183ed61SJeremy L Thompson } else if (data->dim == 2) { 6230183ed61SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 6240183ed61SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 6250183ed61SJeremy L Thompson 6260183ed61SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 6270183ed61SJeremy L Thompson sharedMem, &is_run_good, opargs)); 6280183ed61SJeremy L Thompson } else if (data->dim == 3) { 6290183ed61SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 6300183ed61SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 6310183ed61SJeremy L Thompson 6320183ed61SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 6330183ed61SJeremy L Thompson sharedMem, &is_run_good, opargs)); 6340183ed61SJeremy L Thompson } 635692716b7SZach Atkins CeedCallHip(ceed, hipDeviceSynchronize()); 6360183ed61SJeremy L Thompson 6370183ed61SJeremy L Thompson // Restore input arrays 6380183ed61SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) { 6390183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 6400183ed61SJeremy L Thompson if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 6410183ed61SJeremy L Thompson } else { 6420183ed61SJeremy L Thompson bool is_active; 6430183ed61SJeremy L Thompson CeedVector vec; 6440183ed61SJeremy L Thompson 6450183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 6460183ed61SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 6470183ed61SJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i])); 6480183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 6490183ed61SJeremy L Thompson } 6500183ed61SJeremy L Thompson } 6510183ed61SJeremy L Thompson 6520183ed61SJeremy L Thompson // Restore point coordinates 6530183ed61SJeremy L Thompson { 6540183ed61SJeremy L Thompson CeedVector vec; 6550183ed61SJeremy L Thompson 6560183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 6570183ed61SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords)); 6580183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 6590183ed61SJeremy L Thompson } 6600183ed61SJeremy L Thompson 6610183ed61SJeremy L Thompson // Restore context data 6620183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c)); 6630183ed61SJeremy L Thompson 6640183ed61SJeremy L Thompson // Restore assembly array 6650183ed61SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArray(assembled, &assembled_array)); 6660183ed61SJeremy L Thompson 6670183ed61SJeremy L Thompson // Cleanup 6680183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionDestroy(&qf)); 6690183ed61SJeremy L Thompson if (!is_run_good) data->use_assembly_fallback = true; 6700183ed61SJeremy L Thompson } 6710183ed61SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 6720183ed61SJeremy L Thompson 6730183ed61SJeremy L Thompson // Fallback, if needed 6740183ed61SJeremy L Thompson if (data->use_assembly_fallback) { 6750183ed61SJeremy L Thompson CeedOperator op_fallback; 6760183ed61SJeremy L Thompson 6770183ed61SJeremy L Thompson CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator"); 6780183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback)); 6790183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorLinearAssembleAddDiagonal(op_fallback, assembled, request)); 6800183ed61SJeremy L Thompson return CEED_ERROR_SUCCESS; 6810183ed61SJeremy L Thompson } 6820183ed61SJeremy L Thompson return CEED_ERROR_SUCCESS; 6830183ed61SJeremy L Thompson } 6840183ed61SJeremy L Thompson 6850183ed61SJeremy L Thompson //------------------------------------------------------------------------------ 686692716b7SZach Atkins // AtPoints full assembly 687692716b7SZach Atkins //------------------------------------------------------------------------------ 688692716b7SZach Atkins static int CeedSingleOperatorAssembleAtPoints_Hip_gen(CeedOperator op, CeedInt offset, CeedVector assembled) { 689692716b7SZach Atkins Ceed ceed; 690692716b7SZach Atkins CeedOperator_Hip_gen *data; 691692716b7SZach Atkins 692692716b7SZach Atkins CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 693692716b7SZach Atkins CeedCallBackend(CeedOperatorGetData(op, &data)); 694692716b7SZach Atkins 695692716b7SZach Atkins // Build the assembly kernel 696692716b7SZach Atkins if (!data->assemble_full && !data->use_assembly_fallback) { 697692716b7SZach Atkins bool is_build_good = false; 698692716b7SZach Atkins CeedInt num_active_bases_in, num_active_bases_out; 699692716b7SZach Atkins CeedOperatorAssemblyData assembly_data; 700692716b7SZach Atkins 701692716b7SZach Atkins CeedCallBackend(CeedOperatorGetOperatorAssemblyData(op, &assembly_data)); 702692716b7SZach Atkins CeedCallBackend( 703692716b7SZach Atkins CeedOperatorAssemblyDataGetEvalModes(assembly_data, &num_active_bases_in, NULL, NULL, NULL, &num_active_bases_out, NULL, NULL, NULL, NULL)); 704692716b7SZach Atkins if (num_active_bases_in == num_active_bases_out) { 705692716b7SZach Atkins CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_build_good)); 706692716b7SZach Atkins if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelFullAssemblyAtPoints_Hip_gen(op, &is_build_good)); 707692716b7SZach Atkins } 708692716b7SZach Atkins if (!is_build_good) { 709692716b7SZach Atkins CeedDebug(ceed, "Single Operator Assemble at Points compile failed, using fallback\n"); 710692716b7SZach Atkins data->use_assembly_fallback = true; 711692716b7SZach Atkins } 712692716b7SZach Atkins } 713692716b7SZach Atkins 714692716b7SZach Atkins // Try assembly 715692716b7SZach Atkins if (!data->use_assembly_fallback) { 716692716b7SZach Atkins bool is_run_good = true; 717692716b7SZach Atkins Ceed_Hip *Hip_data; 718692716b7SZach Atkins CeedInt num_elem, num_input_fields, num_output_fields; 719692716b7SZach Atkins CeedEvalMode eval_mode; 720692716b7SZach Atkins CeedScalar *assembled_array; 721692716b7SZach Atkins CeedQFunctionField *qf_input_fields, *qf_output_fields; 722692716b7SZach Atkins CeedQFunction_Hip_gen *qf_data; 723692716b7SZach Atkins CeedQFunction qf; 724692716b7SZach Atkins CeedOperatorField *op_input_fields, *op_output_fields; 725692716b7SZach Atkins 726692716b7SZach Atkins CeedCallBackend(CeedGetData(ceed, &Hip_data)); 727692716b7SZach Atkins CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 728692716b7SZach Atkins CeedCallBackend(CeedQFunctionGetData(qf, &qf_data)); 729692716b7SZach Atkins CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 730692716b7SZach Atkins CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 731692716b7SZach Atkins CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 732692716b7SZach Atkins CeedDebug(ceed, "Running single operator assemble for /gpu/hip/gen\n"); 733692716b7SZach Atkins 734692716b7SZach Atkins // Input vectors 735692716b7SZach Atkins for (CeedInt i = 0; i < num_input_fields; i++) { 736692716b7SZach Atkins CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 737692716b7SZach Atkins if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 738692716b7SZach Atkins data->fields.inputs[i] = NULL; 739692716b7SZach Atkins } else { 740692716b7SZach Atkins bool is_active; 741692716b7SZach Atkins CeedVector vec; 742692716b7SZach Atkins 743692716b7SZach Atkins // Get input vector 744692716b7SZach Atkins CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 745692716b7SZach Atkins is_active = vec == CEED_VECTOR_ACTIVE; 746692716b7SZach Atkins if (is_active) data->fields.inputs[i] = NULL; 747692716b7SZach Atkins else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i])); 748692716b7SZach Atkins CeedCallBackend(CeedVectorDestroy(&vec)); 749692716b7SZach Atkins } 750692716b7SZach Atkins } 751692716b7SZach Atkins 752692716b7SZach Atkins // Point coordinates 753692716b7SZach Atkins { 754692716b7SZach Atkins CeedVector vec; 755692716b7SZach Atkins 756692716b7SZach Atkins CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 757692716b7SZach Atkins CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords)); 758692716b7SZach Atkins CeedCallBackend(CeedVectorDestroy(&vec)); 759692716b7SZach Atkins 760692716b7SZach Atkins // Points per elem 761692716b7SZach Atkins if (num_elem != data->points.num_elem) { 762692716b7SZach Atkins CeedInt *points_per_elem; 763692716b7SZach Atkins const CeedInt num_bytes = num_elem * sizeof(CeedInt); 764692716b7SZach Atkins CeedElemRestriction rstr_points = NULL; 765692716b7SZach Atkins 766692716b7SZach Atkins data->points.num_elem = num_elem; 767692716b7SZach Atkins CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL)); 768692716b7SZach Atkins CeedCallBackend(CeedCalloc(num_elem, &points_per_elem)); 769692716b7SZach Atkins for (CeedInt e = 0; e < num_elem; e++) { 770692716b7SZach Atkins CeedInt num_points_elem; 771692716b7SZach Atkins 772692716b7SZach Atkins CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem)); 773692716b7SZach Atkins points_per_elem[e] = num_points_elem; 774692716b7SZach Atkins } 775692716b7SZach Atkins if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem)); 776692716b7SZach Atkins CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes)); 777692716b7SZach Atkins CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice)); 778692716b7SZach Atkins CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 779692716b7SZach Atkins CeedCallBackend(CeedFree(&points_per_elem)); 780692716b7SZach Atkins } 781692716b7SZach Atkins } 782692716b7SZach Atkins 783692716b7SZach Atkins // Get context data 784692716b7SZach Atkins CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c)); 785692716b7SZach Atkins 786692716b7SZach Atkins // Assembly array 787692716b7SZach Atkins CeedCallBackend(CeedVectorGetArray(assembled, CEED_MEM_DEVICE, &assembled_array)); 788692716b7SZach Atkins CeedScalar *assembled_offset_array = &assembled_array[offset]; 789692716b7SZach Atkins 790692716b7SZach Atkins // Assemble diagonal 791692716b7SZach Atkins void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, 792692716b7SZach Atkins &data->G, &data->W, &data->points, &assembled_offset_array}; 793692716b7SZach Atkins 794692716b7SZach Atkins CeedInt block_sizes[3] = {data->thread_1d, (data->dim == 1 ? 1 : data->thread_1d), -1}; 795692716b7SZach Atkins 796692716b7SZach Atkins CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes)); 797692716b7SZach Atkins block_sizes[2] = 1; 798692716b7SZach Atkins if (data->dim == 1) { 799692716b7SZach Atkins CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 800692716b7SZach Atkins CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar); 801692716b7SZach Atkins 802692716b7SZach Atkins CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_full, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, 803692716b7SZach Atkins &is_run_good, opargs)); 804692716b7SZach Atkins } else if (data->dim == 2) { 805692716b7SZach Atkins CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 806692716b7SZach Atkins CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 807692716b7SZach Atkins 808692716b7SZach Atkins CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_full, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, 809692716b7SZach Atkins &is_run_good, opargs)); 810692716b7SZach Atkins } else if (data->dim == 3) { 811692716b7SZach Atkins CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 812692716b7SZach Atkins CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 813692716b7SZach Atkins 814692716b7SZach Atkins CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_full, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, 815692716b7SZach Atkins &is_run_good, opargs)); 816692716b7SZach Atkins } 817692716b7SZach Atkins CeedCallHip(ceed, hipDeviceSynchronize()); 818692716b7SZach Atkins 819692716b7SZach Atkins // Restore input arrays 820692716b7SZach Atkins for (CeedInt i = 0; i < num_input_fields; i++) { 821692716b7SZach Atkins CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 822692716b7SZach Atkins if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 823692716b7SZach Atkins } else { 824692716b7SZach Atkins bool is_active; 825692716b7SZach Atkins CeedVector vec; 826692716b7SZach Atkins 827692716b7SZach Atkins CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 828692716b7SZach Atkins is_active = vec == CEED_VECTOR_ACTIVE; 829692716b7SZach Atkins if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i])); 830692716b7SZach Atkins CeedCallBackend(CeedVectorDestroy(&vec)); 831692716b7SZach Atkins } 832692716b7SZach Atkins } 833692716b7SZach Atkins 834692716b7SZach Atkins // Restore point coordinates 835692716b7SZach Atkins { 836692716b7SZach Atkins CeedVector vec; 837692716b7SZach Atkins 838692716b7SZach Atkins CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 839692716b7SZach Atkins CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords)); 840692716b7SZach Atkins CeedCallBackend(CeedVectorDestroy(&vec)); 841692716b7SZach Atkins } 842692716b7SZach Atkins 843692716b7SZach Atkins // Restore context data 844692716b7SZach Atkins CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c)); 845692716b7SZach Atkins 846692716b7SZach Atkins // Restore assembly array 847692716b7SZach Atkins CeedCallBackend(CeedVectorRestoreArray(assembled, &assembled_array)); 848692716b7SZach Atkins 849692716b7SZach Atkins // Cleanup 850692716b7SZach Atkins CeedCallBackend(CeedQFunctionDestroy(&qf)); 851692716b7SZach Atkins if (!is_run_good) { 852692716b7SZach Atkins CeedDebug(ceed, "Single Operator Assemble at Points run failed, using fallback\n"); 853692716b7SZach Atkins data->use_assembly_fallback = true; 854692716b7SZach Atkins } 855692716b7SZach Atkins } 856692716b7SZach Atkins CeedCallBackend(CeedDestroy(&ceed)); 857692716b7SZach Atkins 858692716b7SZach Atkins // Fallback, if needed 859692716b7SZach Atkins if (data->use_assembly_fallback) { 860692716b7SZach Atkins CeedOperator op_fallback; 861692716b7SZach Atkins 862692716b7SZach Atkins CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator"); 863692716b7SZach Atkins CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback)); 864692716b7SZach Atkins CeedCallBackend(CeedSingleOperatorAssemble(op_fallback, offset, assembled)); 865692716b7SZach Atkins return CEED_ERROR_SUCCESS; 866692716b7SZach Atkins } 867692716b7SZach Atkins return CEED_ERROR_SUCCESS; 868692716b7SZach Atkins } 869692716b7SZach Atkins 870692716b7SZach Atkins //------------------------------------------------------------------------------ 8717d8d0e25Snbeams // Create operator 8727d8d0e25Snbeams //------------------------------------------------------------------------------ 8737d8d0e25Snbeams int CeedOperatorCreate_Hip_gen(CeedOperator op) { 8740183ed61SJeremy L Thompson bool is_composite, is_at_points; 8757d8d0e25Snbeams Ceed ceed; 8767d8d0e25Snbeams CeedOperator_Hip_gen *impl; 8777d8d0e25Snbeams 878b7453713SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 8792b730f8bSJeremy L Thompson CeedCallBackend(CeedCalloc(1, &impl)); 8802b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorSetData(op, impl)); 881c99afcd8SJeremy L Thompson CeedCall(CeedOperatorIsComposite(op, &is_composite)); 882c99afcd8SJeremy L Thompson if (is_composite) { 883c99afcd8SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAddComposite", CeedOperatorApplyAddComposite_Hip_gen)); 884c99afcd8SJeremy L Thompson } else { 8852b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip_gen)); 886c99afcd8SJeremy L Thompson } 8870183ed61SJeremy L Thompson CeedCall(CeedOperatorIsAtPoints(op, &is_at_points)); 8880183ed61SJeremy L Thompson if (is_at_points) { 8890183ed61SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip_gen)); 890692716b7SZach Atkins CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssembleAtPoints_Hip_gen)); 8910183ed61SJeremy L Thompson } 892*5daefc96SJeremy L Thompson if (!is_at_points) { 893*5daefc96SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunction", CeedOperatorLinearAssembleQFunction_Hip_gen)); 894*5daefc96SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionUpdate_Hip_gen)); 895*5daefc96SJeremy L Thompson } 8962b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip_gen)); 8979bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 898e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 8997d8d0e25Snbeams } 9002a86cc9dSSebastian Grimberg 9017d8d0e25Snbeams //------------------------------------------------------------------------------ 902