1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <stdbool.h>
11 #include <string.h>
12 #include <hip/hip_runtime.h>
13
14 #include "../hip/ceed-hip-common.h"
15 #include "ceed-hip-ref.h"
16
17 //------------------------------------------------------------------------------
18 // Sync host to device
19 //------------------------------------------------------------------------------
CeedQFunctionContextSyncH2D_Hip(const CeedQFunctionContext ctx)20 static inline int CeedQFunctionContextSyncH2D_Hip(const CeedQFunctionContext ctx) {
21 Ceed ceed;
22 size_t ctx_size;
23 CeedQFunctionContext_Hip *impl;
24
25 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
26 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
27
28 CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
29
30 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
31 if (impl->d_data_borrowed) {
32 impl->d_data = impl->d_data_borrowed;
33 } else if (impl->d_data_owned) {
34 impl->d_data = impl->d_data_owned;
35 } else {
36 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctx_size));
37 impl->d_data = impl->d_data_owned;
38 }
39 CeedCallHip(ceed, hipMemcpy(impl->d_data, impl->h_data, ctx_size, hipMemcpyHostToDevice));
40 CeedCallBackend(CeedDestroy(&ceed));
41 return CEED_ERROR_SUCCESS;
42 }
43
44 //------------------------------------------------------------------------------
45 // Sync device to host
46 //------------------------------------------------------------------------------
CeedQFunctionContextSyncD2H_Hip(const CeedQFunctionContext ctx)47 static inline int CeedQFunctionContextSyncD2H_Hip(const CeedQFunctionContext ctx) {
48 Ceed ceed;
49 size_t ctx_size;
50 CeedQFunctionContext_Hip *impl;
51
52 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
53 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
54
55 CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
56
57 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
58 if (impl->h_data_borrowed) {
59 impl->h_data = impl->h_data_borrowed;
60 } else if (impl->h_data_owned) {
61 impl->h_data = impl->h_data_owned;
62 } else {
63 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
64 impl->h_data = impl->h_data_owned;
65 }
66 CeedCallHip(ceed, hipMemcpy(impl->h_data, impl->d_data, ctx_size, hipMemcpyDeviceToHost));
67 CeedCallBackend(CeedDestroy(&ceed));
68 return CEED_ERROR_SUCCESS;
69 }
70
71 //------------------------------------------------------------------------------
72 // Sync data of type
73 //------------------------------------------------------------------------------
CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx,CeedMemType mem_type)74 static inline int CeedQFunctionContextSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type) {
75 switch (mem_type) {
76 case CEED_MEM_HOST:
77 return CeedQFunctionContextSyncD2H_Hip(ctx);
78 case CEED_MEM_DEVICE:
79 return CeedQFunctionContextSyncH2D_Hip(ctx);
80 }
81 return CEED_ERROR_UNSUPPORTED;
82 }
83
84 //------------------------------------------------------------------------------
85 // Set all pointers as invalid
86 //------------------------------------------------------------------------------
CeedQFunctionContextSetAllInvalid_Hip(const CeedQFunctionContext ctx)87 static inline int CeedQFunctionContextSetAllInvalid_Hip(const CeedQFunctionContext ctx) {
88 CeedQFunctionContext_Hip *impl;
89
90 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
91 impl->h_data = NULL;
92 impl->d_data = NULL;
93 return CEED_ERROR_SUCCESS;
94 }
95
96 //------------------------------------------------------------------------------
97 // Check for valid data
98 //------------------------------------------------------------------------------
CeedQFunctionContextHasValidData_Hip(const CeedQFunctionContext ctx,bool * has_valid_data)99 static inline int CeedQFunctionContextHasValidData_Hip(const CeedQFunctionContext ctx, bool *has_valid_data) {
100 CeedQFunctionContext_Hip *impl;
101
102 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
103 *has_valid_data = impl && (impl->h_data || impl->d_data);
104 return CEED_ERROR_SUCCESS;
105 }
106
107 //------------------------------------------------------------------------------
108 // Check if ctx has borrowed data
109 //------------------------------------------------------------------------------
CeedQFunctionContextHasBorrowedDataOfType_Hip(const CeedQFunctionContext ctx,CeedMemType mem_type,bool * has_borrowed_data_of_type)110 static inline int CeedQFunctionContextHasBorrowedDataOfType_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type,
111 bool *has_borrowed_data_of_type) {
112 CeedQFunctionContext_Hip *impl;
113
114 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
115 switch (mem_type) {
116 case CEED_MEM_HOST:
117 *has_borrowed_data_of_type = impl->h_data_borrowed;
118 break;
119 case CEED_MEM_DEVICE:
120 *has_borrowed_data_of_type = impl->d_data_borrowed;
121 break;
122 }
123 return CEED_ERROR_SUCCESS;
124 }
125
126 //------------------------------------------------------------------------------
127 // Check if data of given type needs sync
128 //------------------------------------------------------------------------------
CeedQFunctionContextNeedSync_Hip(const CeedQFunctionContext ctx,CeedMemType mem_type,bool * need_sync)129 static inline int CeedQFunctionContextNeedSync_Hip(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
130 bool has_valid_data = true;
131 CeedQFunctionContext_Hip *impl;
132
133 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
134 CeedCallBackend(CeedQFunctionContextHasValidData_Hip(ctx, &has_valid_data));
135 switch (mem_type) {
136 case CEED_MEM_HOST:
137 *need_sync = has_valid_data && !impl->h_data;
138 break;
139 case CEED_MEM_DEVICE:
140 *need_sync = has_valid_data && !impl->d_data;
141 break;
142 }
143 return CEED_ERROR_SUCCESS;
144 }
145
146 //------------------------------------------------------------------------------
147 // Set data from host
148 //------------------------------------------------------------------------------
CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx,const CeedCopyMode copy_mode,void * data)149 static int CeedQFunctionContextSetDataHost_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
150 CeedQFunctionContext_Hip *impl;
151
152 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
153 CeedCallBackend(CeedFree(&impl->h_data_owned));
154 switch (copy_mode) {
155 case CEED_COPY_VALUES: {
156 size_t ctx_size;
157
158 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
159 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
160 impl->h_data_borrowed = NULL;
161 impl->h_data = impl->h_data_owned;
162 memcpy(impl->h_data, data, ctx_size);
163 } break;
164 case CEED_OWN_POINTER:
165 impl->h_data_owned = data;
166 impl->h_data_borrowed = NULL;
167 impl->h_data = data;
168 break;
169 case CEED_USE_POINTER:
170 impl->h_data_borrowed = data;
171 impl->h_data = data;
172 break;
173 }
174 return CEED_ERROR_SUCCESS;
175 }
176
177 //------------------------------------------------------------------------------
178 // Set data from device
179 //------------------------------------------------------------------------------
CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx,const CeedCopyMode copy_mode,void * data)180 static int CeedQFunctionContextSetDataDevice_Hip(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
181 Ceed ceed;
182 CeedQFunctionContext_Hip *impl;
183
184 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
185 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
186
187 CeedCallHip(ceed, hipFree(impl->d_data_owned));
188 impl->d_data_owned = NULL;
189 switch (copy_mode) {
190 case CEED_COPY_VALUES: {
191 size_t ctx_size;
192 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
193 CeedCallHip(ceed, hipMalloc((void **)&impl->d_data_owned, ctx_size));
194 impl->d_data_borrowed = NULL;
195 impl->d_data = impl->d_data_owned;
196 CeedCallHip(ceed, hipMemcpy(impl->d_data, data, ctx_size, hipMemcpyDeviceToDevice));
197 } break;
198 case CEED_OWN_POINTER:
199 impl->d_data_owned = data;
200 impl->d_data_borrowed = NULL;
201 impl->d_data = data;
202 break;
203 case CEED_USE_POINTER:
204 impl->d_data_owned = NULL;
205 impl->d_data_borrowed = data;
206 impl->d_data = data;
207 break;
208 }
209 CeedCallBackend(CeedDestroy(&ceed));
210 return CEED_ERROR_SUCCESS;
211 }
212
213 //------------------------------------------------------------------------------
214 // Set the data used by a user context,
215 // freeing any previously allocated data if applicable
216 //------------------------------------------------------------------------------
CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx,const CeedMemType mem_type,const CeedCopyMode copy_mode,void * data)217 static int CeedQFunctionContextSetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
218 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx));
219 switch (mem_type) {
220 case CEED_MEM_HOST:
221 return CeedQFunctionContextSetDataHost_Hip(ctx, copy_mode, data);
222 case CEED_MEM_DEVICE:
223 return CeedQFunctionContextSetDataDevice_Hip(ctx, copy_mode, data);
224 }
225 return CEED_ERROR_UNSUPPORTED;
226 }
227
228 //------------------------------------------------------------------------------
229 // Take data
230 //------------------------------------------------------------------------------
CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx,const CeedMemType mem_type,void * data)231 static int CeedQFunctionContextTakeData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
232 bool need_sync = false;
233 CeedQFunctionContext_Hip *impl;
234
235 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
236
237 // Sync data to requested mem_type
238 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync));
239 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type));
240
241 // Update pointer
242 switch (mem_type) {
243 case CEED_MEM_HOST:
244 *(void **)data = impl->h_data_borrowed;
245 impl->h_data_borrowed = NULL;
246 impl->h_data = NULL;
247 break;
248 case CEED_MEM_DEVICE:
249 *(void **)data = impl->d_data_borrowed;
250 impl->d_data_borrowed = NULL;
251 impl->d_data = NULL;
252 break;
253 }
254 return CEED_ERROR_SUCCESS;
255 }
256
257 //------------------------------------------------------------------------------
258 // Core logic for GetData.
259 // If a different memory type is most up to date, this will perform a copy
260 //------------------------------------------------------------------------------
CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx,const CeedMemType mem_type,void * data)261 static int CeedQFunctionContextGetDataCore_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
262 bool need_sync = false;
263 CeedQFunctionContext_Hip *impl;
264
265 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
266
267 // Sync data to requested mem_type
268 CeedCallBackend(CeedQFunctionContextNeedSync_Hip(ctx, mem_type, &need_sync));
269 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Hip(ctx, mem_type));
270
271 // Update pointer
272 switch (mem_type) {
273 case CEED_MEM_HOST:
274 *(void **)data = impl->h_data;
275 break;
276 case CEED_MEM_DEVICE:
277 *(void **)data = impl->d_data;
278 break;
279 }
280 return CEED_ERROR_SUCCESS;
281 }
282
283 //------------------------------------------------------------------------------
284 // Get read-only access to the data
285 //------------------------------------------------------------------------------
CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx,const CeedMemType mem_type,void * data)286 static int CeedQFunctionContextGetDataRead_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
287 return CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data);
288 }
289
290 //------------------------------------------------------------------------------
291 // Get read/write access to the data
292 //------------------------------------------------------------------------------
CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx,const CeedMemType mem_type,void * data)293 static int CeedQFunctionContextGetData_Hip(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
294 CeedQFunctionContext_Hip *impl;
295
296 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
297 CeedCallBackend(CeedQFunctionContextGetDataCore_Hip(ctx, mem_type, data));
298
299 // Mark only pointer for requested memory as valid
300 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Hip(ctx));
301 switch (mem_type) {
302 case CEED_MEM_HOST:
303 impl->h_data = *(void **)data;
304 break;
305 case CEED_MEM_DEVICE:
306 impl->d_data = *(void **)data;
307 break;
308 }
309 return CEED_ERROR_SUCCESS;
310 }
311
312 //------------------------------------------------------------------------------
313 // Destroy the user context
314 //------------------------------------------------------------------------------
CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx)315 static int CeedQFunctionContextDestroy_Hip(const CeedQFunctionContext ctx) {
316 CeedQFunctionContext_Hip *impl;
317
318 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
319 CeedCallHip(CeedQFunctionContextReturnCeed(ctx), hipFree(impl->d_data_owned));
320 CeedCallBackend(CeedFree(&impl->h_data_owned));
321 CeedCallBackend(CeedFree(&impl));
322 return CEED_ERROR_SUCCESS;
323 }
324
325 //------------------------------------------------------------------------------
326 // QFunctionContext Create
327 //------------------------------------------------------------------------------
CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx)328 int CeedQFunctionContextCreate_Hip(CeedQFunctionContext ctx) {
329 CeedQFunctionContext_Hip *impl;
330 Ceed ceed;
331
332 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
333 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Hip));
334 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Hip));
335 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Hip));
336 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Hip));
337 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Hip));
338 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Hip));
339 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Hip));
340 CeedCallBackend(CeedDestroy(&ceed));
341 CeedCallBackend(CeedCalloc(1, &impl));
342 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
343 return CEED_ERROR_SUCCESS;
344 }
345
346 //------------------------------------------------------------------------------
347