xref: /libCEED/rust/libceed-sys/c-src/backends/hip-ref/ceed-hip-ref-qfunction.c (revision 3d8e882215d238700cdceb37404f76ca7fa24eaa)
1*3d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*3d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
30d0321e0SJeremy L Thompson //
4*3d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
50d0321e0SJeremy L Thompson //
6*3d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
70d0321e0SJeremy L Thompson 
80d0321e0SJeremy L Thompson #include <ceed/ceed.h>
90d0321e0SJeremy L Thompson #include <ceed/backend.h>
100d0321e0SJeremy L Thompson #include <hip/hip_runtime.h>
110d0321e0SJeremy L Thompson #include <stdio.h>
120d0321e0SJeremy L Thompson #include <string.h>
130d0321e0SJeremy L Thompson #include "ceed-hip-ref.h"
140d0321e0SJeremy L Thompson #include "ceed-hip-ref-qfunction-load.h"
150d0321e0SJeremy L Thompson #include "../hip/ceed-hip-compile.h"
160d0321e0SJeremy L Thompson 
170d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
180d0321e0SJeremy L Thompson // Apply QFunction
190d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
200d0321e0SJeremy L Thompson static int CeedQFunctionApply_Hip(CeedQFunction qf, CeedInt Q,
210d0321e0SJeremy L Thompson                                   CeedVector *U, CeedVector *V) {
220d0321e0SJeremy L Thompson   int ierr;
230d0321e0SJeremy L Thompson   Ceed ceed;
240d0321e0SJeremy L Thompson   ierr = CeedQFunctionGetCeed(qf, &ceed); CeedChkBackend(ierr);
250d0321e0SJeremy L Thompson 
260d0321e0SJeremy L Thompson   // Build and compile kernel, if not done
270d0321e0SJeremy L Thompson   ierr = CeedHipBuildQFunction(qf); CeedChkBackend(ierr);
280d0321e0SJeremy L Thompson 
290d0321e0SJeremy L Thompson   CeedQFunction_Hip *data;
300d0321e0SJeremy L Thompson   ierr = CeedQFunctionGetData(qf, &data); CeedChkBackend(ierr);
310d0321e0SJeremy L Thompson   Ceed_Hip *ceed_Hip;
320d0321e0SJeremy L Thompson   ierr = CeedGetData(ceed, &ceed_Hip); CeedChkBackend(ierr);
33437930d1SJeremy L Thompson   CeedInt num_input_fields, num_output_fields;
34437930d1SJeremy L Thompson   ierr = CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields);
350d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
360d0321e0SJeremy L Thompson   const int blocksize = ceed_Hip->opt_block_size;
370d0321e0SJeremy L Thompson 
380d0321e0SJeremy L Thompson   // Read vectors
39437930d1SJeremy L Thompson   for (CeedInt i = 0; i < num_input_fields; i++) {
400d0321e0SJeremy L Thompson     ierr = CeedVectorGetArrayRead(U[i], CEED_MEM_DEVICE, &data->fields.inputs[i]);
410d0321e0SJeremy L Thompson     CeedChkBackend(ierr);
420d0321e0SJeremy L Thompson   }
43437930d1SJeremy L Thompson   for (CeedInt i = 0; i < num_output_fields; i++) {
440d0321e0SJeremy L Thompson     ierr = CeedVectorGetArrayWrite(V[i], CEED_MEM_DEVICE, &data->fields.outputs[i]);
450d0321e0SJeremy L Thompson     CeedChkBackend(ierr);
460d0321e0SJeremy L Thompson   }
470d0321e0SJeremy L Thompson 
480d0321e0SJeremy L Thompson   // Get context data
49441428dfSJeremy L Thompson   ierr = CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &data->d_c);
500d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
510d0321e0SJeremy L Thompson 
520d0321e0SJeremy L Thompson   // Run kernel
530d0321e0SJeremy L Thompson   void *args[] = {&data->d_c, (void *) &Q, &data->fields};
54437930d1SJeremy L Thompson   ierr = CeedRunKernelHip(ceed, data->QFunction, CeedDivUpInt(Q, blocksize),
550d0321e0SJeremy L Thompson                           blocksize, args); CeedChkBackend(ierr);
560d0321e0SJeremy L Thompson 
570d0321e0SJeremy L Thompson   // Restore vectors
58437930d1SJeremy L Thompson   for (CeedInt i = 0; i < num_input_fields; i++) {
590d0321e0SJeremy L Thompson     ierr = CeedVectorRestoreArrayRead(U[i], &data->fields.inputs[i]);
600d0321e0SJeremy L Thompson     CeedChkBackend(ierr);
610d0321e0SJeremy L Thompson   }
62437930d1SJeremy L Thompson   for (CeedInt i = 0; i < num_output_fields; i++) {
630d0321e0SJeremy L Thompson     ierr = CeedVectorRestoreArray(V[i], &data->fields.outputs[i]);
640d0321e0SJeremy L Thompson     CeedChkBackend(ierr);
650d0321e0SJeremy L Thompson   }
660d0321e0SJeremy L Thompson 
670d0321e0SJeremy L Thompson   // Restore context
68441428dfSJeremy L Thompson   ierr = CeedQFunctionRestoreInnerContextData(qf, &data->d_c);
690d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
70441428dfSJeremy L Thompson 
710d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
720d0321e0SJeremy L Thompson }
730d0321e0SJeremy L Thompson 
740d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
750d0321e0SJeremy L Thompson // Destroy QFunction
760d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
770d0321e0SJeremy L Thompson static int CeedQFunctionDestroy_Hip(CeedQFunction qf) {
780d0321e0SJeremy L Thompson   int ierr;
790d0321e0SJeremy L Thompson   CeedQFunction_Hip *data;
800d0321e0SJeremy L Thompson   ierr = CeedQFunctionGetData(qf, &data); CeedChkBackend(ierr);
810d0321e0SJeremy L Thompson   Ceed ceed;
820d0321e0SJeremy L Thompson   ierr = CeedQFunctionGetCeed(qf, &ceed); CeedChkBackend(ierr);
830d0321e0SJeremy L Thompson   if  (data->module)
840d0321e0SJeremy L Thompson     CeedChk_Hip(ceed, hipModuleUnload(data->module));
850d0321e0SJeremy L Thompson   ierr = CeedFree(&data); CeedChkBackend(ierr);
86437930d1SJeremy L Thompson 
870d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
880d0321e0SJeremy L Thompson }
890d0321e0SJeremy L Thompson 
900d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
910d0321e0SJeremy L Thompson // Create QFunction
920d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
930d0321e0SJeremy L Thompson int CeedQFunctionCreate_Hip(CeedQFunction qf) {
940d0321e0SJeremy L Thompson   int ierr;
950d0321e0SJeremy L Thompson   Ceed ceed;
960d0321e0SJeremy L Thompson   CeedQFunctionGetCeed(qf, &ceed);
970d0321e0SJeremy L Thompson   CeedQFunction_Hip *data;
980d0321e0SJeremy L Thompson   ierr = CeedCalloc(1,&data); CeedChkBackend(ierr);
990d0321e0SJeremy L Thompson   ierr = CeedQFunctionSetData(qf, data); CeedChkBackend(ierr);
100437930d1SJeremy L Thompson   CeedInt num_input_fields, num_output_fields;
101437930d1SJeremy L Thompson   ierr = CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields);
1020d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
1030d0321e0SJeremy L Thompson 
1040d0321e0SJeremy L Thompson   // Read QFunction source
105437930d1SJeremy L Thompson   ierr = CeedQFunctionGetKernelName(qf, &data->qfunction_name);
1060d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
10746dc0734SJeremy L Thompson   CeedDebug256(ceed, 2, "----- Loading QFunction User Source -----\n");
108437930d1SJeremy L Thompson   ierr = CeedQFunctionLoadSourceToBuffer(qf, &data->qfunction_source);
1090d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
11046dc0734SJeremy L Thompson   CeedDebug256(ceed, 2, "----- Loading QFunction User Source Complete! -----\n");
1110d0321e0SJeremy L Thompson 
1120d0321e0SJeremy L Thompson   // Register backend functions
1130d0321e0SJeremy L Thompson   ierr = CeedSetBackendFunction(ceed, "QFunction", qf, "Apply",
1140d0321e0SJeremy L Thompson                                 CeedQFunctionApply_Hip); CeedChkBackend(ierr);
1150d0321e0SJeremy L Thompson   ierr = CeedSetBackendFunction(ceed, "QFunction", qf, "Destroy",
1160d0321e0SJeremy L Thompson                                 CeedQFunctionDestroy_Hip); CeedChkBackend(ierr);
1170d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1180d0321e0SJeremy L Thompson }
1190d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
120