xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-qfunction.c (revision 381e65939e85104561074440c4dd3dd99bd0efff)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 #include <ceed/backend.h>
10 #include <cuda.h>
11 #include <stdio.h>
12 #include <string.h>
13 #include "ceed-cuda-ref.h"
14 #include "ceed-cuda-ref-qfunction-load.h"
15 #include "../cuda/ceed-cuda-compile.h"
16 
17 //------------------------------------------------------------------------------
18 // Apply QFunction
19 //------------------------------------------------------------------------------
20 static int CeedQFunctionApply_Cuda(CeedQFunction qf, CeedInt Q,
21                                    CeedVector *U, CeedVector *V) {
22   int ierr;
23   Ceed ceed;
24   ierr = CeedQFunctionGetCeed(qf, &ceed); CeedChkBackend(ierr);
25 
26   // Build and compile kernel, if not done
27   ierr = CeedCudaBuildQFunction(qf); CeedChkBackend(ierr);
28 
29   CeedQFunction_Cuda *data;
30   ierr = CeedQFunctionGetData(qf, &data); CeedChkBackend(ierr);
31   Ceed_Cuda *ceed_Cuda;
32   ierr = CeedGetData(ceed, &ceed_Cuda); CeedChkBackend(ierr);
33   CeedInt num_input_fields, num_output_fields;
34   ierr = CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields);
35   CeedChkBackend(ierr);
36 
37   // Read vectors
38   for (CeedInt i = 0; i < num_input_fields; i++) {
39     ierr = CeedVectorGetArrayRead(U[i], CEED_MEM_DEVICE, &data->fields.inputs[i]);
40     CeedChkBackend(ierr);
41   }
42   for (CeedInt i = 0; i < num_output_fields; i++) {
43     ierr = CeedVectorGetArrayWrite(V[i], CEED_MEM_DEVICE, &data->fields.outputs[i]);
44     CeedChkBackend(ierr);
45   }
46 
47   // Get context data
48   ierr = CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &data->d_c);
49   CeedChkBackend(ierr);
50 
51   // Run kernel
52   void *args[] = {&data->d_c, (void *) &Q, &data->fields};
53   ierr = CeedRunKernelAutoblockCuda(ceed, data->QFunction, Q, args);
54   CeedChkBackend(ierr);
55 
56   // Restore vectors
57   for (CeedInt i = 0; i < num_input_fields; i++) {
58     ierr = CeedVectorRestoreArrayRead(U[i], &data->fields.inputs[i]);
59     CeedChkBackend(ierr);
60   }
61   for (CeedInt i = 0; i < num_output_fields; i++) {
62     ierr = CeedVectorRestoreArray(V[i], &data->fields.outputs[i]);
63     CeedChkBackend(ierr);
64   }
65 
66   // Restore context
67   ierr = CeedQFunctionRestoreInnerContextData(qf, &data->d_c);
68   CeedChkBackend(ierr);
69 
70   return CEED_ERROR_SUCCESS;
71 }
72 
73 //------------------------------------------------------------------------------
74 // Destroy QFunction
75 //------------------------------------------------------------------------------
76 static int CeedQFunctionDestroy_Cuda(CeedQFunction qf) {
77   int ierr;
78   CeedQFunction_Cuda *data;
79   ierr = CeedQFunctionGetData(qf, &data); CeedChkBackend(ierr);
80   Ceed ceed;
81   ierr = CeedQFunctionGetCeed(qf, &ceed); CeedChkBackend(ierr);
82   if (data->module)
83     CeedChk_Cu(ceed, cuModuleUnload(data->module));
84   ierr = CeedFree(&data); CeedChkBackend(ierr);
85 
86   return CEED_ERROR_SUCCESS;
87 }
88 
89 //------------------------------------------------------------------------------
90 // Set User QFunction
91 //------------------------------------------------------------------------------
92 static int CeedQFunctionSetCUDAUserFunction_Cuda(CeedQFunction qf,
93     CUfunction f) {
94   int ierr;
95   CeedQFunction_Cuda *data;
96   ierr = CeedQFunctionGetData(qf, &data); CeedChkBackend(ierr);
97   data->QFunction = f;
98 
99   return CEED_ERROR_SUCCESS;
100 }
101 
102 //------------------------------------------------------------------------------
103 // Create QFunction
104 //------------------------------------------------------------------------------
105 int CeedQFunctionCreate_Cuda(CeedQFunction qf) {
106   int ierr;
107   Ceed ceed;
108   CeedQFunctionGetCeed(qf, &ceed);
109   CeedQFunction_Cuda *data;
110   ierr = CeedCalloc(1, &data); CeedChkBackend(ierr);
111   ierr = CeedQFunctionSetData(qf, data); CeedChkBackend(ierr);
112 
113   // Read QFunction source
114   ierr = CeedQFunctionGetKernelName(qf, &data->qfunction_name);
115   CeedChkBackend(ierr);
116   CeedDebug256(ceed, 2, "----- Loading QFunction User Source -----\n");
117   ierr = CeedQFunctionLoadSourceToBuffer(qf, &data->qfunction_source);
118   CeedChkBackend(ierr);
119   CeedDebug256(ceed, 2, "----- Loading QFunction User Source Complete! -----\n");
120 
121   // Register backend functions
122   ierr = CeedSetBackendFunction(ceed, "QFunction", qf, "Apply",
123                                 CeedQFunctionApply_Cuda); CeedChkBackend(ierr);
124   ierr = CeedSetBackendFunction(ceed, "QFunction", qf, "Destroy",
125                                 CeedQFunctionDestroy_Cuda); CeedChkBackend(ierr);
126   ierr = CeedSetBackendFunction(ceed, "QFunction", qf, "SetCUDAUserFunction",
127                                 CeedQFunctionSetCUDAUserFunction_Cuda);
128   CeedChkBackend(ierr);
129   return CEED_ERROR_SUCCESS;
130 }
131 //------------------------------------------------------------------------------
132