1d275d636SJeremy L Thompson // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors. 23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 37d8d0e25Snbeams // 43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause 57d8d0e25Snbeams // 63d8e8822SJeremy L Thompson // This file is part of CEED: http://github.com/ceed 77d8d0e25Snbeams 849aac155SJeremy L Thompson #include <ceed.h> 9ec3da8bcSJed Brown #include <ceed/backend.h> 1049aac155SJeremy L Thompson #include <ceed/jit-source/hip/hip-types.h> 113d576824SJeremy L Thompson #include <stddef.h> 123a2968d6SJeremy L Thompson #include <hip/hiprtc.h> 132b730f8bSJeremy L Thompson 14b2165e7aSSebastian Grimberg #include "../hip/ceed-hip-common.h" 157d8d0e25Snbeams #include "../hip/ceed-hip-compile.h" 162b730f8bSJeremy L Thompson #include "ceed-hip-gen-operator-build.h" 172b730f8bSJeremy L Thompson #include "ceed-hip-gen.h" 187d8d0e25Snbeams 197d8d0e25Snbeams //------------------------------------------------------------------------------ 207d8d0e25Snbeams // Destroy operator 217d8d0e25Snbeams //------------------------------------------------------------------------------ 227d8d0e25Snbeams static int CeedOperatorDestroy_Hip_gen(CeedOperator op) { 233a2968d6SJeremy L Thompson Ceed ceed; 247d8d0e25Snbeams CeedOperator_Hip_gen *impl; 256eee1ffcSZach Atkins bool is_composite; 26b7453713SJeremy L Thompson 273a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 282b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &impl)); 296eee1ffcSZach Atkins CeedCallBackend(CeedOperatorIsComposite(op, &is_composite)); 306eee1ffcSZach Atkins if (is_composite) { 316eee1ffcSZach Atkins CeedInt num_suboperators; 326eee1ffcSZach Atkins 336eee1ffcSZach Atkins CeedCall(CeedCompositeOperatorGetNumSub(op, &num_suboperators)); 346eee1ffcSZach Atkins for (CeedInt i = 0; i < num_suboperators; i++) { 356eee1ffcSZach Atkins if (impl->streams[i]) CeedCallHip(ceed, hipStreamDestroy(impl->streams[i])); 366eee1ffcSZach Atkins impl->streams[i] = NULL; 376eee1ffcSZach Atkins } 386eee1ffcSZach Atkins } 398b7d3340SJeremy L Thompson if (impl->module) CeedCallHip(ceed, hipModuleUnload(impl->module)); 40*0183ed61SJeremy L Thompson if (impl->module_assemble_full) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_full)); 41*0183ed61SJeremy L Thompson if (impl->module_assemble_diagonal) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_diagonal)); 423a2968d6SJeremy L Thompson if (impl->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)impl->points.num_per_elem)); 432b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&impl)); 443a2968d6SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 45e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 467d8d0e25Snbeams } 477d8d0e25Snbeams 487d8d0e25Snbeams //------------------------------------------------------------------------------ 497d8d0e25Snbeams // Apply and add to output 507d8d0e25Snbeams //------------------------------------------------------------------------------ 51e9c76bddSJeremy L Thompson static int CeedOperatorApplyAddCore_Hip_gen(CeedOperator op, hipStream_t stream, const CeedScalar *input_arr, CeedScalar *output_arr, 52e9c76bddSJeremy L Thompson bool *is_run_good, CeedRequest *request) { 53ea04d07fSJeremy L Thompson bool is_at_points, is_tensor; 547d8d0e25Snbeams Ceed ceed; 55b7453713SJeremy L Thompson CeedInt num_elem, num_input_fields, num_output_fields; 56b7453713SJeremy L Thompson CeedEvalMode eval_mode; 57b7453713SJeremy L Thompson CeedQFunctionField *qf_input_fields, *qf_output_fields; 587d8d0e25Snbeams CeedQFunction_Hip_gen *qf_data; 59b7453713SJeremy L Thompson CeedQFunction qf; 60b7453713SJeremy L Thompson CeedOperatorField *op_input_fields, *op_output_fields; 61b7453713SJeremy L Thompson CeedOperator_Hip_gen *data; 62b7453713SJeremy L Thompson 638d12f40eSJeremy L Thompson // Creation of the operator 64ea04d07fSJeremy L Thompson CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, is_run_good)); 65ea04d07fSJeremy L Thompson if (!(*is_run_good)) return CEED_ERROR_SUCCESS; 66f6eafd79SJeremy L Thompson 67c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 68c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &data)); 69c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 70c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &qf_data)); 71c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 728d12f40eSJeremy L Thompson CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 73c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 74c11e12f4SJeremy L Thompson 757d8d0e25Snbeams // Input vectors 769e201c85SYohann for (CeedInt i = 0; i < num_input_fields; i++) { 772b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 789e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 799e201c85SYohann data->fields.inputs[i] = NULL; 807d8d0e25Snbeams } else { 813efc994bSJeremy L Thompson bool is_active; 82b7453713SJeremy L Thompson CeedVector vec; 83b7453713SJeremy L Thompson 847d8d0e25Snbeams // Get input vector 852b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 863efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 87ea04d07fSJeremy L Thompson if (is_active) data->fields.inputs[i] = input_arr; 88ea04d07fSJeremy L Thompson else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i])); 89ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 907d8d0e25Snbeams } 917d8d0e25Snbeams } 927d8d0e25Snbeams 937d8d0e25Snbeams // Output vectors 949e201c85SYohann for (CeedInt i = 0; i < num_output_fields; i++) { 952b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 969e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 979e201c85SYohann data->fields.outputs[i] = NULL; 987d8d0e25Snbeams } else { 993efc994bSJeremy L Thompson bool is_active; 100b7453713SJeremy L Thompson CeedVector vec; 101b7453713SJeremy L Thompson 1027d8d0e25Snbeams // Get output vector 1032b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 1043efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 105ea04d07fSJeremy L Thompson if (is_active) data->fields.outputs[i] = output_arr; 1060c8fbeedSJeremy L Thompson else CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.outputs[i])); 107ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 1087d8d0e25Snbeams } 1097d8d0e25Snbeams } 1107d8d0e25Snbeams 1113a2968d6SJeremy L Thompson // Point coordinates, if needed 1123a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorIsAtPoints(op, &is_at_points)); 1133a2968d6SJeremy L Thompson if (is_at_points) { 1143a2968d6SJeremy L Thompson // Coords 1153a2968d6SJeremy L Thompson CeedVector vec; 1163a2968d6SJeremy L Thompson 1173a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 1183a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords)); 1193a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 1203a2968d6SJeremy L Thompson 1213a2968d6SJeremy L Thompson // Points per elem 1223a2968d6SJeremy L Thompson if (num_elem != data->points.num_elem) { 1233a2968d6SJeremy L Thompson CeedInt *points_per_elem; 1243a2968d6SJeremy L Thompson const CeedInt num_bytes = num_elem * sizeof(CeedInt); 1253a2968d6SJeremy L Thompson CeedElemRestriction rstr_points = NULL; 1263a2968d6SJeremy L Thompson 1273a2968d6SJeremy L Thompson data->points.num_elem = num_elem; 1283a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL)); 1293a2968d6SJeremy L Thompson CeedCallBackend(CeedCalloc(num_elem, &points_per_elem)); 1303a2968d6SJeremy L Thompson for (CeedInt e = 0; e < num_elem; e++) { 1313a2968d6SJeremy L Thompson CeedInt num_points_elem; 1323a2968d6SJeremy L Thompson 1333a2968d6SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem)); 1343a2968d6SJeremy L Thompson points_per_elem[e] = num_points_elem; 1353a2968d6SJeremy L Thompson } 1363a2968d6SJeremy L Thompson if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem)); 1373a2968d6SJeremy L Thompson CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes)); 1383a2968d6SJeremy L Thompson CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice)); 1393a2968d6SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 1403a2968d6SJeremy L Thompson CeedCallBackend(CeedFree(&points_per_elem)); 1413a2968d6SJeremy L Thompson } 1423a2968d6SJeremy L Thompson } 1433a2968d6SJeremy L Thompson 1447d8d0e25Snbeams // Get context data 1452b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c)); 1467d8d0e25Snbeams 1477d8d0e25Snbeams // Apply operator 1483a2968d6SJeremy L Thompson void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points}; 149b7453713SJeremy L Thompson 1509123fb08SJeremy L Thompson CeedCallBackend(CeedOperatorHasTensorBases(op, &is_tensor)); 151a61b1c91SJeremy L Thompson CeedInt block_sizes[3] = {data->thread_1d, ((!is_tensor || data->dim == 1) ? 1 : data->thread_1d), -1}; 152f82027a4SJeremy L Thompson 153f82027a4SJeremy L Thompson if (is_tensor) { 15474398b5aSJeremy L Thompson CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes)); 15590c30374SJeremy L Thompson if (is_at_points) block_sizes[2] = 1; 156f82027a4SJeremy L Thompson } else { 157a61b1c91SJeremy L Thompson CeedInt elems_per_block = 64 * data->thread_1d > 256 ? 256 / data->thread_1d : 64; 158f82027a4SJeremy L Thompson 159f82027a4SJeremy L Thompson elems_per_block = elems_per_block > 0 ? elems_per_block : 1; 160f82027a4SJeremy L Thompson block_sizes[2] = elems_per_block; 161f82027a4SJeremy L Thompson } 16274398b5aSJeremy L Thompson if (data->dim == 1 || !is_tensor) { 1632b730f8bSJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 164a61b1c91SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar); 165b7453713SJeremy L Thompson 1668d12f40eSJeremy L Thompson CeedCallBackend( 167e9c76bddSJeremy L Thompson CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, is_run_good, opargs)); 16874398b5aSJeremy L Thompson } else if (data->dim == 2) { 1692b730f8bSJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 170a61b1c91SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 171b7453713SJeremy L Thompson 1728d12f40eSJeremy L Thompson CeedCallBackend( 173e9c76bddSJeremy L Thompson CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, is_run_good, opargs)); 17474398b5aSJeremy L Thompson } else if (data->dim == 3) { 1752b730f8bSJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 176a61b1c91SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 177b7453713SJeremy L Thompson 1788d12f40eSJeremy L Thompson CeedCallBackend( 179e9c76bddSJeremy L Thompson CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, is_run_good, opargs)); 1807d8d0e25Snbeams } 1817d8d0e25Snbeams 1827d8d0e25Snbeams // Restore input arrays 1839e201c85SYohann for (CeedInt i = 0; i < num_input_fields; i++) { 1842b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 1859e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 1867d8d0e25Snbeams } else { 1873efc994bSJeremy L Thompson bool is_active; 188b7453713SJeremy L Thompson CeedVector vec; 189b7453713SJeremy L Thompson 1902b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1913efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 192ea04d07fSJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i])); 193ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 1947d8d0e25Snbeams } 1957d8d0e25Snbeams } 1967d8d0e25Snbeams 1977d8d0e25Snbeams // Restore output arrays 1989e201c85SYohann for (CeedInt i = 0; i < num_output_fields; i++) { 1992b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 2009e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 2017d8d0e25Snbeams } else { 2023efc994bSJeremy L Thompson bool is_active; 203b7453713SJeremy L Thompson CeedVector vec; 204b7453713SJeremy L Thompson 2052b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 2063efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 207ea04d07fSJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArray(vec, &data->fields.outputs[i])); 208ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 2097d8d0e25Snbeams } 2107d8d0e25Snbeams } 2117d8d0e25Snbeams 2123a2968d6SJeremy L Thompson // Restore point coordinates, if needed 2133a2968d6SJeremy L Thompson if (is_at_points) { 2143a2968d6SJeremy L Thompson CeedVector vec; 2153a2968d6SJeremy L Thompson 2163a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 2173a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords)); 2183a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 2193a2968d6SJeremy L Thompson } 2203a2968d6SJeremy L Thompson 2217d8d0e25Snbeams // Restore context data 2222b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c)); 2238d12f40eSJeremy L Thompson 2248d12f40eSJeremy L Thompson // Cleanup 2259bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 226c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionDestroy(&qf)); 227ea04d07fSJeremy L Thompson if (!(*is_run_good)) data->use_fallback = true; 228ea04d07fSJeremy L Thompson return CEED_ERROR_SUCCESS; 229ea04d07fSJeremy L Thompson } 2308d12f40eSJeremy L Thompson 231ea04d07fSJeremy L Thompson static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) { 232ea04d07fSJeremy L Thompson bool is_run_good = false; 233ea04d07fSJeremy L Thompson const CeedScalar *input_arr = NULL; 234ea04d07fSJeremy L Thompson CeedScalar *output_arr = NULL; 235ea04d07fSJeremy L Thompson 236ea04d07fSJeremy L Thompson // Try to run kernel 237ea04d07fSJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr)); 238ea04d07fSJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr)); 239087855afSJeremy L Thompson CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(op, NULL, input_arr, output_arr, &is_run_good, request)); 240ea04d07fSJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr)); 241ea04d07fSJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr)); 242ea04d07fSJeremy L Thompson 243ea04d07fSJeremy L Thompson // Fallback on unsuccessful run 244ea04d07fSJeremy L Thompson if (!is_run_good) { 2458d12f40eSJeremy L Thompson CeedOperator op_fallback; 2468d12f40eSJeremy L Thompson 247ea04d07fSJeremy L Thompson CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator"); 2488d12f40eSJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback)); 2498d12f40eSJeremy L Thompson CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request)); 2508d12f40eSJeremy L Thompson } 251e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 2527d8d0e25Snbeams } 2537d8d0e25Snbeams 254c99afcd8SJeremy L Thompson static int CeedOperatorApplyAddComposite_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) { 2556eee1ffcSZach Atkins bool is_run_good[CEED_COMPOSITE_MAX] = {true}; 256c99afcd8SJeremy L Thompson CeedInt num_suboperators; 257c99afcd8SJeremy L Thompson const CeedScalar *input_arr = NULL; 2586eee1ffcSZach Atkins CeedScalar *output_arr; 259087855afSJeremy L Thompson Ceed ceed; 2606eee1ffcSZach Atkins CeedOperator_Hip_gen *impl; 261c99afcd8SJeremy L Thompson CeedOperator *sub_operators; 262c99afcd8SJeremy L Thompson 263087855afSJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 2646eee1ffcSZach Atkins CeedCallBackend(CeedOperatorGetData(op, &impl)); 2656eee1ffcSZach Atkins CeedCallBackend(CeedCompositeOperatorGetNumSub(op, &num_suboperators)); 2666eee1ffcSZach Atkins CeedCallBackend(CeedCompositeOperatorGetSubList(op, &sub_operators)); 267c99afcd8SJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr)); 268c99afcd8SJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr)); 269c99afcd8SJeremy L Thompson for (CeedInt i = 0; i < num_suboperators; i++) { 270c99afcd8SJeremy L Thompson CeedInt num_elem = 0; 271c99afcd8SJeremy L Thompson 2726eee1ffcSZach Atkins CeedCallBackend(CeedOperatorGetNumElements(sub_operators[i], &num_elem)); 273c99afcd8SJeremy L Thompson if (num_elem > 0) { 2746eee1ffcSZach Atkins if (!impl->streams[i]) CeedCallHip(ceed, hipStreamCreate(&impl->streams[i])); 2756eee1ffcSZach Atkins CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(sub_operators[i], impl->streams[i], input_arr, output_arr, &is_run_good[i], request)); 2766eee1ffcSZach Atkins } else { 2776eee1ffcSZach Atkins is_run_good[i] = true; 2786eee1ffcSZach Atkins } 2796eee1ffcSZach Atkins } 280087855afSJeremy L Thompson 2816eee1ffcSZach Atkins for (CeedInt i = 0; i < num_suboperators; i++) { 2826eee1ffcSZach Atkins if (impl->streams[i]) { 2836eee1ffcSZach Atkins if (is_run_good[i]) CeedCallHip(ceed, hipStreamSynchronize(impl->streams[i])); 284c99afcd8SJeremy L Thompson } 285c99afcd8SJeremy L Thompson } 286c99afcd8SJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr)); 287c99afcd8SJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr)); 288087855afSJeremy L Thompson CeedCallHip(ceed, hipDeviceSynchronize()); 289c99afcd8SJeremy L Thompson 290c99afcd8SJeremy L Thompson // Fallback on unsuccessful run 291c99afcd8SJeremy L Thompson for (CeedInt i = 0; i < num_suboperators; i++) { 292c99afcd8SJeremy L Thompson if (!is_run_good[i]) { 293c99afcd8SJeremy L Thompson CeedOperator op_fallback; 294c99afcd8SJeremy L Thompson 295087855afSJeremy L Thompson CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator"); 296c99afcd8SJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(sub_operators[i], &op_fallback)); 297c99afcd8SJeremy L Thompson CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request)); 298c99afcd8SJeremy L Thompson } 299c99afcd8SJeremy L Thompson } 300087855afSJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 301c99afcd8SJeremy L Thompson return CEED_ERROR_SUCCESS; 302c99afcd8SJeremy L Thompson } 303c99afcd8SJeremy L Thompson 3047d8d0e25Snbeams //------------------------------------------------------------------------------ 305*0183ed61SJeremy L Thompson // AtPoints diagonal assembly 306*0183ed61SJeremy L Thompson //------------------------------------------------------------------------------ 307*0183ed61SJeremy L Thompson static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip_gen(CeedOperator op, CeedVector assembled, CeedRequest *request) { 308*0183ed61SJeremy L Thompson Ceed ceed; 309*0183ed61SJeremy L Thompson CeedOperator_Hip_gen *data; 310*0183ed61SJeremy L Thompson 311*0183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 312*0183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &data)); 313*0183ed61SJeremy L Thompson 314*0183ed61SJeremy L Thompson // Build the assembly kernel 315*0183ed61SJeremy L Thompson if (!data->assemble_diagonal && !data->use_assembly_fallback) { 316*0183ed61SJeremy L Thompson bool is_build_good = false; 317*0183ed61SJeremy L Thompson CeedInt num_active_bases_in, num_active_bases_out; 318*0183ed61SJeremy L Thompson CeedOperatorAssemblyData assembly_data; 319*0183ed61SJeremy L Thompson 320*0183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetOperatorAssemblyData(op, &assembly_data)); 321*0183ed61SJeremy L Thompson CeedCallBackend( 322*0183ed61SJeremy L Thompson CeedOperatorAssemblyDataGetEvalModes(assembly_data, &num_active_bases_in, NULL, NULL, NULL, &num_active_bases_out, NULL, NULL, NULL, NULL)); 323*0183ed61SJeremy L Thompson if (num_active_bases_in == num_active_bases_out) { 324*0183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_build_good)); 325*0183ed61SJeremy L Thompson if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelDiagonalAssemblyAtPoints_Hip_gen(op, &is_build_good)); 326*0183ed61SJeremy L Thompson } 327*0183ed61SJeremy L Thompson if (!is_build_good) data->use_assembly_fallback = true; 328*0183ed61SJeremy L Thompson } 329*0183ed61SJeremy L Thompson 330*0183ed61SJeremy L Thompson // Try assembly 331*0183ed61SJeremy L Thompson if (!data->use_assembly_fallback) { 332*0183ed61SJeremy L Thompson bool is_run_good = true; 333*0183ed61SJeremy L Thompson Ceed_Hip *hip_data; 334*0183ed61SJeremy L Thompson CeedInt num_elem, num_input_fields, num_output_fields; 335*0183ed61SJeremy L Thompson CeedEvalMode eval_mode; 336*0183ed61SJeremy L Thompson CeedScalar *assembled_array; 337*0183ed61SJeremy L Thompson CeedQFunctionField *qf_input_fields, *qf_output_fields; 338*0183ed61SJeremy L Thompson CeedQFunction_Hip_gen *qf_data; 339*0183ed61SJeremy L Thompson CeedQFunction qf; 340*0183ed61SJeremy L Thompson CeedOperatorField *op_input_fields, *op_output_fields; 341*0183ed61SJeremy L Thompson 342*0183ed61SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &hip_data)); 343*0183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 344*0183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &qf_data)); 345*0183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 346*0183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 347*0183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 348*0183ed61SJeremy L Thompson 349*0183ed61SJeremy L Thompson // Input vectors 350*0183ed61SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) { 351*0183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 352*0183ed61SJeremy L Thompson if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 353*0183ed61SJeremy L Thompson data->fields.inputs[i] = NULL; 354*0183ed61SJeremy L Thompson } else { 355*0183ed61SJeremy L Thompson bool is_active; 356*0183ed61SJeremy L Thompson CeedVector vec; 357*0183ed61SJeremy L Thompson 358*0183ed61SJeremy L Thompson // Get input vector 359*0183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 360*0183ed61SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 361*0183ed61SJeremy L Thompson if (is_active) data->fields.inputs[i] = NULL; 362*0183ed61SJeremy L Thompson else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i])); 363*0183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 364*0183ed61SJeremy L Thompson } 365*0183ed61SJeremy L Thompson } 366*0183ed61SJeremy L Thompson 367*0183ed61SJeremy L Thompson // Point coordinates 368*0183ed61SJeremy L Thompson { 369*0183ed61SJeremy L Thompson CeedVector vec; 370*0183ed61SJeremy L Thompson 371*0183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 372*0183ed61SJeremy L Thompson CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords)); 373*0183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 374*0183ed61SJeremy L Thompson 375*0183ed61SJeremy L Thompson // Points per elem 376*0183ed61SJeremy L Thompson if (num_elem != data->points.num_elem) { 377*0183ed61SJeremy L Thompson CeedInt *points_per_elem; 378*0183ed61SJeremy L Thompson const CeedInt num_bytes = num_elem * sizeof(CeedInt); 379*0183ed61SJeremy L Thompson CeedElemRestriction rstr_points = NULL; 380*0183ed61SJeremy L Thompson 381*0183ed61SJeremy L Thompson data->points.num_elem = num_elem; 382*0183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL)); 383*0183ed61SJeremy L Thompson CeedCallBackend(CeedCalloc(num_elem, &points_per_elem)); 384*0183ed61SJeremy L Thompson for (CeedInt e = 0; e < num_elem; e++) { 385*0183ed61SJeremy L Thompson CeedInt num_points_elem; 386*0183ed61SJeremy L Thompson 387*0183ed61SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem)); 388*0183ed61SJeremy L Thompson points_per_elem[e] = num_points_elem; 389*0183ed61SJeremy L Thompson } 390*0183ed61SJeremy L Thompson if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem)); 391*0183ed61SJeremy L Thompson CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes)); 392*0183ed61SJeremy L Thompson CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice)); 393*0183ed61SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 394*0183ed61SJeremy L Thompson CeedCallBackend(CeedFree(&points_per_elem)); 395*0183ed61SJeremy L Thompson } 396*0183ed61SJeremy L Thompson } 397*0183ed61SJeremy L Thompson 398*0183ed61SJeremy L Thompson // Get context data 399*0183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c)); 400*0183ed61SJeremy L Thompson 401*0183ed61SJeremy L Thompson // Assembly array 402*0183ed61SJeremy L Thompson CeedCallBackend(CeedVectorGetArray(assembled, CEED_MEM_DEVICE, &assembled_array)); 403*0183ed61SJeremy L Thompson 404*0183ed61SJeremy L Thompson // Assemble diagonal 405*0183ed61SJeremy L Thompson void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points, &assembled_array}; 406*0183ed61SJeremy L Thompson 407*0183ed61SJeremy L Thompson CeedInt block_sizes[3] = {data->thread_1d, (data->dim == 1 ? 1 : data->thread_1d), -1}; 408*0183ed61SJeremy L Thompson 409*0183ed61SJeremy L Thompson CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes)); 410*0183ed61SJeremy L Thompson block_sizes[2] = 1; 411*0183ed61SJeremy L Thompson if (data->dim == 1) { 412*0183ed61SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 413*0183ed61SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar); 414*0183ed61SJeremy L Thompson 415*0183ed61SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 416*0183ed61SJeremy L Thompson sharedMem, &is_run_good, opargs)); 417*0183ed61SJeremy L Thompson } else if (data->dim == 2) { 418*0183ed61SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 419*0183ed61SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 420*0183ed61SJeremy L Thompson 421*0183ed61SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 422*0183ed61SJeremy L Thompson sharedMem, &is_run_good, opargs)); 423*0183ed61SJeremy L Thompson } else if (data->dim == 3) { 424*0183ed61SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 425*0183ed61SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 426*0183ed61SJeremy L Thompson 427*0183ed61SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 428*0183ed61SJeremy L Thompson sharedMem, &is_run_good, opargs)); 429*0183ed61SJeremy L Thompson } 430*0183ed61SJeremy L Thompson 431*0183ed61SJeremy L Thompson // Restore input arrays 432*0183ed61SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) { 433*0183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 434*0183ed61SJeremy L Thompson if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 435*0183ed61SJeremy L Thompson } else { 436*0183ed61SJeremy L Thompson bool is_active; 437*0183ed61SJeremy L Thompson CeedVector vec; 438*0183ed61SJeremy L Thompson 439*0183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 440*0183ed61SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 441*0183ed61SJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i])); 442*0183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 443*0183ed61SJeremy L Thompson } 444*0183ed61SJeremy L Thompson } 445*0183ed61SJeremy L Thompson 446*0183ed61SJeremy L Thompson // Restore point coordinates 447*0183ed61SJeremy L Thompson { 448*0183ed61SJeremy L Thompson CeedVector vec; 449*0183ed61SJeremy L Thompson 450*0183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 451*0183ed61SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords)); 452*0183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 453*0183ed61SJeremy L Thompson } 454*0183ed61SJeremy L Thompson 455*0183ed61SJeremy L Thompson // Restore context data 456*0183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c)); 457*0183ed61SJeremy L Thompson 458*0183ed61SJeremy L Thompson // Restore assembly array 459*0183ed61SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArray(assembled, &assembled_array)); 460*0183ed61SJeremy L Thompson 461*0183ed61SJeremy L Thompson // Cleanup 462*0183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionDestroy(&qf)); 463*0183ed61SJeremy L Thompson if (!is_run_good) data->use_assembly_fallback = true; 464*0183ed61SJeremy L Thompson } 465*0183ed61SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 466*0183ed61SJeremy L Thompson 467*0183ed61SJeremy L Thompson // Fallback, if needed 468*0183ed61SJeremy L Thompson if (data->use_assembly_fallback) { 469*0183ed61SJeremy L Thompson CeedOperator op_fallback; 470*0183ed61SJeremy L Thompson 471*0183ed61SJeremy L Thompson CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator"); 472*0183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback)); 473*0183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorLinearAssembleAddDiagonal(op_fallback, assembled, request)); 474*0183ed61SJeremy L Thompson return CEED_ERROR_SUCCESS; 475*0183ed61SJeremy L Thompson } 476*0183ed61SJeremy L Thompson return CEED_ERROR_SUCCESS; 477*0183ed61SJeremy L Thompson } 478*0183ed61SJeremy L Thompson 479*0183ed61SJeremy L Thompson //------------------------------------------------------------------------------ 4807d8d0e25Snbeams // Create operator 4817d8d0e25Snbeams //------------------------------------------------------------------------------ 4827d8d0e25Snbeams int CeedOperatorCreate_Hip_gen(CeedOperator op) { 483*0183ed61SJeremy L Thompson bool is_composite, is_at_points; 4847d8d0e25Snbeams Ceed ceed; 4857d8d0e25Snbeams CeedOperator_Hip_gen *impl; 4867d8d0e25Snbeams 487b7453713SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 4882b730f8bSJeremy L Thompson CeedCallBackend(CeedCalloc(1, &impl)); 4892b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorSetData(op, impl)); 490c99afcd8SJeremy L Thompson CeedCall(CeedOperatorIsComposite(op, &is_composite)); 491c99afcd8SJeremy L Thompson if (is_composite) { 492c99afcd8SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAddComposite", CeedOperatorApplyAddComposite_Hip_gen)); 493c99afcd8SJeremy L Thompson } else { 4942b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip_gen)); 495c99afcd8SJeremy L Thompson } 496*0183ed61SJeremy L Thompson CeedCall(CeedOperatorIsAtPoints(op, &is_at_points)); 497*0183ed61SJeremy L Thompson if (is_at_points) { 498*0183ed61SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip_gen)); 499*0183ed61SJeremy L Thompson } 5002b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip_gen)); 5019bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 502e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 5037d8d0e25Snbeams } 5042a86cc9dSSebastian Grimberg 5057d8d0e25Snbeams //------------------------------------------------------------------------------ 506