xref: /libCEED/backends/hip-gen/ceed-hip-gen-operator.c (revision 8b7d3340714d052d4785a048802399d76a15146b)
1 // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-source/hip/hip-types.h>
11 #include <stddef.h>
12 #include <hip/hiprtc.h>
13 
14 #include "../hip/ceed-hip-common.h"
15 #include "../hip/ceed-hip-compile.h"
16 #include "ceed-hip-gen-operator-build.h"
17 #include "ceed-hip-gen.h"
18 
19 //------------------------------------------------------------------------------
20 // Destroy operator
21 //------------------------------------------------------------------------------
22 static int CeedOperatorDestroy_Hip_gen(CeedOperator op) {
23   Ceed                  ceed;
24   CeedOperator_Hip_gen *impl;
25 
26   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
27   CeedCallBackend(CeedOperatorGetData(op, &impl));
28   if (impl->module) CeedCallHip(ceed, hipModuleUnload(impl->module));
29   if (impl->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)impl->points.num_per_elem));
30   CeedCallBackend(CeedFree(&impl));
31   CeedCallBackend(CeedDestroy(&ceed));
32   return CEED_ERROR_SUCCESS;
33 }
34 
35 //------------------------------------------------------------------------------
36 // Apply and add to output
37 //------------------------------------------------------------------------------
38 static int CeedOperatorApplyAddCore_Hip_gen(CeedOperator op, hipStream_t stream, const CeedScalar *input_arr, CeedScalar *output_arr,
39                                             bool *is_run_good, CeedRequest *request) {
40   bool                   is_at_points, is_tensor;
41   Ceed                   ceed;
42   CeedInt                num_elem, num_input_fields, num_output_fields;
43   CeedEvalMode           eval_mode;
44   CeedQFunctionField    *qf_input_fields, *qf_output_fields;
45   CeedQFunction_Hip_gen *qf_data;
46   CeedQFunction          qf;
47   CeedOperatorField     *op_input_fields, *op_output_fields;
48   CeedOperator_Hip_gen  *data;
49 
50   // Creation of the operator
51   CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, is_run_good));
52   if (!(*is_run_good)) return CEED_ERROR_SUCCESS;
53 
54   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
55   CeedCallBackend(CeedOperatorGetData(op, &data));
56   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
57   CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
58   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
59   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
60   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
61 
62   // Input vectors
63   for (CeedInt i = 0; i < num_input_fields; i++) {
64     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
65     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
66       data->fields.inputs[i] = NULL;
67     } else {
68       bool       is_active;
69       CeedVector vec;
70 
71       // Get input vector
72       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
73       is_active = vec == CEED_VECTOR_ACTIVE;
74       if (is_active) data->fields.inputs[i] = input_arr;
75       else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
76       CeedCallBackend(CeedVectorDestroy(&vec));
77     }
78   }
79 
80   // Output vectors
81   for (CeedInt i = 0; i < num_output_fields; i++) {
82     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
83     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
84       data->fields.outputs[i] = NULL;
85     } else {
86       bool       is_active;
87       CeedVector vec;
88 
89       // Get output vector
90       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
91       is_active = vec == CEED_VECTOR_ACTIVE;
92       if (is_active) data->fields.outputs[i] = output_arr;
93       else CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.outputs[i]));
94       CeedCallBackend(CeedVectorDestroy(&vec));
95     }
96   }
97 
98   // Point coordinates, if needed
99   CeedCallBackend(CeedOperatorIsAtPoints(op, &is_at_points));
100   if (is_at_points) {
101     // Coords
102     CeedVector vec;
103 
104     CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
105     CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords));
106     CeedCallBackend(CeedVectorDestroy(&vec));
107 
108     // Points per elem
109     if (num_elem != data->points.num_elem) {
110       CeedInt            *points_per_elem;
111       const CeedInt       num_bytes   = num_elem * sizeof(CeedInt);
112       CeedElemRestriction rstr_points = NULL;
113 
114       data->points.num_elem = num_elem;
115       CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL));
116       CeedCallBackend(CeedCalloc(num_elem, &points_per_elem));
117       for (CeedInt e = 0; e < num_elem; e++) {
118         CeedInt num_points_elem;
119 
120         CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem));
121         points_per_elem[e] = num_points_elem;
122       }
123       if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem));
124       CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes));
125       CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice));
126       CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
127       CeedCallBackend(CeedFree(&points_per_elem));
128     }
129   }
130 
131   // Get context data
132   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
133 
134   // Apply operator
135   void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points};
136 
137   CeedCallBackend(CeedOperatorHasTensorBases(op, &is_tensor));
138   CeedInt block_sizes[3] = {data->thread_1d, ((!is_tensor || data->dim == 1) ? 1 : data->thread_1d), -1};
139 
140   if (is_tensor) {
141     CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes));
142     if (is_at_points) block_sizes[2] = 1;
143   } else {
144     CeedInt elems_per_block = 64 * data->thread_1d > 256 ? 256 / data->thread_1d : 64;
145 
146     elems_per_block = elems_per_block > 0 ? elems_per_block : 1;
147     block_sizes[2]  = elems_per_block;
148   }
149   if (data->dim == 1 || !is_tensor) {
150     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
151     CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar);
152 
153     CeedCallBackend(
154         CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, is_run_good, opargs));
155   } else if (data->dim == 2) {
156     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
157     CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
158 
159     CeedCallBackend(
160         CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, is_run_good, opargs));
161   } else if (data->dim == 3) {
162     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
163     CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar);
164 
165     CeedCallBackend(
166         CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, is_run_good, opargs));
167   }
168 
169   // Restore input arrays
170   for (CeedInt i = 0; i < num_input_fields; i++) {
171     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
172     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
173     } else {
174       bool       is_active;
175       CeedVector vec;
176 
177       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
178       is_active = vec == CEED_VECTOR_ACTIVE;
179       if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
180       CeedCallBackend(CeedVectorDestroy(&vec));
181     }
182   }
183 
184   // Restore output arrays
185   for (CeedInt i = 0; i < num_output_fields; i++) {
186     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
187     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
188     } else {
189       bool       is_active;
190       CeedVector vec;
191 
192       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
193       is_active = vec == CEED_VECTOR_ACTIVE;
194       if (!is_active) CeedCallBackend(CeedVectorRestoreArray(vec, &data->fields.outputs[i]));
195       CeedCallBackend(CeedVectorDestroy(&vec));
196     }
197   }
198 
199   // Restore point coordinates, if needed
200   if (is_at_points) {
201     CeedVector vec;
202 
203     CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
204     CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords));
205     CeedCallBackend(CeedVectorDestroy(&vec));
206   }
207 
208   // Restore context data
209   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
210 
211   // Cleanup
212   CeedCallBackend(CeedDestroy(&ceed));
213   CeedCallBackend(CeedQFunctionDestroy(&qf));
214   if (!(*is_run_good)) data->use_fallback = true;
215   return CEED_ERROR_SUCCESS;
216 }
217 
218 static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
219   bool              is_run_good = false;
220   const CeedScalar *input_arr   = NULL;
221   CeedScalar       *output_arr  = NULL;
222 
223   // Try to run kernel
224   if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr));
225   if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr));
226   CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(op, NULL, input_arr, output_arr, &is_run_good, request));
227   if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr));
228   if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr));
229 
230   // Fallback on unsuccessful run
231   if (!is_run_good) {
232     CeedOperator op_fallback;
233 
234     CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator");
235     CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
236     CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
237   }
238   return CEED_ERROR_SUCCESS;
239 }
240 
241 static int CeedOperatorApplyAddComposite_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
242   bool              is_run_good[CEED_COMPOSITE_MAX] = {false};
243   CeedInt           num_suboperators;
244   const CeedScalar *input_arr  = NULL;
245   CeedScalar       *output_arr = NULL;
246   Ceed              ceed;
247   CeedOperator     *sub_operators;
248 
249   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
250   CeedCall(CeedCompositeOperatorGetNumSub(op, &num_suboperators));
251   CeedCall(CeedCompositeOperatorGetSubList(op, &sub_operators));
252   if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr));
253   if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr));
254   for (CeedInt i = 0; i < num_suboperators; i++) {
255     CeedInt num_elem = 0;
256 
257     CeedCall(CeedOperatorGetNumElements(sub_operators[i], &num_elem));
258     if (num_elem > 0) {
259       hipStream_t stream = NULL;
260 
261       CeedCallHip(ceed, hipStreamCreate(&stream));
262       CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(sub_operators[i], stream, input_arr, output_arr, &is_run_good[i], request));
263       CeedCallHip(ceed, hipStreamDestroy(stream));
264     }
265   }
266   if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr));
267   if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr));
268   CeedCallHip(ceed, hipDeviceSynchronize());
269 
270   // Fallback on unsuccessful run
271   for (CeedInt i = 0; i < num_suboperators; i++) {
272     if (!is_run_good[i]) {
273       CeedOperator op_fallback;
274 
275       CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator");
276       CeedCallBackend(CeedOperatorGetFallback(sub_operators[i], &op_fallback));
277       CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
278     }
279   }
280   CeedCallBackend(CeedDestroy(&ceed));
281   return CEED_ERROR_SUCCESS;
282 }
283 
284 //------------------------------------------------------------------------------
285 // Create operator
286 //------------------------------------------------------------------------------
287 int CeedOperatorCreate_Hip_gen(CeedOperator op) {
288   bool                  is_composite;
289   Ceed                  ceed;
290   CeedOperator_Hip_gen *impl;
291 
292   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
293   CeedCallBackend(CeedCalloc(1, &impl));
294   CeedCallBackend(CeedOperatorSetData(op, impl));
295   CeedCall(CeedOperatorIsComposite(op, &is_composite));
296   if (is_composite) {
297     CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAddComposite", CeedOperatorApplyAddComposite_Hip_gen));
298   } else {
299     CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip_gen));
300   }
301   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip_gen));
302   CeedCallBackend(CeedDestroy(&ceed));
303   return CEED_ERROR_SUCCESS;
304 }
305 
306 //------------------------------------------------------------------------------
307