xref: /libCEED/backends/hip-ref/ceed-hip-ref-qfunction.c (revision 019b76820d7ff306c177822c4e76ffe5939c204b)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 #include <ceed/backend.h>
10 #include <hip/hip_runtime.h>
11 #include <stdio.h>
12 #include <string.h>
13 #include "ceed-hip-ref.h"
14 #include "ceed-hip-ref-qfunction-load.h"
15 #include "../hip/ceed-hip-compile.h"
16 
17 //------------------------------------------------------------------------------
18 // Apply QFunction
19 //------------------------------------------------------------------------------
20 static int CeedQFunctionApply_Hip(CeedQFunction qf, CeedInt Q,
21                                   CeedVector *U, CeedVector *V) {
22   int ierr;
23   Ceed ceed;
24   ierr = CeedQFunctionGetCeed(qf, &ceed); CeedChkBackend(ierr);
25 
26   // Build and compile kernel, if not done
27   ierr = CeedHipBuildQFunction(qf); CeedChkBackend(ierr);
28 
29   CeedQFunction_Hip *data;
30   ierr = CeedQFunctionGetData(qf, &data); CeedChkBackend(ierr);
31   Ceed_Hip *ceed_Hip;
32   ierr = CeedGetData(ceed, &ceed_Hip); CeedChkBackend(ierr);
33   CeedInt num_input_fields, num_output_fields;
34   ierr = CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields);
35   CeedChkBackend(ierr);
36   const int blocksize = ceed_Hip->opt_block_size;
37 
38   // Read vectors
39   for (CeedInt i = 0; i < num_input_fields; i++) {
40     ierr = CeedVectorGetArrayRead(U[i], CEED_MEM_DEVICE, &data->fields.inputs[i]);
41     CeedChkBackend(ierr);
42   }
43   for (CeedInt i = 0; i < num_output_fields; i++) {
44     ierr = CeedVectorGetArrayWrite(V[i], CEED_MEM_DEVICE, &data->fields.outputs[i]);
45     CeedChkBackend(ierr);
46   }
47 
48   // Get context data
49   ierr = CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &data->d_c);
50   CeedChkBackend(ierr);
51 
52   // Run kernel
53   void *args[] = {&data->d_c, (void *) &Q, &data->fields};
54   ierr = CeedRunKernelHip(ceed, data->QFunction, CeedDivUpInt(Q, blocksize),
55                           blocksize, args); CeedChkBackend(ierr);
56 
57   // Restore vectors
58   for (CeedInt i = 0; i < num_input_fields; i++) {
59     ierr = CeedVectorRestoreArrayRead(U[i], &data->fields.inputs[i]);
60     CeedChkBackend(ierr);
61   }
62   for (CeedInt i = 0; i < num_output_fields; i++) {
63     ierr = CeedVectorRestoreArray(V[i], &data->fields.outputs[i]);
64     CeedChkBackend(ierr);
65   }
66 
67   // Restore context
68   ierr = CeedQFunctionRestoreInnerContextData(qf, &data->d_c);
69   CeedChkBackend(ierr);
70 
71   return CEED_ERROR_SUCCESS;
72 }
73 
74 //------------------------------------------------------------------------------
75 // Destroy QFunction
76 //------------------------------------------------------------------------------
77 static int CeedQFunctionDestroy_Hip(CeedQFunction qf) {
78   int ierr;
79   CeedQFunction_Hip *data;
80   ierr = CeedQFunctionGetData(qf, &data); CeedChkBackend(ierr);
81   Ceed ceed;
82   ierr = CeedQFunctionGetCeed(qf, &ceed); CeedChkBackend(ierr);
83   if  (data->module)
84     CeedChk_Hip(ceed, hipModuleUnload(data->module));
85   ierr = CeedFree(&data); CeedChkBackend(ierr);
86 
87   return CEED_ERROR_SUCCESS;
88 }
89 
90 //------------------------------------------------------------------------------
91 // Create QFunction
92 //------------------------------------------------------------------------------
93 int CeedQFunctionCreate_Hip(CeedQFunction qf) {
94   int ierr;
95   Ceed ceed;
96   CeedQFunctionGetCeed(qf, &ceed);
97   CeedQFunction_Hip *data;
98   ierr = CeedCalloc(1,&data); CeedChkBackend(ierr);
99   ierr = CeedQFunctionSetData(qf, data); CeedChkBackend(ierr);
100   CeedInt num_input_fields, num_output_fields;
101   ierr = CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields);
102   CeedChkBackend(ierr);
103 
104   // Read QFunction source
105   ierr = CeedQFunctionGetKernelName(qf, &data->qfunction_name);
106   CeedChkBackend(ierr);
107   CeedDebug256(ceed, 2, "----- Loading QFunction User Source -----\n");
108   ierr = CeedQFunctionLoadSourceToBuffer(qf, &data->qfunction_source);
109   CeedChkBackend(ierr);
110   CeedDebug256(ceed, 2, "----- Loading QFunction User Source Complete! -----\n");
111 
112   // Register backend functions
113   ierr = CeedSetBackendFunction(ceed, "QFunction", qf, "Apply",
114                                 CeedQFunctionApply_Hip); CeedChkBackend(ierr);
115   ierr = CeedSetBackendFunction(ceed, "QFunction", qf, "Destroy",
116                                 CeedQFunctionDestroy_Hip); CeedChkBackend(ierr);
117   return CEED_ERROR_SUCCESS;
118 }
119 //------------------------------------------------------------------------------
120