xref: /libCEED/rust/libceed-sys/c-src/backends/cuda-ref/ceed-cuda-ref-qfunction.c (revision 3d8e882215d238700cdceb37404f76ca7fa24eaa)
1*3d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*3d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
30d0321e0SJeremy L Thompson //
4*3d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
50d0321e0SJeremy L Thompson //
6*3d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
70d0321e0SJeremy L Thompson 
80d0321e0SJeremy L Thompson #include <ceed/ceed.h>
90d0321e0SJeremy L Thompson #include <ceed/backend.h>
100d0321e0SJeremy L Thompson #include <cuda.h>
110d0321e0SJeremy L Thompson #include <stdio.h>
120d0321e0SJeremy L Thompson #include <string.h>
130d0321e0SJeremy L Thompson #include "ceed-cuda-ref.h"
140d0321e0SJeremy L Thompson #include "ceed-cuda-ref-qfunction-load.h"
150d0321e0SJeremy L Thompson #include "../cuda/ceed-cuda-compile.h"
160d0321e0SJeremy L Thompson 
170d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
180d0321e0SJeremy L Thompson // Apply QFunction
190d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
200d0321e0SJeremy L Thompson static int CeedQFunctionApply_Cuda(CeedQFunction qf, CeedInt Q,
210d0321e0SJeremy L Thompson                                    CeedVector *U, CeedVector *V) {
220d0321e0SJeremy L Thompson   int ierr;
230d0321e0SJeremy L Thompson   Ceed ceed;
240d0321e0SJeremy L Thompson   ierr = CeedQFunctionGetCeed(qf, &ceed); CeedChkBackend(ierr);
250d0321e0SJeremy L Thompson 
260d0321e0SJeremy L Thompson   // Build and compile kernel, if not done
270d0321e0SJeremy L Thompson   ierr = CeedCudaBuildQFunction(qf); CeedChkBackend(ierr);
280d0321e0SJeremy L Thompson 
290d0321e0SJeremy L Thompson   CeedQFunction_Cuda *data;
300d0321e0SJeremy L Thompson   ierr = CeedQFunctionGetData(qf, &data); CeedChkBackend(ierr);
310d0321e0SJeremy L Thompson   Ceed_Cuda *ceed_Cuda;
320d0321e0SJeremy L Thompson   ierr = CeedGetData(ceed, &ceed_Cuda); CeedChkBackend(ierr);
33437930d1SJeremy L Thompson   CeedInt num_input_fields, num_output_fields;
34437930d1SJeremy L Thompson   ierr = CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields);
350d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
360d0321e0SJeremy L Thompson 
370d0321e0SJeremy L Thompson   // Read vectors
38437930d1SJeremy L Thompson   for (CeedInt i = 0; i < num_input_fields; i++) {
390d0321e0SJeremy L Thompson     ierr = CeedVectorGetArrayRead(U[i], CEED_MEM_DEVICE, &data->fields.inputs[i]);
400d0321e0SJeremy L Thompson     CeedChkBackend(ierr);
410d0321e0SJeremy L Thompson   }
42437930d1SJeremy L Thompson   for (CeedInt i = 0; i < num_output_fields; i++) {
430d0321e0SJeremy L Thompson     ierr = CeedVectorGetArrayWrite(V[i], CEED_MEM_DEVICE, &data->fields.outputs[i]);
440d0321e0SJeremy L Thompson     CeedChkBackend(ierr);
450d0321e0SJeremy L Thompson   }
460d0321e0SJeremy L Thompson 
470d0321e0SJeremy L Thompson   // Get context data
48441428dfSJeremy L Thompson   ierr = CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &data->d_c);
490d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
500d0321e0SJeremy L Thompson 
510d0321e0SJeremy L Thompson   // Run kernel
520d0321e0SJeremy L Thompson   void *args[] = {&data->d_c, (void *) &Q, &data->fields};
53437930d1SJeremy L Thompson   ierr = CeedRunKernelAutoblockCuda(ceed, data->QFunction, Q, args);
540d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
550d0321e0SJeremy L Thompson 
560d0321e0SJeremy L Thompson   // Restore vectors
57437930d1SJeremy L Thompson   for (CeedInt i = 0; i < num_input_fields; i++) {
580d0321e0SJeremy L Thompson     ierr = CeedVectorRestoreArrayRead(U[i], &data->fields.inputs[i]);
590d0321e0SJeremy L Thompson     CeedChkBackend(ierr);
600d0321e0SJeremy L Thompson   }
61437930d1SJeremy L Thompson   for (CeedInt i = 0; i < num_output_fields; i++) {
620d0321e0SJeremy L Thompson     ierr = CeedVectorRestoreArray(V[i], &data->fields.outputs[i]);
630d0321e0SJeremy L Thompson     CeedChkBackend(ierr);
640d0321e0SJeremy L Thompson   }
650d0321e0SJeremy L Thompson 
660d0321e0SJeremy L Thompson   // Restore context
67441428dfSJeremy L Thompson   ierr = CeedQFunctionRestoreInnerContextData(qf, &data->d_c);
680d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
69441428dfSJeremy L Thompson 
700d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
710d0321e0SJeremy L Thompson }
720d0321e0SJeremy L Thompson 
730d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
740d0321e0SJeremy L Thompson // Destroy QFunction
750d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
760d0321e0SJeremy L Thompson static int CeedQFunctionDestroy_Cuda(CeedQFunction qf) {
770d0321e0SJeremy L Thompson   int ierr;
780d0321e0SJeremy L Thompson   CeedQFunction_Cuda *data;
790d0321e0SJeremy L Thompson   ierr = CeedQFunctionGetData(qf, &data); CeedChkBackend(ierr);
800d0321e0SJeremy L Thompson   Ceed ceed;
810d0321e0SJeremy L Thompson   ierr = CeedQFunctionGetCeed(qf, &ceed); CeedChkBackend(ierr);
820d0321e0SJeremy L Thompson   if (data->module)
830d0321e0SJeremy L Thompson     CeedChk_Cu(ceed, cuModuleUnload(data->module));
840d0321e0SJeremy L Thompson   ierr = CeedFree(&data); CeedChkBackend(ierr);
85437930d1SJeremy L Thompson 
860d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
870d0321e0SJeremy L Thompson }
880d0321e0SJeremy L Thompson 
890d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
900d0321e0SJeremy L Thompson // Set User QFunction
910d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
920d0321e0SJeremy L Thompson static int CeedQFunctionSetCUDAUserFunction_Cuda(CeedQFunction qf,
930d0321e0SJeremy L Thompson     CUfunction f) {
940d0321e0SJeremy L Thompson   int ierr;
950d0321e0SJeremy L Thompson   CeedQFunction_Cuda *data;
960d0321e0SJeremy L Thompson   ierr = CeedQFunctionGetData(qf, &data); CeedChkBackend(ierr);
97437930d1SJeremy L Thompson   data->QFunction = f;
98437930d1SJeremy L Thompson 
990d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1000d0321e0SJeremy L Thompson }
1010d0321e0SJeremy L Thompson 
1020d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1030d0321e0SJeremy L Thompson // Create QFunction
1040d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1050d0321e0SJeremy L Thompson int CeedQFunctionCreate_Cuda(CeedQFunction qf) {
1060d0321e0SJeremy L Thompson   int ierr;
1070d0321e0SJeremy L Thompson   Ceed ceed;
1080d0321e0SJeremy L Thompson   CeedQFunctionGetCeed(qf, &ceed);
1090d0321e0SJeremy L Thompson   CeedQFunction_Cuda *data;
1100d0321e0SJeremy L Thompson   ierr = CeedCalloc(1, &data); CeedChkBackend(ierr);
1110d0321e0SJeremy L Thompson   ierr = CeedQFunctionSetData(qf, data); CeedChkBackend(ierr);
1120d0321e0SJeremy L Thompson 
1130d0321e0SJeremy L Thompson   // Read QFunction source
114437930d1SJeremy L Thompson   ierr = CeedQFunctionGetKernelName(qf, &data->qfunction_name);
1150d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
11646dc0734SJeremy L Thompson   CeedDebug256(ceed, 2, "----- Loading QFunction User Source -----\n");
117437930d1SJeremy L Thompson   ierr = CeedQFunctionLoadSourceToBuffer(qf, &data->qfunction_source);
1180d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
11946dc0734SJeremy L Thompson   CeedDebug256(ceed, 2, "----- Loading QFunction User Source Complete! -----\n");
1200d0321e0SJeremy L Thompson 
1210d0321e0SJeremy L Thompson   // Register backend functions
1220d0321e0SJeremy L Thompson   ierr = CeedSetBackendFunction(ceed, "QFunction", qf, "Apply",
1230d0321e0SJeremy L Thompson                                 CeedQFunctionApply_Cuda); CeedChkBackend(ierr);
1240d0321e0SJeremy L Thompson   ierr = CeedSetBackendFunction(ceed, "QFunction", qf, "Destroy",
1250d0321e0SJeremy L Thompson                                 CeedQFunctionDestroy_Cuda); CeedChkBackend(ierr);
1260d0321e0SJeremy L Thompson   ierr = CeedSetBackendFunction(ceed, "QFunction", qf, "SetCUDAUserFunction",
1270d0321e0SJeremy L Thompson                                 CeedQFunctionSetCUDAUserFunction_Cuda);
1280d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
1290d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1300d0321e0SJeremy L Thompson }
1310d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
132