1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/ceed.h> 9 #include <ceed/backend.h> 10 #include <hip/hip_runtime.h> 11 #include <stdio.h> 12 #include <string.h> 13 #include "ceed-hip-ref.h" 14 #include "ceed-hip-ref-qfunction-load.h" 15 #include "../hip/ceed-hip-compile.h" 16 17 //------------------------------------------------------------------------------ 18 // Apply QFunction 19 //------------------------------------------------------------------------------ 20 static int CeedQFunctionApply_Hip(CeedQFunction qf, CeedInt Q, 21 CeedVector *U, CeedVector *V) { 22 int ierr; 23 Ceed ceed; 24 ierr = CeedQFunctionGetCeed(qf, &ceed); CeedChkBackend(ierr); 25 26 // Build and compile kernel, if not done 27 ierr = CeedHipBuildQFunction(qf); CeedChkBackend(ierr); 28 29 CeedQFunction_Hip *data; 30 ierr = CeedQFunctionGetData(qf, &data); CeedChkBackend(ierr); 31 Ceed_Hip *ceed_Hip; 32 ierr = CeedGetData(ceed, &ceed_Hip); CeedChkBackend(ierr); 33 CeedInt num_input_fields, num_output_fields; 34 ierr = CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields); 35 CeedChkBackend(ierr); 36 const int blocksize = ceed_Hip->opt_block_size; 37 38 // Read vectors 39 for (CeedInt i = 0; i < num_input_fields; i++) { 40 ierr = CeedVectorGetArrayRead(U[i], CEED_MEM_DEVICE, &data->fields.inputs[i]); 41 CeedChkBackend(ierr); 42 } 43 for (CeedInt i = 0; i < num_output_fields; i++) { 44 ierr = CeedVectorGetArrayWrite(V[i], CEED_MEM_DEVICE, &data->fields.outputs[i]); 45 CeedChkBackend(ierr); 46 } 47 48 // Get context data 49 ierr = CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &data->d_c); 50 CeedChkBackend(ierr); 51 52 // Run kernel 53 void *args[] = {&data->d_c, (void *) &Q, &data->fields}; 54 ierr = CeedRunKernelHip(ceed, data->QFunction, CeedDivUpInt(Q, blocksize), 55 blocksize, args); CeedChkBackend(ierr); 56 57 // Restore vectors 58 for (CeedInt i = 0; i < num_input_fields; i++) { 59 ierr = CeedVectorRestoreArrayRead(U[i], &data->fields.inputs[i]); 60 CeedChkBackend(ierr); 61 } 62 for (CeedInt i = 0; i < num_output_fields; i++) { 63 ierr = CeedVectorRestoreArray(V[i], &data->fields.outputs[i]); 64 CeedChkBackend(ierr); 65 } 66 67 // Restore context 68 ierr = CeedQFunctionRestoreInnerContextData(qf, &data->d_c); 69 CeedChkBackend(ierr); 70 71 return CEED_ERROR_SUCCESS; 72 } 73 74 //------------------------------------------------------------------------------ 75 // Destroy QFunction 76 //------------------------------------------------------------------------------ 77 static int CeedQFunctionDestroy_Hip(CeedQFunction qf) { 78 int ierr; 79 CeedQFunction_Hip *data; 80 ierr = CeedQFunctionGetData(qf, &data); CeedChkBackend(ierr); 81 Ceed ceed; 82 ierr = CeedQFunctionGetCeed(qf, &ceed); CeedChkBackend(ierr); 83 if (data->module) 84 CeedChk_Hip(ceed, hipModuleUnload(data->module)); 85 ierr = CeedFree(&data); CeedChkBackend(ierr); 86 87 return CEED_ERROR_SUCCESS; 88 } 89 90 //------------------------------------------------------------------------------ 91 // Create QFunction 92 //------------------------------------------------------------------------------ 93 int CeedQFunctionCreate_Hip(CeedQFunction qf) { 94 int ierr; 95 Ceed ceed; 96 CeedQFunctionGetCeed(qf, &ceed); 97 CeedQFunction_Hip *data; 98 ierr = CeedCalloc(1,&data); CeedChkBackend(ierr); 99 ierr = CeedQFunctionSetData(qf, data); CeedChkBackend(ierr); 100 CeedInt num_input_fields, num_output_fields; 101 ierr = CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields); 102 CeedChkBackend(ierr); 103 104 // Read QFunction source 105 ierr = CeedQFunctionGetKernelName(qf, &data->qfunction_name); 106 CeedChkBackend(ierr); 107 CeedDebug256(ceed, 2, "----- Loading QFunction User Source -----\n"); 108 ierr = CeedQFunctionLoadSourceToBuffer(qf, &data->qfunction_source); 109 CeedChkBackend(ierr); 110 CeedDebug256(ceed, 2, "----- Loading QFunction User Source Complete! -----\n"); 111 112 // Register backend functions 113 ierr = CeedSetBackendFunction(ceed, "QFunction", qf, "Apply", 114 CeedQFunctionApply_Hip); CeedChkBackend(ierr); 115 ierr = CeedSetBackendFunction(ceed, "QFunction", qf, "Destroy", 116 CeedQFunctionDestroy_Hip); CeedChkBackend(ierr); 117 return CEED_ERROR_SUCCESS; 118 } 119 //------------------------------------------------------------------------------ 120