1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/ceed.h> 9 #include <ceed/backend.h> 10 #include <cuda.h> 11 #include <stdio.h> 12 #include <string.h> 13 #include "ceed-cuda-ref.h" 14 #include "ceed-cuda-ref-qfunction-load.h" 15 #include "../cuda/ceed-cuda-compile.h" 16 17 //------------------------------------------------------------------------------ 18 // Apply QFunction 19 //------------------------------------------------------------------------------ 20 static int CeedQFunctionApply_Cuda(CeedQFunction qf, CeedInt Q, 21 CeedVector *U, CeedVector *V) { 22 int ierr; 23 Ceed ceed; 24 ierr = CeedQFunctionGetCeed(qf, &ceed); CeedChkBackend(ierr); 25 26 // Build and compile kernel, if not done 27 ierr = CeedCudaBuildQFunction(qf); CeedChkBackend(ierr); 28 29 CeedQFunction_Cuda *data; 30 ierr = CeedQFunctionGetData(qf, &data); CeedChkBackend(ierr); 31 Ceed_Cuda *ceed_Cuda; 32 ierr = CeedGetData(ceed, &ceed_Cuda); CeedChkBackend(ierr); 33 CeedInt num_input_fields, num_output_fields; 34 ierr = CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields); 35 CeedChkBackend(ierr); 36 37 // Read vectors 38 for (CeedInt i = 0; i < num_input_fields; i++) { 39 ierr = CeedVectorGetArrayRead(U[i], CEED_MEM_DEVICE, &data->fields.inputs[i]); 40 CeedChkBackend(ierr); 41 } 42 for (CeedInt i = 0; i < num_output_fields; i++) { 43 ierr = CeedVectorGetArrayWrite(V[i], CEED_MEM_DEVICE, &data->fields.outputs[i]); 44 CeedChkBackend(ierr); 45 } 46 47 // Get context data 48 ierr = CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &data->d_c); 49 CeedChkBackend(ierr); 50 51 // Run kernel 52 void *args[] = {&data->d_c, (void *) &Q, &data->fields}; 53 ierr = CeedRunKernelAutoblockCuda(ceed, data->QFunction, Q, args); 54 CeedChkBackend(ierr); 55 56 // Restore vectors 57 for (CeedInt i = 0; i < num_input_fields; i++) { 58 ierr = CeedVectorRestoreArrayRead(U[i], &data->fields.inputs[i]); 59 CeedChkBackend(ierr); 60 } 61 for (CeedInt i = 0; i < num_output_fields; i++) { 62 ierr = CeedVectorRestoreArray(V[i], &data->fields.outputs[i]); 63 CeedChkBackend(ierr); 64 } 65 66 // Restore context 67 ierr = CeedQFunctionRestoreInnerContextData(qf, &data->d_c); 68 CeedChkBackend(ierr); 69 70 return CEED_ERROR_SUCCESS; 71 } 72 73 //------------------------------------------------------------------------------ 74 // Destroy QFunction 75 //------------------------------------------------------------------------------ 76 static int CeedQFunctionDestroy_Cuda(CeedQFunction qf) { 77 int ierr; 78 CeedQFunction_Cuda *data; 79 ierr = CeedQFunctionGetData(qf, &data); CeedChkBackend(ierr); 80 Ceed ceed; 81 ierr = CeedQFunctionGetCeed(qf, &ceed); CeedChkBackend(ierr); 82 if (data->module) 83 CeedChk_Cu(ceed, cuModuleUnload(data->module)); 84 ierr = CeedFree(&data); CeedChkBackend(ierr); 85 86 return CEED_ERROR_SUCCESS; 87 } 88 89 //------------------------------------------------------------------------------ 90 // Set User QFunction 91 //------------------------------------------------------------------------------ 92 static int CeedQFunctionSetCUDAUserFunction_Cuda(CeedQFunction qf, 93 CUfunction f) { 94 int ierr; 95 CeedQFunction_Cuda *data; 96 ierr = CeedQFunctionGetData(qf, &data); CeedChkBackend(ierr); 97 data->QFunction = f; 98 99 return CEED_ERROR_SUCCESS; 100 } 101 102 //------------------------------------------------------------------------------ 103 // Create QFunction 104 //------------------------------------------------------------------------------ 105 int CeedQFunctionCreate_Cuda(CeedQFunction qf) { 106 int ierr; 107 Ceed ceed; 108 CeedQFunctionGetCeed(qf, &ceed); 109 CeedQFunction_Cuda *data; 110 ierr = CeedCalloc(1, &data); CeedChkBackend(ierr); 111 ierr = CeedQFunctionSetData(qf, data); CeedChkBackend(ierr); 112 113 // Read QFunction source 114 ierr = CeedQFunctionGetKernelName(qf, &data->qfunction_name); 115 CeedChkBackend(ierr); 116 CeedDebug256(ceed, 2, "----- Loading QFunction User Source -----\n"); 117 ierr = CeedQFunctionLoadSourceToBuffer(qf, &data->qfunction_source); 118 CeedChkBackend(ierr); 119 CeedDebug256(ceed, 2, "----- Loading QFunction User Source Complete! -----\n"); 120 121 // Register backend functions 122 ierr = CeedSetBackendFunction(ceed, "QFunction", qf, "Apply", 123 CeedQFunctionApply_Cuda); CeedChkBackend(ierr); 124 ierr = CeedSetBackendFunction(ceed, "QFunction", qf, "Destroy", 125 CeedQFunctionDestroy_Cuda); CeedChkBackend(ierr); 126 ierr = CeedSetBackendFunction(ceed, "QFunction", qf, "SetCUDAUserFunction", 127 CeedQFunctionSetCUDAUserFunction_Cuda); 128 CeedChkBackend(ierr); 129 return CEED_ERROR_SUCCESS; 130 } 131 //------------------------------------------------------------------------------ 132