1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-source/hip/hip-types.h> 11 #include <hip/hip_runtime.h> 12 13 #include "../hip/ceed-hip-common.h" 14 #include "../hip/ceed-hip-compile.h" 15 #include "ceed-hip-ref-qfunction-load.h" 16 #include "ceed-hip-ref.h" 17 18 //------------------------------------------------------------------------------ 19 // Apply QFunction 20 //------------------------------------------------------------------------------ 21 static int CeedQFunctionApply_Hip(CeedQFunction qf, CeedInt Q, CeedVector *U, CeedVector *V) { 22 Ceed ceed; 23 Ceed_Hip *ceed_Hip; 24 CeedInt num_input_fields, num_output_fields; 25 CeedQFunction_Hip *data; 26 27 CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed)); 28 29 // Build and compile kernel, if not done 30 CeedCallBackend(CeedQFunctionBuildKernel_Hip_ref(qf)); 31 32 CeedCallBackend(CeedQFunctionGetData(qf, &data)); 33 CeedCallBackend(CeedGetData(ceed, &ceed_Hip)); 34 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 35 const int block_size = ceed_Hip->opt_block_size; 36 37 // Read vectors 38 for (CeedInt i = 0; i < num_input_fields; i++) { 39 CeedCallBackend(CeedVectorGetArrayRead(U[i], CEED_MEM_DEVICE, &data->fields.inputs[i])); 40 } 41 for (CeedInt i = 0; i < num_output_fields; i++) { 42 CeedCallBackend(CeedVectorGetArrayWrite(V[i], CEED_MEM_DEVICE, &data->fields.outputs[i])); 43 } 44 45 // Get context data 46 CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &data->d_c)); 47 48 // Run kernel 49 void *args[] = {&data->d_c, (void *)&Q, &data->fields}; 50 51 CeedCallBackend(CeedRunKernel_Hip(ceed, data->QFunction, CeedDivUpInt(Q, block_size), block_size, args)); 52 53 // Restore vectors 54 for (CeedInt i = 0; i < num_input_fields; i++) { 55 CeedCallBackend(CeedVectorRestoreArrayRead(U[i], &data->fields.inputs[i])); 56 } 57 for (CeedInt i = 0; i < num_output_fields; i++) { 58 CeedCallBackend(CeedVectorRestoreArray(V[i], &data->fields.outputs[i])); 59 } 60 61 // Restore context 62 CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &data->d_c)); 63 return CEED_ERROR_SUCCESS; 64 } 65 66 //------------------------------------------------------------------------------ 67 // Destroy QFunction 68 //------------------------------------------------------------------------------ 69 static int CeedQFunctionDestroy_Hip(CeedQFunction qf) { 70 Ceed ceed; 71 CeedQFunction_Hip *data; 72 73 CeedCallBackend(CeedQFunctionGetData(qf, &data)); 74 CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed)); 75 if (data->module) CeedCallHip(ceed, hipModuleUnload(data->module)); 76 CeedCallBackend(CeedFree(&data)); 77 return CEED_ERROR_SUCCESS; 78 } 79 80 //------------------------------------------------------------------------------ 81 // Create QFunction 82 //------------------------------------------------------------------------------ 83 int CeedQFunctionCreate_Hip(CeedQFunction qf) { 84 Ceed ceed; 85 CeedInt num_input_fields, num_output_fields; 86 CeedQFunction_Hip *data; 87 88 CeedQFunctionGetCeed(qf, &ceed); 89 CeedCallBackend(CeedCalloc(1, &data)); 90 CeedCallBackend(CeedQFunctionSetData(qf, data)); 91 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 92 93 // Read QFunction source 94 CeedCallBackend(CeedQFunctionGetKernelName(qf, &data->qfunction_name)); 95 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction User Source -----\n"); 96 CeedCallBackend(CeedQFunctionLoadSourceToBuffer(qf, &data->qfunction_source)); 97 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction User Source Complete! -----\n"); 98 99 // Register backend functions 100 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "Apply", CeedQFunctionApply_Hip)); 101 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "Destroy", CeedQFunctionDestroy_Hip)); 102 return CEED_ERROR_SUCCESS; 103 } 104 105 //------------------------------------------------------------------------------ 106