xref: /libCEED/backends/hip-gen/ceed-hip-gen-operator.c (revision 8c03e814a8aedd48736bf8454f3df41e37fe2fcc)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-source/hip/hip-types.h>
11 #include <stddef.h>
12 #include <hip/hiprtc.h>
13 
14 #include "../hip/ceed-hip-common.h"
15 #include "../hip/ceed-hip-compile.h"
16 #include "ceed-hip-gen-operator-build.h"
17 #include "ceed-hip-gen.h"
18 
19 //------------------------------------------------------------------------------
20 // Destroy operator
21 //------------------------------------------------------------------------------
22 static int CeedOperatorDestroy_Hip_gen(CeedOperator op) {
23   Ceed                  ceed;
24   CeedOperator_Hip_gen *impl;
25 
26   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
27   CeedCallBackend(CeedOperatorGetData(op, &impl));
28   if (impl->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)impl->points.num_per_elem));
29   CeedCallBackend(CeedFree(&impl));
30   CeedCallBackend(CeedDestroy(&ceed));
31   return CEED_ERROR_SUCCESS;
32 }
33 
34 //------------------------------------------------------------------------------
35 // Apply and add to output
36 //------------------------------------------------------------------------------
37 static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
38   bool                   is_at_points, is_tensor, is_good_run = true;
39   Ceed                   ceed;
40   CeedInt                num_elem, num_input_fields, num_output_fields;
41   CeedEvalMode           eval_mode;
42   CeedVector             output_vecs[CEED_FIELD_MAX] = {NULL};
43   CeedQFunctionField    *qf_input_fields, *qf_output_fields;
44   CeedQFunction_Hip_gen *qf_data;
45   CeedQFunction          qf;
46   CeedOperatorField     *op_input_fields, *op_output_fields;
47   CeedOperator_Hip_gen  *data;
48 
49   // Creation of the operator
50   {
51     bool is_good_build = false;
52 
53     CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_good_build));
54     if (!is_good_build) {
55       CeedOperator op_fallback;
56 
57       CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator due to code generation issue");
58       CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
59       CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
60       return CEED_ERROR_SUCCESS;
61     }
62   }
63 
64   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
65   CeedCallBackend(CeedOperatorGetData(op, &data));
66   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
67   CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
68   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
69   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
70   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
71 
72   // Input vectors
73   for (CeedInt i = 0; i < num_input_fields; i++) {
74     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
75     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
76       data->fields.inputs[i] = NULL;
77     } else {
78       bool       is_active;
79       CeedVector vec;
80 
81       // Get input vector
82       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
83       is_active = vec == CEED_VECTOR_ACTIVE;
84       if (is_active) vec = input_vec;
85       CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
86       if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
87     }
88   }
89 
90   // Output vectors
91   for (CeedInt i = 0; i < num_output_fields; i++) {
92     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
93     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
94       data->fields.outputs[i] = NULL;
95     } else {
96       bool       is_active;
97       CeedVector vec;
98 
99       // Get output vector
100       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
101       is_active = vec == CEED_VECTOR_ACTIVE;
102       if (is_active) vec = output_vec;
103       output_vecs[i] = vec;
104       // Check for multiple output modes
105       CeedInt index = -1;
106 
107       for (CeedInt j = 0; j < i; j++) {
108         if (vec == output_vecs[j]) {
109           index = j;
110           break;
111         }
112       }
113       if (index == -1) {
114         CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.outputs[i]));
115       } else {
116         data->fields.outputs[i] = data->fields.outputs[index];
117       }
118       if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
119     }
120   }
121 
122   // Point coordinates, if needed
123   CeedCallBackend(CeedOperatorIsAtPoints(op, &is_at_points));
124   if (is_at_points) {
125     // Coords
126     CeedVector vec;
127 
128     CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
129     CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords));
130     CeedCallBackend(CeedVectorDestroy(&vec));
131 
132     // Points per elem
133     if (num_elem != data->points.num_elem) {
134       CeedInt            *points_per_elem;
135       const CeedInt       num_bytes   = num_elem * sizeof(CeedInt);
136       CeedElemRestriction rstr_points = NULL;
137 
138       data->points.num_elem = num_elem;
139       CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL));
140       CeedCallBackend(CeedCalloc(num_elem, &points_per_elem));
141       for (CeedInt e = 0; e < num_elem; e++) {
142         CeedInt num_points_elem;
143 
144         CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem));
145         points_per_elem[e] = num_points_elem;
146       }
147       if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem));
148       CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes));
149       CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice));
150       CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
151       CeedCallBackend(CeedFree(&points_per_elem));
152     }
153   }
154 
155   // Get context data
156   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
157 
158   // Apply operator
159   void         *opargs[]  = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points};
160   const CeedInt dim       = data->dim;
161   const CeedInt Q_1d      = data->Q_1d;
162   const CeedInt P_1d      = data->max_P_1d;
163   const CeedInt thread_1d = CeedIntMax(Q_1d, P_1d);
164 
165   CeedCallBackend(CeedOperatorHasTensorBases(op, &is_tensor));
166   CeedInt block_sizes[3] = {thread_1d, ((!is_tensor || dim == 1) ? 1 : thread_1d), -1};
167 
168   if (is_tensor) {
169     CeedCallBackend(BlockGridCalculate_Hip_gen(is_tensor ? dim : 1, num_elem, P_1d, Q_1d, block_sizes));
170   } else {
171     CeedInt elems_per_block = 64 * thread_1d > 256 ? 256 / thread_1d : 64;
172 
173     elems_per_block = elems_per_block > 0 ? elems_per_block : 1;
174     block_sizes[2]  = elems_per_block;
175   }
176   if (dim == 1 || !is_tensor) {
177     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
178     CeedInt sharedMem = block_sizes[2] * thread_1d * sizeof(CeedScalar);
179 
180     CeedCallBackend(
181         CeedTryRunKernelDimShared_Hip(ceed, data->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, &is_good_run, opargs));
182   } else if (dim == 2) {
183     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
184     CeedInt sharedMem = block_sizes[2] * thread_1d * thread_1d * sizeof(CeedScalar);
185 
186     CeedCallBackend(
187         CeedTryRunKernelDimShared_Hip(ceed, data->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, &is_good_run, opargs));
188   } else if (dim == 3) {
189     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
190     CeedInt sharedMem = block_sizes[2] * thread_1d * thread_1d * sizeof(CeedScalar);
191 
192     CeedCallBackend(
193         CeedTryRunKernelDimShared_Hip(ceed, data->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, &is_good_run, opargs));
194   }
195 
196   // Restore input arrays
197   for (CeedInt i = 0; i < num_input_fields; i++) {
198     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
199     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
200     } else {
201       bool       is_active;
202       CeedVector vec;
203 
204       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
205       is_active = vec == CEED_VECTOR_ACTIVE;
206       if (is_active) vec = input_vec;
207       CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
208       if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
209     }
210   }
211 
212   // Restore output arrays
213   for (CeedInt i = 0; i < num_output_fields; i++) {
214     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
215     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
216     } else {
217       bool       is_active;
218       CeedVector vec;
219 
220       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
221       is_active = vec == CEED_VECTOR_ACTIVE;
222       if (is_active) vec = output_vec;
223       // Check for multiple output modes
224       CeedInt index = -1;
225 
226       for (CeedInt j = 0; j < i; j++) {
227         if (vec == output_vecs[j]) {
228           index = j;
229           break;
230         }
231       }
232       if (index == -1) {
233         CeedCallBackend(CeedVectorRestoreArray(vec, &data->fields.outputs[i]));
234       }
235       if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
236     }
237   }
238 
239   // Restore point coordinates, if needed
240   if (is_at_points) {
241     CeedVector vec;
242 
243     CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
244     CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords));
245     CeedCallBackend(CeedVectorDestroy(&vec));
246   }
247 
248   // Restore context data
249   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
250 
251   // Cleanup
252   CeedCallBackend(CeedDestroy(&ceed));
253   CeedCallBackend(CeedQFunctionDestroy(&qf));
254 
255   // Fallback if run was bad (out of resources)
256   if (!is_good_run) {
257     CeedOperator op_fallback;
258 
259     data->use_fallback = true;
260     CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator due to kernel execution issue");
261     CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
262     CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
263     return CEED_ERROR_SUCCESS;
264   }
265   return CEED_ERROR_SUCCESS;
266 }
267 
268 //------------------------------------------------------------------------------
269 // Create operator
270 //------------------------------------------------------------------------------
271 int CeedOperatorCreate_Hip_gen(CeedOperator op) {
272   Ceed                  ceed;
273   CeedOperator_Hip_gen *impl;
274 
275   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
276   CeedCallBackend(CeedCalloc(1, &impl));
277   CeedCallBackend(CeedOperatorSetData(op, impl));
278   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip_gen));
279   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip_gen));
280   CeedCallBackend(CeedDestroy(&ceed));
281   return CEED_ERROR_SUCCESS;
282 }
283 
284 //------------------------------------------------------------------------------
285