1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-source/cuda/cuda-types.h>
11 #include <cuda.h>
12
13 #include "../cuda/ceed-cuda-common.h"
14 #include "../cuda/ceed-cuda-compile.h"
15 #include "ceed-cuda-ref-qfunction-load.h"
16 #include "ceed-cuda-ref.h"
17
18 //------------------------------------------------------------------------------
19 // Apply QFunction
20 //------------------------------------------------------------------------------
CeedQFunctionApply_Cuda(CeedQFunction qf,CeedInt Q,CeedVector * U,CeedVector * V)21 static int CeedQFunctionApply_Cuda(CeedQFunction qf, CeedInt Q, CeedVector *U, CeedVector *V) {
22 Ceed ceed;
23 Ceed_Cuda *ceed_Cuda;
24 CeedInt num_input_fields, num_output_fields;
25 CeedQFunction_Cuda *data;
26
27 CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
28
29 // Build and compile kernel, if not done
30 CeedCallBackend(CeedQFunctionBuildKernel_Cuda_ref(qf));
31
32 CeedCallBackend(CeedQFunctionGetData(qf, &data));
33 CeedCallBackend(CeedGetData(ceed, &ceed_Cuda));
34 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields));
35
36 // Read vectors
37 for (CeedInt i = 0; i < num_input_fields; i++) {
38 CeedCallBackend(CeedVectorGetArrayRead(U[i], CEED_MEM_DEVICE, &data->fields.inputs[i]));
39 }
40 for (CeedInt i = 0; i < num_output_fields; i++) {
41 CeedCallBackend(CeedVectorGetArrayWrite(V[i], CEED_MEM_DEVICE, &data->fields.outputs[i]));
42 }
43
44 // Get context data
45 CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &data->d_c));
46
47 // Run kernel
48 void *args[] = {&data->d_c, (void *)&Q, &data->fields};
49 CeedCallBackend(CeedRunKernelAutoblockCuda(ceed, data->QFunction, Q, args));
50
51 // Restore vectors
52 for (CeedInt i = 0; i < num_input_fields; i++) {
53 CeedCallBackend(CeedVectorRestoreArrayRead(U[i], &data->fields.inputs[i]));
54 }
55 for (CeedInt i = 0; i < num_output_fields; i++) {
56 CeedCallBackend(CeedVectorRestoreArray(V[i], &data->fields.outputs[i]));
57 }
58
59 // Restore context
60 CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &data->d_c));
61 CeedCallBackend(CeedDestroy(&ceed));
62 return CEED_ERROR_SUCCESS;
63 }
64
65 //------------------------------------------------------------------------------
66 // Destroy QFunction
67 //------------------------------------------------------------------------------
CeedQFunctionDestroy_Cuda(CeedQFunction qf)68 static int CeedQFunctionDestroy_Cuda(CeedQFunction qf) {
69 CeedQFunction_Cuda *data;
70
71 CeedCallBackend(CeedQFunctionGetData(qf, &data));
72 if (data->module) CeedCallCuda(CeedQFunctionReturnCeed(qf), cuModuleUnload(data->module));
73 CeedCallBackend(CeedFree(&data));
74 return CEED_ERROR_SUCCESS;
75 }
76
77 //------------------------------------------------------------------------------
78 // Set User QFunction
79 //------------------------------------------------------------------------------
CeedQFunctionSetCUDAUserFunction_Cuda(CeedQFunction qf,CUfunction f)80 static int CeedQFunctionSetCUDAUserFunction_Cuda(CeedQFunction qf, CUfunction f) {
81 CeedQFunction_Cuda *data;
82
83 CeedCallBackend(CeedQFunctionGetData(qf, &data));
84 data->QFunction = f;
85 return CEED_ERROR_SUCCESS;
86 }
87
88 //------------------------------------------------------------------------------
89 // Create QFunction
90 //------------------------------------------------------------------------------
CeedQFunctionCreate_Cuda(CeedQFunction qf)91 int CeedQFunctionCreate_Cuda(CeedQFunction qf) {
92 Ceed ceed;
93 CeedQFunction_Cuda *data;
94
95 CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
96 CeedCallBackend(CeedCalloc(1, &data));
97 CeedCallBackend(CeedQFunctionSetData(qf, data));
98
99 CeedCallBackend(CeedQFunctionGetKernelName(qf, &data->qfunction_name));
100
101 // Register backend functions
102 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "Apply", CeedQFunctionApply_Cuda));
103 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "Destroy", CeedQFunctionDestroy_Cuda));
104 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunction", qf, "SetCUDAUserFunction", CeedQFunctionSetCUDAUserFunction_Cuda));
105 CeedCallBackend(CeedDestroy(&ceed));
106 return CEED_ERROR_SUCCESS;
107 }
108
109 //------------------------------------------------------------------------------
110