1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-source/cuda/cuda-types.h> 11 #include <cuda.h> 12 13 #include "../cuda/ceed-cuda-common.h" 14 #include "../cuda/ceed-cuda-compile.h" 15 #include "ceed-cuda-ref-qfunction-load.h" 16 #include "ceed-cuda-ref.h" 17 18 //------------------------------------------------------------------------------ 19 // Apply QFunction 20 //------------------------------------------------------------------------------ 21 static int CeedQFunctionApply_Cuda(CeedQFunction qf, CeedInt Q, CeedVector *U, CeedVector *V) { 22 Ceed ceed; 23 CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed)); 24 25 // Build and compile kernel, if not done 26 CeedCallBackend(CeedQFunctionBuildKernel_Cuda_ref(qf)); 27 28 CeedQFunction_Cuda *data; 29 CeedCallBackend(CeedQFunctionGetData(qf, &data)); 30 Ceed_Cuda *ceed_Cuda; 31 CeedCallBackend(CeedGetData(ceed, &ceed_Cuda)); 32 CeedInt num_input_fields, num_output_fields; 33 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 34 35 // Read vectors 36 for (CeedInt i = 0; i < num_input_fields; i++) { 37 CeedCallBackend(CeedVectorGetArrayRead(U[i], CEED_MEM_DEVICE, &data->fields.inputs[i])); 38 } 39 for (CeedInt i = 0; i < num_output_fields; i++) { 40 CeedCallBackend(CeedVectorGetArrayWrite(V[i], CEED_MEM_DEVICE, &data->fields.outputs[i])); 41 } 42 43 // Get context data 44 CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &data->d_c)); 45 46 // Run kernel 47 void *args[] = {&data->d_c, (void *)&Q, &data->fields}; 48 CeedCallBackend(CeedRunKernelAutoblockCuda(ceed, data->QFunction, Q, args)); 49 50 // Restore vectors 51 for (CeedInt i = 0; i < num_input_fields; i++) { 52 CeedCallBackend(CeedVectorRestoreArrayRead(U[i], &data->fields.inputs[i])); 53 } 54 for (CeedInt i = 0; i < num_output_fields; i++) { 55 CeedCallBackend(CeedVectorRestoreArray(V[i], &data->fields.outputs[i])); 56 } 57 58 // Restore context 59 CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &data->d_c)); 60 61 return CEED_ERROR_SUCCESS; 62 } 63 64 //------------------------------------------------------------------------------ 65 // Destroy QFunction 66 //------------------------------------------------------------------------------ 67 static int CeedQFunctionDestroy_Cuda(CeedQFunction qf) { 68 CeedQFunction_Cuda *data; 69 CeedCallBackend(CeedQFunctionGetData(qf, &data)); 70 Ceed ceed; 71 CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed)); 72 if (data->module) CeedCallCuda(ceed, cuModuleUnload(data->module)); 73 CeedCallBackend(CeedFree(&data)); 74 75 return CEED_ERROR_SUCCESS; 76 } 77 78 //------------------------------------------------------------------------------ 79 // Set User QFunction 80 //------------------------------------------------------------------------------ 81 static int CeedQFunctionSetCUDAUserFunction_Cuda(CeedQFunction qf, CUfunction f) { 82 CeedQFunction_Cuda *data; 83 CeedCallBackend(CeedQFunctionGetData(qf, &data)); 84 data->QFunction = f; 85 return CEED_ERROR_SUCCESS; 86 } 87 88 //------------------------------------------------------------------------------ 89 // Create QFunction 90 //------------------------------------------------------------------------------ 91 int CeedQFunctionCreate_Cuda(CeedQFunction qf) { 92 Ceed ceed; 93 CeedQFunctionGetCeed(qf, &ceed); 94 CeedQFunction_Cuda *data; 95 CeedCallBackend(CeedCalloc(1, &data)); 96 CeedCallBackend(CeedQFunctionSetData(qf, data)); 97 98 // Read QFunction source 99 CeedCallBackend(CeedQFunctionGetKernelName(qf, &data->qfunction_name)); 100 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction User Source -----\n"); 101 CeedCallBackend(CeedQFunctionLoadSourceToBuffer(qf, &data->qfunction_source)); 102 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading QFunction User Source Complete! -----\n"); 103 104 // Register backend functions 105 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "Apply", CeedQFunctionApply_Cuda)); 106 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "Destroy", CeedQFunctionDestroy_Cuda)); 107 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "SetCUDAUserFunction", CeedQFunctionSetCUDAUserFunction_Cuda)); 108 return CEED_ERROR_SUCCESS; 109 } 110 111 //------------------------------------------------------------------------------ 112